diff options
| -rw-r--r-- | graph/centroid.cpp | 27 | ||||
| -rw-r--r-- | graph/hld.cpp | 32 |
2 files changed, 27 insertions, 32 deletions
diff --git a/graph/centroid.cpp b/graph/centroid.cpp index d2855e2..c5187a5 100644 --- a/graph/centroid.cpp +++ b/graph/centroid.cpp @@ -1,22 +1,21 @@ vector<int> s; -void dfs1(int u, int v = -1) { - s[u] = 1; - for (int w : adj[u]) { - if (w == v) continue; - dfs1(w, u); - s[u] += s[w]; +void dfs_sz(int v, int parent = -1) { + s[v] = 1; + for (int u : adj[v]) if (u != parent) { + dfs_sz(u, v); + s[v] += s[u]; }} -pair<int, int> dfs2(int u, int v, int n) { - for (int w : adj[u]) { - if (2 * s[w] == n) return {u, w}; - if (w != v && 2 * s[w] > n) return dfs2(w, u, n); +pair<int, int> dfs_cent(int v, int parent, int n) { + for (int u : adj[v]) if (u != parent) { + if (2 * s[u] == n) return {v, u}; + if (2 * s[u] > n) return dfs_cent(u, v, n); } - return {u, -1}; + return {v, -1}; } pair<int, int> find_centroid(int root) { - // s muss nicht initialisiert werden, nur groß genug sein - dfs1(root); - return dfs2(root, -1, s[root]); + s.resize(sz(adj)); + dfs_sz(root); + return dfs_cent(root, -1, s[root]); } diff --git a/graph/hld.cpp b/graph/hld.cpp index 9431782..7e3331f 100644 --- a/graph/hld.cpp +++ b/graph/hld.cpp @@ -1,35 +1,31 @@ vector<vector<int>> adj; vector<int> sz, in, out, nxt, par; -int t; +int counter; void dfs_sz(int v = 0, int from = -1) { - sz[v] = 1; - for (auto& u : adj[v]) { - if (u != from) { - dfs_sz(u, v); - sz[v] += sz[u]; - } + for (auto& u : adj[v]) if (u != from) { + dfs_sz(u, v); + sz[v] += sz[u]; if (adj[v][0] == from || sz[u] > sz[adj[v][0]]) { swap(u, adj[v][0]); }}} void dfs_hld(int v = 0, int from = -1) { par[v] = from; - in[v] = t++; - for (int u : adj[v]) { - if (u == from) continue; - nxt[u] = (u == adj[v][0] ? nxt[v] : u); + in[v] = counter++; + for (int u : adj[v]) if (u != from) { + nxt[u] = (u == adj[v][0]) ? nxt[v] : u; dfs_hld(u, v); } - out[v] = t; + out[v] = counter; } -void init() { +void init(int root = 0) { int n = sz(adj); - sz.assign(n, 0); in.assign(n, 0); out.assign(n, 0); - nxt.assign(n, 0); par.assign(n, -1); - t = 0; - dfs_sz(); dfs_hld(); + sz.assign(n, 1), nxt.assign(n, 0), par.assign(n, -1); + in.resize(n), out.resize(n); + counter = 0; + dfs_sz(root); dfs_hld(root); } vector<pair<int, int>> get_intervals(int u, int v) { @@ -47,6 +43,6 @@ vector<pair<int, int>> get_intervals(int u, int v) { int get_lca(int u, int v) { while (true) { if (in[v] < in[u]) swap(u, v); - if (in[nxt[v]] <= in[u]) return in[u]; + if (in[nxt[v]] <= in[u]) return u; v = par[nxt[v]]; }} |
