C++ Library for Competitive Programming
#include "emthrm/data_structure/binary_trie.hpp"
非負整数を管理するトライ木である。
$O(B)$
template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie;
B
:ビット幅T
:要素型名前 | 説明 |
---|---|
std::shared_ptr<Node> root; |
根のポインタ |
名前 | 効果・戻り値 | 備考 |
---|---|---|
BinaryTrie(); |
集合 $S \mathrel{:=} \emptyset$ を表すオブジェクトを構築する | |
void clear(); |
$S \gets \emptyset$ | |
bool empty() const; |
$S = \emptyset$ を満たすか。 | |
int size() const; |
$\lvert S \rvert$ | |
void erase(const T& x); |
$S \gets S \setminus \lbrace x \rbrace$ | $x \notin S$ を満たすときは何もしない。 |
std::shared_ptr<Node> find(const T& x) const; |
$x$ を指すノード。ただし $x \notin S$ を満たすときは nullptr を返す。 |
|
std::pair<std::shared_ptr<Node>, T> operator[](const int n) const; |
$n$ 番目 (0-based) の要素を指すノードと値 | |
std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const; |
$\lbrace s \oplus x \mid s \in S \rbrace$ における $n$ 番目 (0-based) の要素を表すノードと値 | |
std::shared_ptr<Node> insert(const T& x); |
$S \gets S \cup \lbrace x \rbrace$ の操作後、$x$ を表すノードを返す。 | |
int less_than(const T& x) const; |
$\lvert \lbrace s \in S \mid s < x \rbrace \rvert$ | |
int count(const T& l, const T& r) const; |
$\lvert \lbrace s \in S \mid l \leq x < r \rbrace \rvert$ | |
int count(const T& x) const; |
$\lvert \lbrace s \in S \mid s = x \rbrace \rvert$ | |
std::pair<std::shared_ptr<Node>, T> lower_bound(const T& x) const; |
$x$ より小さくない最初の要素を表すノードと値。ただしそのような要素が存在しないときは (nullptr , $-1$) を返す。 |
|
std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(const T& x) const; |
$x$ より小さくない最初の要素を表すノードと値。ただしそのような要素が存在しないときは (nullptr, std::nullopt) を返す。 |
|
std::pair<std::shared_ptr<Node>, T> upper_bound(const T& x) const; |
$x$ より大きい最初の要素を表すノードと値。ただし存在しないときは (nullptr , $-1$) を返す。 |
|
std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(const T& x) const; |
$x$ より大きい最初の要素を表すノードと値。ただし存在しないときは (nullptr, std::nullopt) を返す。 |
|
std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const; |
$\mathrm{argmax} \lbrace s \oplus x \mid s \in S \rbrace$ | |
std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const; |
$\mathrm{argmin} \lbrace s \oplus x \mid s \in S \rbrace$ |
名前 | 説明 |
---|---|
Node |
ノードを表す構造体 |
struct Node;
名前 | 説明 |
---|---|
std::array<std::shared_ptr<Node>, 2> nxt |
子のポインタ |
int child |
部分木に属する要素の個数 |
名前 | 効果 |
---|---|
Node(); |
コンストラクタ |
https://judge.yosupo.jp/submission/33239
#ifndef EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_
#define EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
namespace emthrm {
template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie {
struct Node {
std::array<std::shared_ptr<Node>, 2> nxt;
int child;
Node() : nxt{nullptr, nullptr}, child(0) {}
};
std::shared_ptr<Node> root;
BinaryTrie() : root(nullptr) {}
void clear() { root.reset(); }
bool empty() const { return !root; }
int size() const { return root ? root->child : 0; }
void erase(const T& x) {
if (root) [[likely]] erase(&root, x, B - 1);
}
std::shared_ptr<Node> find(const T& x) const {
if (!root) [[unlikely]] return nullptr;
std::shared_ptr<Node> node = root;
for (int b = B - 1; b >= 0; --b) {
const bool digit = x >> b & 1;
if (!node->nxt[digit]) return nullptr;
node = node->nxt[digit];
}
return node;
}
std::pair<std::shared_ptr<Node>, T> operator[](const int n) const {
return find_nth(n, 0);
}
std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const {
assert(0 <= n && n < size());
std::shared_ptr<Node> node = root;
T res = 0;
for (int b = B - 1; b >= 0; --b) {
bool digit = x >> b & 1;
const int l_child = (node->nxt[digit] ? node->nxt[digit]->child : 0);
if (n >= l_child) {
n -= l_child;
digit = !digit;
}
node = node->nxt[digit];
if (digit) res |= static_cast<T>(1) << b;
}
return {node, res};
}
std::shared_ptr<Node> insert(const T& x) {
if (!root) [[unlikely]] root = std::make_shared<Node>();
std::shared_ptr<Node> node = root;
++node->child;
for (int b = B - 1; b >= 0; --b) {
const bool digit = x >> b & 1;
if (!node->nxt[digit]) node->nxt[digit] = std::make_shared<Node>();
node = node->nxt[digit];
++node->child;
}
return node;
}
int less_than(const T& x) const {
int res = 0;
std::shared_ptr<Node> node = root;
for (int b = B - 1; node && b >= 0; --b) {
const bool digit = x >> b & 1;
if (digit && node->nxt[0]) res += node->nxt[0]->child;
node = node->nxt[digit];
}
return res;
}
int count(const T& l, const T& r) const {
return less_than(r) - less_than(l);
}
int count(const T& x) const {
const std::shared_ptr<Node> ptr = find(x);
return ptr ? ptr->child : 0;
}
std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(
const T& x) const {
const int lt = less_than(x);
if (lt == size()) return std::make_pair(nullptr, std::nullopt);
const auto [node, value] = find_nth(lt, 0);
return std::make_pair(node, std::make_optional(value));
}
std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(
const T& x) const {
return lower_bound(x + 1);
}
std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const {
return min_element(~x);
}
std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const {
assert(root);
std::shared_ptr<Node> node = root;
T res = 0;
for (int b = B - 1; b >= 0; --b) {
bool digit = x >> b & 1;
if (!node->nxt[digit]) digit = !digit;
node = node->nxt[digit];
if (digit) res |= static_cast<T>(1) << b;
}
return {node, res};
}
private:
void erase(std::shared_ptr<Node>* node, const T& x, int b) {
if (b == -1) {
if (--(*node)->child == 0) node->reset();
return;
}
const bool digit = x >> b & 1;
if (!(*node)->nxt[digit]) return;
(*node)->child -= (*node)->nxt[digit]->child;
erase(&(*node)->nxt[digit], x, b - 1);
if ((*node)->nxt[digit]) {
(*node)->child += (*node)->nxt[digit]->child;
} else if ((*node)->child == 0) {
node->reset();
}
}
};
} // namespace emthrm
#endif // EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_
#line 1 "include/emthrm/data_structure/binary_trie.hpp"
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
namespace emthrm {
template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie {
struct Node {
std::array<std::shared_ptr<Node>, 2> nxt;
int child;
Node() : nxt{nullptr, nullptr}, child(0) {}
};
std::shared_ptr<Node> root;
BinaryTrie() : root(nullptr) {}
void clear() { root.reset(); }
bool empty() const { return !root; }
int size() const { return root ? root->child : 0; }
void erase(const T& x) {
if (root) [[likely]] erase(&root, x, B - 1);
}
std::shared_ptr<Node> find(const T& x) const {
if (!root) [[unlikely]] return nullptr;
std::shared_ptr<Node> node = root;
for (int b = B - 1; b >= 0; --b) {
const bool digit = x >> b & 1;
if (!node->nxt[digit]) return nullptr;
node = node->nxt[digit];
}
return node;
}
std::pair<std::shared_ptr<Node>, T> operator[](const int n) const {
return find_nth(n, 0);
}
std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const {
assert(0 <= n && n < size());
std::shared_ptr<Node> node = root;
T res = 0;
for (int b = B - 1; b >= 0; --b) {
bool digit = x >> b & 1;
const int l_child = (node->nxt[digit] ? node->nxt[digit]->child : 0);
if (n >= l_child) {
n -= l_child;
digit = !digit;
}
node = node->nxt[digit];
if (digit) res |= static_cast<T>(1) << b;
}
return {node, res};
}
std::shared_ptr<Node> insert(const T& x) {
if (!root) [[unlikely]] root = std::make_shared<Node>();
std::shared_ptr<Node> node = root;
++node->child;
for (int b = B - 1; b >= 0; --b) {
const bool digit = x >> b & 1;
if (!node->nxt[digit]) node->nxt[digit] = std::make_shared<Node>();
node = node->nxt[digit];
++node->child;
}
return node;
}
int less_than(const T& x) const {
int res = 0;
std::shared_ptr<Node> node = root;
for (int b = B - 1; node && b >= 0; --b) {
const bool digit = x >> b & 1;
if (digit && node->nxt[0]) res += node->nxt[0]->child;
node = node->nxt[digit];
}
return res;
}
int count(const T& l, const T& r) const {
return less_than(r) - less_than(l);
}
int count(const T& x) const {
const std::shared_ptr<Node> ptr = find(x);
return ptr ? ptr->child : 0;
}
std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(
const T& x) const {
const int lt = less_than(x);
if (lt == size()) return std::make_pair(nullptr, std::nullopt);
const auto [node, value] = find_nth(lt, 0);
return std::make_pair(node, std::make_optional(value));
}
std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(
const T& x) const {
return lower_bound(x + 1);
}
std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const {
return min_element(~x);
}
std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const {
assert(root);
std::shared_ptr<Node> node = root;
T res = 0;
for (int b = B - 1; b >= 0; --b) {
bool digit = x >> b & 1;
if (!node->nxt[digit]) digit = !digit;
node = node->nxt[digit];
if (digit) res |= static_cast<T>(1) << b;
}
return {node, res};
}
private:
void erase(std::shared_ptr<Node>* node, const T& x, int b) {
if (b == -1) {
if (--(*node)->child == 0) node->reset();
return;
}
const bool digit = x >> b & 1;
if (!(*node)->nxt[digit]) return;
(*node)->child -= (*node)->nxt[digit]->child;
erase(&(*node)->nxt[digit], x, b - 1);
if ((*node)->nxt[digit]) {
(*node)->child += (*node)->nxt[digit]->child;
} else if ((*node)->child == 0) {
node->reset();
}
}
};
} // namespace emthrm