cp-library

C++ Library for Competitive Programming

View the Project on GitHub emthrm/cp-library

:heavy_check_mark: データ構造/binary trie
(test/data_structure/binary_trie.test.cpp)

Depends on

Code

/*
 * @title データ構造/binary trie
 *
 * verification-helper: PROBLEM https://judge.yosupo.jp/problem/set_xor_min
 */

#include <iostream>

#include "emthrm/data_structure/binary_trie.hpp"

int main() {
  constexpr int B = 30;
  emthrm::BinaryTrie<B> binary_trie;
  int q;
  std::cin >> q;
  while (q--) {
    int type, x;
    std::cin >> type >> x;
    if (type == 0) {
      if (!binary_trie.find(x)) binary_trie.insert(x);
    } else if (type == 1) {
      binary_trie.erase(x);
    } else if (type == 2) {
      std::cout << (binary_trie.min_element(x).second ^ x) << '\n';
    }
  }
  return 0;
}
#line 1 "test/data_structure/binary_trie.test.cpp"
/*
 * @title データ構造/binary trie
 *
 * verification-helper: PROBLEM https://judge.yosupo.jp/problem/set_xor_min
 */

#include <iostream>

#line 1 "include/emthrm/data_structure/binary_trie.hpp"



#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>

namespace emthrm {

template <int B = 32, typename T = std::uint32_t>
struct BinaryTrie {
  struct Node {
    std::array<std::shared_ptr<Node>, 2> nxt;
    int child;

    Node() : nxt{nullptr, nullptr}, child(0) {}
  };

  std::shared_ptr<Node> root;

  BinaryTrie() : root(nullptr) {}

  void clear() { root.reset(); }

  bool empty() const { return !root; }

  int size() const { return root ? root->child : 0; }

  void erase(const T& x) {
    if (root) [[likely]] erase(&root, x, B - 1);
  }

  std::shared_ptr<Node> find(const T& x) const {
    if (!root) [[unlikely]] return nullptr;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) return nullptr;
      node = node->nxt[digit];
    }
    return node;
  }

  std::pair<std::shared_ptr<Node>, T> operator[](const int n) const {
    return find_nth(n, 0);
  }

  std::pair<std::shared_ptr<Node>, T> find_nth(int n, const T& x) const {
    assert(0 <= n && n < size());
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      const int l_child = (node->nxt[digit] ? node->nxt[digit]->child : 0);
      if (n >= l_child) {
        n -= l_child;
        digit = !digit;
      }
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

  std::shared_ptr<Node> insert(const T& x) {
    if (!root) [[unlikely]] root = std::make_shared<Node>();
    std::shared_ptr<Node> node = root;
    ++node->child;
    for (int b = B - 1; b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (!node->nxt[digit]) node->nxt[digit] = std::make_shared<Node>();
      node = node->nxt[digit];
      ++node->child;
    }
    return node;
  }

  int less_than(const T& x) const {
    int res = 0;
    std::shared_ptr<Node> node = root;
    for (int b = B - 1; node && b >= 0; --b) {
      const bool digit = x >> b & 1;
      if (digit && node->nxt[0]) res += node->nxt[0]->child;
      node = node->nxt[digit];
    }
    return res;
  }

  int count(const T& l, const T& r) const {
    return less_than(r) - less_than(l);
  }

  int count(const T& x) const {
    const std::shared_ptr<Node> ptr = find(x);
    return ptr ? ptr->child : 0;
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> lower_bound(
      const T& x) const {
    const int lt = less_than(x);
    if (lt == size()) return std::make_pair(nullptr, std::nullopt);
    const auto [node, value] = find_nth(lt, 0);
    return std::make_pair(node, std::make_optional(value));
  }

  std::pair<std::shared_ptr<Node>, std::optional<T>> upper_bound(
      const T& x) const {
    return lower_bound(x + 1);
  }

  std::pair<std::shared_ptr<Node>, T> max_element(const T& x = 0) const {
    return min_element(~x);
  }

  std::pair<std::shared_ptr<Node>, T> min_element(const T& x = 0) const {
    assert(root);
    std::shared_ptr<Node> node = root;
    T res = 0;
    for (int b = B - 1; b >= 0; --b) {
      bool digit = x >> b & 1;
      if (!node->nxt[digit]) digit = !digit;
      node = node->nxt[digit];
      if (digit) res |= static_cast<T>(1) << b;
    }
    return {node, res};
  }

 private:
  void erase(std::shared_ptr<Node>* node, const T& x, int b) {
    if (b == -1) {
      if (--(*node)->child == 0) node->reset();
      return;
    }
    const bool digit = x >> b & 1;
    if (!(*node)->nxt[digit]) return;
    (*node)->child -= (*node)->nxt[digit]->child;
    erase(&(*node)->nxt[digit], x, b - 1);
    if ((*node)->nxt[digit]) {
      (*node)->child += (*node)->nxt[digit]->child;
    } else if ((*node)->child == 0) {
      node->reset();
    }
  }
};

}  // namespace emthrm


#line 10 "test/data_structure/binary_trie.test.cpp"

int main() {
  constexpr int B = 30;
  emthrm::BinaryTrie<B> binary_trie;
  int q;
  std::cin >> q;
  while (q--) {
    int type, x;
    std::cin >> type >> x;
    if (type == 0) {
      if (!binary_trie.find(x)) binary_trie.insert(x);
    } else if (type == 1) {
      binary_trie.erase(x);
    } else if (type == 2) {
      std::cout << (binary_trie.min_element(x).second ^ x) << '\n';
    }
  }
  return 0;
}
Back to top page