cp-library

C++ Library for Competitive Programming

View the Project on GitHub emthrm/cp-library

:heavy_check_mark: グラフ/木/全方位木 DP
(test/graph/tree/rerooting_dp.test.cpp)

Depends on

Code

/*
 * @title グラフ/木/全方位木 DP
 *
 * verification-helper: PROBLEM https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=GRL_5_A
 */

#include <algorithm>
#include <functional>
#include <iostream>
#include <utility>
#include <vector>

#include "emthrm/graph/edge.hpp"
#include "emthrm/graph/tree/rerooting_dp.hpp"

int main() {
  int n;
  std::cin >> n;
  std::vector<std::vector<emthrm::Edge<int>>> graph(n);
  for (int i = 0; i < n - 1; ++i) {
    int s, t, w;
    std::cin >> s >> t >> w;
    graph[s].emplace_back(s, t, w);
    graph[t].emplace_back(t, s, w);
  }
  const std::vector<std::pair<int, int>> ans = emthrm::rerooting_dp(
      graph, std::vector<std::pair<int, int>>(n, {0, 0}),
      [](const std::pair<int, int>& x, const std::pair<int, int>& y)
          -> std::pair<int, int> {
        int tmp[]{x.first, x.second, y.first, y.second};
        std::sort(tmp, tmp + 4, std::greater<int>());
        return {tmp[0], tmp[1]};
      },
      [](const std::pair<int, int>& x, const emthrm::Edge<int>& e)
          -> std::pair<int, int> { return {x.first + e.cost, 0}; },
      [](const std::pair<int, int>& x, const int) -> std::pair<int, int> {
        return x;
      });
  int diameter = 0;
  for (int i = 0; i < n; ++i) {
    diameter = std::max(diameter, ans[i].first + ans[i].second);
  }
  std::cout << diameter << '\n';
  return 0;
}
#line 1 "test/graph/tree/rerooting_dp.test.cpp"
/*
 * @title グラフ/木/全方位木 DP
 *
 * verification-helper: PROBLEM https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=GRL_5_A
 */

#include <algorithm>
#include <functional>
#include <iostream>
#include <utility>
#include <vector>

#line 1 "include/emthrm/graph/edge.hpp"
/**
 * @title 辺
 */

#ifndef EMTHRM_GRAPH_EDGE_HPP_
#define EMTHRM_GRAPH_EDGE_HPP_

#include <compare>

namespace emthrm {

template <typename CostType>
struct Edge {
  CostType cost;
  int src, dst;

  explicit Edge(const int src, const int dst, const CostType cost = 0)
      : cost(cost), src(src), dst(dst) {}

  auto operator<=>(const Edge& x) const = default;
};

}  // namespace emthrm

#endif  // EMTHRM_GRAPH_EDGE_HPP_
#line 1 "include/emthrm/graph/tree/rerooting_dp.hpp"



#line 6 "include/emthrm/graph/tree/rerooting_dp.hpp"

#line 1 "include/emthrm/graph/edge.hpp"
/**
 * @title 辺
 */

#ifndef EMTHRM_GRAPH_EDGE_HPP_
#define EMTHRM_GRAPH_EDGE_HPP_

#include <compare>

namespace emthrm {

template <typename CostType>
struct Edge {
  CostType cost;
  int src, dst;

  explicit Edge(const int src, const int dst, const CostType cost = 0)
      : cost(cost), src(src), dst(dst) {}

  auto operator<=>(const Edge& x) const = default;
};

}  // namespace emthrm

#endif  // EMTHRM_GRAPH_EDGE_HPP_
#line 8 "include/emthrm/graph/tree/rerooting_dp.hpp"

namespace emthrm {

template <typename CostType, typename CommutativeSemigroup,
          typename E, typename F, typename G>
std::vector<CommutativeSemigroup> rerooting_dp(
    const std::vector<std::vector<Edge<CostType>>>& graph,
    const std::vector<CommutativeSemigroup>& def,
    const E merge, const F f, const G g) {
  const int n = graph.size();
  if (n == 0) [[unlikely]] return {};
  if (n == 1) [[unlikely]] return {g(def[0], 0)};
  std::vector<std::vector<CommutativeSemigroup>> children(n);
  const auto dfs1 = [&graph, &def, merge, f, g, &children](
      auto dfs1, const int par, const int ver) -> CommutativeSemigroup {
    children[ver].reserve(graph[ver].size());
    CommutativeSemigroup dp = def[ver];
    for (const Edge<CostType>& e : graph[ver]) {
      if (e.dst == par) {
        children[ver].emplace_back();
      } else {
        children[ver].emplace_back(f(dfs1(dfs1, ver, e.dst), e));
        dp = merge(dp, children[ver].back());
      }
    }
    return g(dp, ver);
  };
  dfs1(dfs1, -1, 0);
  std::vector<CommutativeSemigroup> dp = def;
  const auto dfs2 = [&graph, &def, merge, f, g, &children, &dp](
      auto dfs2, const int par, const int ver, const CommutativeSemigroup& m)
          -> void {
    const int c = graph[ver].size();
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst == par) {
        children[ver][i] = f(m, graph[ver][i]);
        break;
      }
    }
    std::vector<CommutativeSemigroup> left{def[ver]}, right;
    left.reserve(c);
    for (int i = 0; i < c - 1; ++i) {
      left.emplace_back(merge(left[i], children[ver][i]));
    }
    dp[ver] = g(merge(left.back(), children[ver].back()), ver);
    if (c >= 2) {
      right.reserve(c - 1);
      right.emplace_back(children[ver].back());
      for (int i = c - 2; i > 0; --i) {
        right.emplace_back(merge(children[ver][i], right[c - 2 - i]));
      }
      std::reverse(right.begin(), right.end());
    }
    for (int i = 0; i < c; ++i) {
      if (graph[ver][i].dst != par) {
        dfs2(dfs2, ver, graph[ver][i].dst,
             g(i + 1 == c ? left[i] : merge(left[i], right[i]), ver));
      }
    }
  };
  dfs2(dfs2, -1, 0, CommutativeSemigroup());
  return dp;
}

}  // namespace emthrm


#line 15 "test/graph/tree/rerooting_dp.test.cpp"

int main() {
  int n;
  std::cin >> n;
  std::vector<std::vector<emthrm::Edge<int>>> graph(n);
  for (int i = 0; i < n - 1; ++i) {
    int s, t, w;
    std::cin >> s >> t >> w;
    graph[s].emplace_back(s, t, w);
    graph[t].emplace_back(t, s, w);
  }
  const std::vector<std::pair<int, int>> ans = emthrm::rerooting_dp(
      graph, std::vector<std::pair<int, int>>(n, {0, 0}),
      [](const std::pair<int, int>& x, const std::pair<int, int>& y)
          -> std::pair<int, int> {
        int tmp[]{x.first, x.second, y.first, y.second};
        std::sort(tmp, tmp + 4, std::greater<int>());
        return {tmp[0], tmp[1]};
      },
      [](const std::pair<int, int>& x, const emthrm::Edge<int>& e)
          -> std::pair<int, int> { return {x.first + e.cost, 0}; },
      [](const std::pair<int, int>& x, const int) -> std::pair<int, int> {
        return x;
      });
  int diameter = 0;
  for (int i = 0; i < n; ++i) {
    diameter = std::max(diameter, ans[i].first + ans[i].second);
  }
  std::cout << diameter << '\n';
  return 0;
}
Back to top page