summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--graph/centroid.cpp27
-rw-r--r--graph/hld.cpp32
2 files changed, 27 insertions, 32 deletions
diff --git a/graph/centroid.cpp b/graph/centroid.cpp
index d2855e2..c5187a5 100644
--- a/graph/centroid.cpp
+++ b/graph/centroid.cpp
@@ -1,22 +1,21 @@
vector<int> s;
-void dfs1(int u, int v = -1) {
- s[u] = 1;
- for (int w : adj[u]) {
- if (w == v) continue;
- dfs1(w, u);
- s[u] += s[w];
+void dfs_sz(int v, int parent = -1) {
+ s[v] = 1;
+ for (int u : adj[v]) if (u != parent) {
+ dfs_sz(u, v);
+ s[v] += s[u];
}}
-pair<int, int> dfs2(int u, int v, int n) {
- for (int w : adj[u]) {
- if (2 * s[w] == n) return {u, w};
- if (w != v && 2 * s[w] > n) return dfs2(w, u, n);
+pair<int, int> dfs_cent(int v, int parent, int n) {
+ for (int u : adj[v]) if (u != parent) {
+ if (2 * s[u] == n) return {v, u};
+ if (2 * s[u] > n) return dfs_cent(u, v, n);
}
- return {u, -1};
+ return {v, -1};
}
pair<int, int> find_centroid(int root) {
- // s muss nicht initialisiert werden, nur groß genug sein
- dfs1(root);
- return dfs2(root, -1, s[root]);
+ s.resize(sz(adj));
+ dfs_sz(root);
+ return dfs_cent(root, -1, s[root]);
}
diff --git a/graph/hld.cpp b/graph/hld.cpp
index 9431782..7e3331f 100644
--- a/graph/hld.cpp
+++ b/graph/hld.cpp
@@ -1,35 +1,31 @@
vector<vector<int>> adj;
vector<int> sz, in, out, nxt, par;
-int t;
+int counter;
void dfs_sz(int v = 0, int from = -1) {
- sz[v] = 1;
- for (auto& u : adj[v]) {
- if (u != from) {
- dfs_sz(u, v);
- sz[v] += sz[u];
- }
+ for (auto& u : adj[v]) if (u != from) {
+ dfs_sz(u, v);
+ sz[v] += sz[u];
if (adj[v][0] == from || sz[u] > sz[adj[v][0]]) {
swap(u, adj[v][0]);
}}}
void dfs_hld(int v = 0, int from = -1) {
par[v] = from;
- in[v] = t++;
- for (int u : adj[v]) {
- if (u == from) continue;
- nxt[u] = (u == adj[v][0] ? nxt[v] : u);
+ in[v] = counter++;
+ for (int u : adj[v]) if (u != from) {
+ nxt[u] = (u == adj[v][0]) ? nxt[v] : u;
dfs_hld(u, v);
}
- out[v] = t;
+ out[v] = counter;
}
-void init() {
+void init(int root = 0) {
int n = sz(adj);
- sz.assign(n, 0); in.assign(n, 0); out.assign(n, 0);
- nxt.assign(n, 0); par.assign(n, -1);
- t = 0;
- dfs_sz(); dfs_hld();
+ sz.assign(n, 1), nxt.assign(n, 0), par.assign(n, -1);
+ in.resize(n), out.resize(n);
+ counter = 0;
+ dfs_sz(root); dfs_hld(root);
}
vector<pair<int, int>> get_intervals(int u, int v) {
@@ -47,6 +43,6 @@ vector<pair<int, int>> get_intervals(int u, int v) {
int get_lca(int u, int v) {
while (true) {
if (in[v] < in[u]) swap(u, v);
- if (in[nxt[v]] <= in[u]) return in[u];
+ if (in[nxt[v]] <= in[u]) return u;
v = par[nxt[v]];
}}