cp-library

C++ Library for Competitive Programming

View the Project on GitHub emthrm/cp-library

:heavy_check_mark: binary trie
(include/emthrm/data_structure/binary_trie.hpp)

非負整数を管理するトライ木である。

時間計算量

$O(B)$

仕様

template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie;

メンバ変数

名前 説明
std::shared_ptr<Node> root; 根のポインタ

メンバ関数

名前 効果・戻り値 備考
BinaryTrie(); 集合 $S \mathrel{:=} \emptyset$ を表すオブジェクトを構築する  
void clear(); $S \gets \emptyset$  
bool empty() const; $S = \emptyset$ を満たすか。  
int size() const; $\lvert S \rvert$  
void erase(const T& x); $S \gets S \setminus \lbrace x \rbrace$ $x \notin S$ を満たすときは何もしない。
std::shared_ptr<Node> find(const T& x) const; $x$ を指すノード。ただし $x \notin S$ を満たすときは nullptr を返す。  
std::pair<std::shared_ptr<Node>, T> operator[](const int n) const; $n$ 番目 (0-based) の要素を指すノードと値  
std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const; $\lbrace s \oplus x \mid s \in S \rbrace$ における $n$ 番目 (0-based) の要素を表すノードと値  
std::shared_ptr<Node> insert(const T& x); $S \gets S \cup \lbrace x \rbrace$ の操作後、$x$ を表すノードを返す。  
int less_than(const T& x) const; $\lvert \lbrace s \in S \mid s < x \rbrace \rvert$  
int count(const T& l, const T& r) const; $\lvert \lbrace s \in S \mid l \leq x < r \rbrace \rvert$  
int count(const T& x) const; $\lvert \lbrace s \in S \mid s = x \rbrace \rvert$  
std::pair<std::shared_ptr<Node>, T> lower_bound(const T& x) const; $x$ より小さくない最初の要素を表すノードと値。ただしそのような要素が存在しないときは (nullptr, $-1$) を返す。  
std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(const T& x) const; $x$ より小さくない最初の要素を表すノードと値。ただしそのような要素が存在しないときは (nullptr, std::nullopt) を返す。  
std::pair<std::shared_ptr<Node>, T> upper_bound(const T& x) const; $x$ より大きい最初の要素を表すノードと値。ただし存在しないときは (nullptr, $-1$) を返す。  
std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(const T& x) const; $x$ より大きい最初の要素を表すノードと値。ただし存在しないときは (nullptr, std::nullopt) を返す。  
std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const; $\mathrm{argmax} \lbrace s \oplus x \mid s \in S \rbrace$  
std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const; $\mathrm{argmin} \lbrace s \oplus x \mid s \in S \rbrace$  

メンバ型

名前 説明
Node ノードを表す構造体
struct Node;

メンバ変数

名前 説明
std::array<std::shared_ptr<Node>, 2> nxt 子のポインタ
int child 部分木に属する要素の個数

メンバ関数

名前 効果
Node(); コンストラクタ

参考文献

TODO

Submissons

https://judge.yosupo.jp/submission/33239

Verified with

Code

#ifndef EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_
#define EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_

#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

namespace emthrm {

template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie {
  struct Node {
    std::array<std::shared_ptr<Node>, 2> nxt;
    int child;

    Node() : nxt{nullptr, nullptr}, child(0) {}
  };

  std::shared_ptr<Node> root;

  BinaryTrie() : root(nullptr) {}

  void clear() { root.reset(); }

  bool empty() const { return !root; }

  int size() const { return root ? root->child : 0; }

  void erase(const T& x) {
    if (root) [[likely]] erase(&root, x, B - 1);
  }

  std::shared_ptr<Node> find(const T& x) const {
    if (!root) [[unlikely]] return nullptr;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) return nullptr;
      node = node->nxt[digit];
    }
    return node;
  }

  std::pair<std::shared_ptr<Node>, T> operator[](const int n) const {
    return find_nth(n, 0);
  }

  std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const {
    assert(0 <= n && n < size());
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      const int l_child = (node->nxt[digit] ? node->nxt[digit]->child : 0);
      if (n >= l_child) {
        n -= l_child;
        digit = !digit;
      }
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

  std::shared_ptr<Node> insert(const T& x) {
    if (!root) [[unlikely]] root = std::make_shared<Node>();
    std::shared_ptr<Node> node = root;
    ++node->child;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) node->nxt[digit] = std::make_shared<Node>();
      node = node->nxt[digit];
      ++node->child;
    }
    return node;
  }

  int less_than(const T& x) const {
    int res = 0;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; node && b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (digit && node->nxt[0]) res += node->nxt[0]->child;
      node = node->nxt[digit];
    }
    return res;
  }

  int count(const T& l, const T& r) const {
    return less_than(r) - less_than(l);
  }

  int count(const T& x) const {
    const std::shared_ptr<Node> ptr = find(x);
    return ptr ? ptr->child : 0;
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(
      const T& x) const {
    const int lt = less_than(x);
    if (lt == size()) return std::make_pair(nullptr, std::nullopt);
    const auto [node, value] = find_nth(lt, 0);
    return std::make_pair(node, std::make_optional(value));
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(
      const T& x) const {
    return lower_bound(x + 1);
  }

  std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const {
    return min_element(~x);
  }

  std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const {
    assert(root);
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      if (!node->nxt[digit]) digit = !digit;
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

 private:
  void erase(std::shared_ptr<Node>* node, const T& x, int b) {
    if (b == -1) {
      if (--(*node)->child == 0) node->reset();
      return;
    }
    const bool digit = x >> b & 1;
    if (!(*node)->nxt[digit]) return;
    (*node)->child -= (*node)->nxt[digit]->child;
    erase(&(*node)->nxt[digit], x, b - 1);
    if ((*node)->nxt[digit]) {
      (*node)->child += (*node)->nxt[digit]->child;
    } else if ((*node)->child == 0) {
      node->reset();
    }
  }
};

}  // namespace emthrm

#endif  // EMTHRM_DATA_STRUCTURE_BINARY_TRIE_HPP_
#line 1 "include/emthrm/data_structure/binary_trie.hpp"



#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

namespace emthrm {

template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie {
  struct Node {
    std::array<std::shared_ptr<Node>, 2> nxt;
    int child;

    Node() : nxt{nullptr, nullptr}, child(0) {}
  };

  std::shared_ptr<Node> root;

  BinaryTrie() : root(nullptr) {}

  void clear() { root.reset(); }

  bool empty() const { return !root; }

  int size() const { return root ? root->child : 0; }

  void erase(const T& x) {
    if (root) [[likely]] erase(&root, x, B - 1);
  }

  std::shared_ptr<Node> find(const T& x) const {
    if (!root) [[unlikely]] return nullptr;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) return nullptr;
      node = node->nxt[digit];
    }
    return node;
  }

  std::pair<std::shared_ptr<Node>, T> operator[](const int n) const {
    return find_nth(n, 0);
  }

  std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const {
    assert(0 <= n && n < size());
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      const int l_child = (node->nxt[digit] ? node->nxt[digit]->child : 0);
      if (n >= l_child) {
        n -= l_child;
        digit = !digit;
      }
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

  std::shared_ptr<Node> insert(const T& x) {
    if (!root) [[unlikely]] root = std::make_shared<Node>();
    std::shared_ptr<Node> node = root;
    ++node->child;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) node->nxt[digit] = std::make_shared<Node>();
      node = node->nxt[digit];
      ++node->child;
    }
    return node;
  }

  int less_than(const T& x) const {
    int res = 0;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; node && b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (digit && node->nxt[0]) res += node->nxt[0]->child;
      node = node->nxt[digit];
    }
    return res;
  }

  int count(const T& l, const T& r) const {
    return less_than(r) - less_than(l);
  }

  int count(const T& x) const {
    const std::shared_ptr<Node> ptr = find(x);
    return ptr ? ptr->child : 0;
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(
      const T& x) const {
    const int lt = less_than(x);
    if (lt == size()) return std::make_pair(nullptr, std::nullopt);
    const auto [node, value] = find_nth(lt, 0);
    return std::make_pair(node, std::make_optional(value));
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(
      const T& x) const {
    return lower_bound(x + 1);
  }

  std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const {
    return min_element(~x);
  }

  std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const {
    assert(root);
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      if (!node->nxt[digit]) digit = !digit;
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

 private:
  void erase(std::shared_ptr<Node>* node, const T& x, int b) {
    if (b == -1) {
      if (--(*node)->child == 0) node->reset();
      return;
    }
    const bool digit = x >> b & 1;
    if (!(*node)->nxt[digit]) return;
    (*node)->child -= (*node)->nxt[digit]->child;
    erase(&(*node)->nxt[digit], x, b - 1);
    if ((*node)->nxt[digit]) {
      (*node)->child += (*node)->nxt[digit]->child;
    } else if ((*node)->child == 0) {
      node->reset();
    }
  }
};

}  // namespace emthrm
Back to top page