summaryrefslogtreecommitdiff
path: root/graph
diff options
context:
space:
mode:
Diffstat (limited to 'graph')
-rw-r--r--graph/hld.cpp22
1 files changed, 9 insertions, 13 deletions
diff --git a/graph/hld.cpp b/graph/hld.cpp
index 7e3331f..fae6030 100644
--- a/graph/hld.cpp
+++ b/graph/hld.cpp
@@ -7,7 +7,7 @@ void dfs_sz(int v = 0, int from = -1) {
dfs_sz(u, v);
sz[v] += sz[u];
if (adj[v][0] == from || sz[u] > sz[adj[v][0]]) {
- swap(u, adj[v][0]);
+ swap(u, adj[v][0]); //changes adj!
}}}
void dfs_hld(int v = 0, int from = -1) {
@@ -25,24 +25,20 @@ void init(int root = 0) {
sz.assign(n, 1), nxt.assign(n, 0), par.assign(n, -1);
in.resize(n), out.resize(n);
counter = 0;
- dfs_sz(root); dfs_hld(root);
+ dfs_sz(root);
+ dfs_hld(root);
}
-vector<pair<int, int>> get_intervals(int u, int v) {
- vector<pair<int, int>> res;
- while (true) {
+template<typename F>
+void for_intervals(int u, int v, F&& f) {
+ for (;; v = par[nxt[v]]) {
if (in[v] < in[u]) swap(u, v);
- if (in[nxt[v]] <= in[u]) {
- res.emplace_back(in[u], in[v] + 1);
- return res;
- }
- res.emplace_back(in[nxt[v]], in[v] + 1);
- v = par[nxt[v]];
+ f(max(in[u], in[nxt[v]]), in[v] + 1);
+ if (in[nxt[v]] <= in[u]) return;
}}
int get_lca(int u, int v) {
- while (true) {
+ for (;; v = par[nxt[v]]) {
if (in[v] < in[u]) swap(u, v);
if (in[nxt[v]] <= in[u]) return u;
- v = par[nxt[v]];
}}