summaryrefslogtreecommitdiff
path: root/content
diff options
context:
space:
mode:
Diffstat (limited to 'content')
-rw-r--r--content/graph/reroot.cpp86
1 files changed, 36 insertions, 50 deletions
diff --git a/content/graph/reroot.cpp b/content/graph/reroot.cpp
index 4c6a748..379c839 100644
--- a/content/graph/reroot.cpp
+++ b/content/graph/reroot.cpp
@@ -1,62 +1,48 @@
-// Usual Tree DP can be broken down in 4 steps:
-// - Initialize dp[v] = identity
-// - Iterate over all children w and take a value for w
-// by looking at dp[w] and possibly the edge label of v -> w
-// - combine the values of those children
-// usually this operation should be commutative and associative
-// - finalize the dp[v] after iterating over all children
+using W = ll; // edge weight type
+vector<vector<pair<int, W>>> adj;
+
struct Reroot {
- using T = ll;
+ using T = ll; // dp type
- // identity element
- T E() {}
- // x: dp value of child
- // e: index of edge going to child
- T takeChild(T x, int e) {}
- T comb(T x, T y) {}
- // called after combining all dp values of children
- T fin(T x, int v) {}
+ static constexpr T E = 0; // neutral element
+ T takeChild(int v, int c, W w, T x) {} // move child along edge
+ static T comb(T x, T y) {}
+ T fin(int v, T x) {} // add v to own dp value x
- vector<vector<pair<int, int>>> g;
- vector<int> ord, pae;
vector<T> dp;
- T dfs(int v) {
- ord.push_back(v);
- for (auto [w, e] : g[v]) {
- g[w].erase(find(all(g[w]), pair(v, e^1)));
- pae[w] = e^1;
- dp[v] = comb(dp[v], takeChild(dfs(w), e));
+ T dfs0(int v, int from = -1) {
+ T val = E;
+ for (auto [u, w] : adj[v]) {
+ if (u == from) continue;
+ val = comb(val, takeChild(v, u, w, dfs0(u, v)));
}
- return dp[v] = fin(dp[v], v);
+ return dp[v] = fin(v, val);
}
- vector<T> solve(int n, vector<pair<int, int>> edges) {
- g.resize(n);
- for (int i = 0; i < n-1; i++) {
- g[edges[i].first].emplace_back(edges[i].second, 2*i);
- g[edges[i].second].emplace_back(edges[i].first, 2*i+1);
+ void dfs1(int v, int from = -1) {
+ vector<T> pref = {E};
+ for (auto [u, w] : adj[v]) {
+ pref.push_back(takeChild(v, u, w, dp[u]));
}
- pae.assign(n, -1);
- dp.assign(n, E());
- dfs(0);
- vector<T> updp(n, E()), res(n, E());
- for (int v : ord) {
- vector<T> pref(sz(g[v])+1), suff(sz(g[v])+1);
- if (v != 0) pref[0] = takeChild(updp[v], pae[v]);
- for (int i = 0; i < sz(g[v]); i++){
- auto [u, w] = g[v][i];
- pref[i+1] = suff[i] = takeChild(dp[u], w);
- pref[i+1] = comb(pref[i], pref[i+1]);
- }
- for (int i = sz(g[v])-1; i >= 0; i--) {
- suff[i] = comb(suff[i], suff[i+1]);
- }
- for (int i = 0; i < sz(g[v]); i++) {
- updp[g[v][i].first] = fin(comb(pref[i], suff[i+1]), v);
- }
- res[v] = fin(pref.back(), v);
+ auto suf = pref;
+ partial_sum(all(pref), pref.begin(), comb);
+ exclusive_scan(suf.rbegin(), suf.rend(),
+ suf.rbegin(), E, comb);
+
+ for (int i = 0; i < sz(adj[v]); i++) {
+ auto [u, w] = adj[v][i];
+ if (u == from) continue;
+ dp[v] = fin(v, comb(pref[i], suf[i + 1]));
+ dfs1(u, v);
}
- return res;
+ dp[v] = fin(v, suf[0]);
+ }
+
+ auto solve() {
+ dp.assign(sz(adj), E);
+ dfs0(0);
+ dfs1(0);
+ return dp;
}
};