cp-library

C++ Library for Competitive Programming

View the Project on GitHub emthrm/cp-library

:heavy_check_mark: 全方位木 DP
(include/emthrm/graph/tree/rerooting_dp.hpp)

仕様

名前 戻り値
template <typename CostType, typename CommutativeSemigroup, typename E, typename F, typename G>
std::vector<CommutativeSemigroup> rerooting_dp(const std::vector<std::vector<Edge<CostType>>>& graph, const std::vector<CommutativeSemigroup>& def, const E merge, const F f, const G g);
木 $\mathrm{graph}$ に対する全方位木 DP

参考文献

Submissons

https://judge.yosupo.jp/submission/138885

Depends on

Verified with

Code

#ifndef EMTHRM_GRAPH_TREE_REROOTING_DP_HPP_
#define EMTHRM_GRAPH_TREE_REROOTING_DP_HPP_

#include <algorithm>
#include <vector>

#include "emthrm/graph/edge.hpp"

namespace emthrm {

template <typename CostType, typename CommutativeSemigroup,
          typename E, typename F, typename G>
std::vector<CommutativeSemigroup> rerooting_dp(
    const std::vector<std::vector<Edge<CostType>>>& graph,
    const std::vector<CommutativeSemigroup>& def,
    const E merge, const F f, const G g) {
  const int n = graph.size();
  if (n == 0) [[unlikely]] return {};
  if (n == 1) [[unlikely]] return {g(def[0], 0)};
  std::vector<std::vector<CommutativeSemigroup>> children(n);
  const auto dfs1 = [&graph, &def, merge, f, g, &children](
      auto dfs1, const int par, const int ver) -> CommutativeSemigroup {
    children[ver].reserve(graph[ver].size());
    CommutativeSemigroup dp = def[ver];
    for (const Edge<CostType>& e : graph[ver]) {
      if (e.dst == par) {
        children[ver].emplace_back();
      } else {
        children[ver].emplace_back(f(dfs1(dfs1, ver, e.dst), e));
        dp = merge(dp, children[ver].back());
      }
    }
    return g(dp, ver);
  };
  dfs1(dfs1, -1, 0);
  std::vector<CommutativeSemigroup> dp = def;
  const auto dfs2 = [&graph, &def, merge, f, g, &children, &dp](
      auto dfs2, const int par, const int ver, const CommutativeSemigroup& m)
          -> void {
    const int c = graph[ver].size();
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst == par) {
        children[ver][i] = f(m, graph[ver][i]);
        break;
      }
    }
    std::vector<CommutativeSemigroup> left{def[ver]}, right;
    left.reserve(c);
    for (int i = 0; i < c - 1; ++i) {
      left.emplace_back(merge(left[i], children[ver][i]));
    }
    dp[ver] = g(merge(left.back(), children[ver].back()), ver);
    if (c >= 2) {
      right.reserve(c - 1);
      right.emplace_back(children[ver].back());
      for (int i = c - 2; i > 0; --i) {
        right.emplace_back(merge(children[ver][i], right[c - 2 - i]));
      }
      std::reverse(right.begin(), right.end());
    }
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst != par) {
        dfs2(dfs2, ver, graph[ver][i].dst,
             g(i + 1 == c ? left[i] : merge(left[i], right[i]), ver));
      }
    }
  };
  dfs2(dfs2, -1, 0, CommutativeSemigroup());
  return dp;
}

}  // namespace emthrm

#endif  // EMTHRM_GRAPH_TREE_REROOTING_DP_HPP_
#line 1 "include/emthrm/graph/tree/rerooting_dp.hpp"



#include <algorithm>
#include <vector>

#line 1 "include/emthrm/graph/edge.hpp"
/**
 * @title 辺
 */

#ifndef EMTHRM_GRAPH_EDGE_HPP_
#define EMTHRM_GRAPH_EDGE_HPP_

#include <compare>

namespace emthrm {

template <typename CostType>
struct Edge {
  CostType cost;
  int src, dst;

  explicit Edge(const int src, const int dst, const CostType cost = 0)
      : cost(cost), src(src), dst(dst) {}

  auto operator<=>(const Edge& x) const = default;
};

}  // namespace emthrm

#endif  // EMTHRM_GRAPH_EDGE_HPP_
#line 8 "include/emthrm/graph/tree/rerooting_dp.hpp"

namespace emthrm {

template <typename CostType, typename CommutativeSemigroup,
          typename E, typename F, typename G>
std::vector<CommutativeSemigroup> rerooting_dp(
    const std::vector<std::vector<Edge<CostType>>>& graph,
    const std::vector<CommutativeSemigroup>& def,
    const E merge, const F f, const G g) {
  const int n = graph.size();
  if (n == 0) [[unlikely]] return {};
  if (n == 1) [[unlikely]] return {g(def[0], 0)};
  std::vector<std::vector<CommutativeSemigroup>> children(n);
  const auto dfs1 = [&graph, &def, merge, f, g, &children](
      auto dfs1, const int par, const int ver) -> CommutativeSemigroup {
    children[ver].reserve(graph[ver].size());
    CommutativeSemigroup dp = def[ver];
    for (const Edge<CostType>& e : graph[ver]) {
      if (e.dst == par) {
        children[ver].emplace_back();
      } else {
        children[ver].emplace_back(f(dfs1(dfs1, ver, e.dst), e));
        dp = merge(dp, children[ver].back());
      }
    }
    return g(dp, ver);
  };
  dfs1(dfs1, -1, 0);
  std::vector<CommutativeSemigroup> dp = def;
  const auto dfs2 = [&graph, &def, merge, f, g, &children, &dp](
      auto dfs2, const int par, const int ver, const CommutativeSemigroup& m)
          -> void {
    const int c = graph[ver].size();
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst == par) {
        children[ver][i] = f(m, graph[ver][i]);
        break;
      }
    }
    std::vector<CommutativeSemigroup> left{def[ver]}, right;
    left.reserve(c);
    for (int i = 0; i < c - 1; ++i) {
      left.emplace_back(merge(left[i], children[ver][i]));
    }
    dp[ver] = g(merge(left.back(), children[ver].back()), ver);
    if (c >= 2) {
      right.reserve(c - 1);
      right.emplace_back(children[ver].back());
      for (int i = c - 2; i > 0; --i) {
        right.emplace_back(merge(children[ver][i], right[c - 2 - i]));
      }
      std::reverse(right.begin(), right.end());
    }
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst != par) {
        dfs2(dfs2, ver, graph[ver][i].dst,
             g(i + 1 == c ? left[i] : merge(left[i], right[i]), ver));
      }
    }
  };
  dfs2(dfs2, -1, 0, CommutativeSemigroup());
  return dp;
}

}  // namespace emthrm
Back to top page