From 88d04413ebaab961f849ac6ef3d6ff2179253d41 Mon Sep 17 00:00:00 2001 From: Gloria Mundi Date: Sat, 7 Jun 2025 21:20:34 +0200 Subject: make union find a struct, remove kruskal --- content/datastructures/datastructures.tex | 9 +++--- content/datastructures/unionFind.cpp | 42 +++++++++++++-------------- content/graph/cycleCounting.cpp | 6 ++-- content/graph/graph.tex | 12 ++++---- content/graph/kruskal.cpp | 20 +++++++------ test/datastructures/unionFind.cpp | 48 ++++++++++++++++++++----------- test/graph/articulationPoints.bcc.cpp | 6 ++-- test/graph/cycleCounting.cpp | 8 ++---- test/graph/kruskal.cpp | 27 +++++++++-------- 9 files changed, 96 insertions(+), 82 deletions(-) diff --git a/content/datastructures/datastructures.tex b/content/datastructures/datastructures.tex index c4bd312..1c51475 100644 --- a/content/datastructures/datastructures.tex +++ b/content/datastructures/datastructures.tex @@ -123,11 +123,12 @@ \begin{algorithm}{Union-Find} \begin{methods} - \method{init}{legt $n$ einzelne Unions an}{n} - \method{findSet}{findet den Repräsentanten}{\log(n)} - \method{unionSets}{vereint 2 Mengen}{\log(n)} + \method{UnionFind}{legt $n$ einzelne Elemente an}{n} + \method{find}{findet den Repräsentanten}{\log(n)} + \method{link}{vereint 2 Mengen}{\log(n)} \method{size}{zählt Elemente in Menge, die $a$ enthält}{\log(n)} - \method{m\*findSet + n\*unionSets}{Folge von Befehlen}{n+m\*\alpha(n)} + \method{add}{fügt neues einzelnes Element ein}{1} + \method{m\*find + n\*link}{Folge von Befehlen}{n+m\*\alpha(n)} \end{methods} \sourcecode{datastructures/unionFind.cpp} \end{algorithm} diff --git a/content/datastructures/unionFind.cpp b/content/datastructures/unionFind.cpp index dd5a569..36a4b45 100644 --- a/content/datastructures/unionFind.cpp +++ b/content/datastructures/unionFind.cpp @@ -1,26 +1,26 @@ -// unions[i] >= 0 => unions[i] = parent -// unions[i] < 0 => unions[i] = -size -vector unions; +struct UnionFind { + vector unions; // unions[i] = parent or unions[i] = -size -void init(int n) { //Initialisieren - unions.assign(n, -1); -} + UnionFind(int n): unions(n, -1) {} -int findSet(int a) { // Pfadkompression - if (unions[a] < 0) return a; - return unions[a] = findSet(unions[a]); -} + int find(int a) { + return unions[a] < 0 ? a : unions[a] = find(unions[a]); + } -void linkSets(int a, int b) { // Union by size. - if (unions[b] > unions[a]) swap(a, b); - unions[b] += unions[a]; - unions[a] = b; -} + bool link(int a, int b) { + if ((a = find(a)) == (b = find(b))) return false; + if (unions[b] > unions[a]) swap(a, b); + unions[b] += unions[a]; + unions[a] = b; + return true; + } -void unionSets(int a, int b) { // Diese Funktion aufrufen. - if (findSet(a) != findSet(b)) linkSets(findSet(a), findSet(b)); -} + int size(int a) { + return -unions[find(a)]; + } -int size(int a) { - return -unions[findSet(a)]; -} + int add() { + unions.push_back(-1); + return ssize(unions) - 1; + } +}; diff --git a/content/graph/cycleCounting.cpp b/content/graph/cycleCounting.cpp index deac71e..b7545d5 100644 --- a/content/graph/cycleCounting.cpp +++ b/content/graph/cycleCounting.cpp @@ -38,13 +38,11 @@ struct cycles { bool isCycle(cycle cur) {// cycle must be constructed from base if (cur.none()) return false; - init(ssize(adj)); // union find @\sourceref{datastructures/unionFind.cpp}@ + UnionFind uf(ssize(adj)); // union find @\sourceref{datastructures/unionFind.cpp}@ for (int i = 0; i < ssize(edges); i++) { if (cur[i]) { cur[i] = false; - if (findSet(edges[i].first) == - findSet(edges[i].second)) break; - unionSets(edges[i].first, edges[i].second); + if (!uf.link(edges[i].first, edges[i].second)) break; }} return cur.none(); } diff --git a/content/graph/graph.tex b/content/graph/graph.tex index e46ad07..bf51d74 100644 --- a/content/graph/graph.tex +++ b/content/graph/graph.tex @@ -10,11 +10,13 @@ Für jeden Kreis $K$ im Graphen gilt: Die schwerste Kante auf dem Kreis ist nicht Teil des minimalen Spannbaums. - \subsection{\textsc{Kruskal}} - \begin{methods}[ll] - berechnet den Minimalen Spannbaum & \runtime{\abs{E}\cdot\log(\abs{E})} \\ - \end{methods} - \sourcecode{graph/kruskal.cpp} + \optional{ + \subsubsection{\textsc{Kruskal}'s Algorithm \opthint} + \begin{methods}[ll] + berechnet den Minimalen Spannbaum & \runtime{\abs{E}\cdot\log(\abs{E})} \\ + \end{methods} + \sourcecode{graph/kruskal.cpp} + } \end{algorithm} \begin{algorithm}{Heavy-Light Decomposition} diff --git a/content/graph/kruskal.cpp b/content/graph/kruskal.cpp index d42800d..98a2682 100644 --- a/content/graph/kruskal.cpp +++ b/content/graph/kruskal.cpp @@ -1,9 +1,11 @@ -ranges::sort(edges, less{}); -vector mst; -ll cost = 0; -for (Edge& e : edges) { - if (findSet(e.from) != findSet(e.to)) { - unionSets(e.from, e.to); - mst.push_back(e); - cost += e.cost; -}} +ll kruskal(int n, vector edges, vector &mst) { + ranges::sort(edges, less{}); + ll cost = 0; + UnionFind uf(n); // union find @\sourceref{datastructures/unionFind.cpp}@ + for (Edge &e: edges) { + if (uf.link(e.from, e.to)) { + mst.push_back(e); + cost += e.cost; + }} + return cost; +} diff --git a/test/datastructures/unionFind.cpp b/test/datastructures/unionFind.cpp index 2afdc86..4783f6b 100644 --- a/test/datastructures/unionFind.cpp +++ b/test/datastructures/unionFind.cpp @@ -1,8 +1,5 @@ #include "../util.h" -struct UF { - UF(int n) {init(n);} - #include -}; +#include struct Naive { vector> adj; @@ -28,15 +25,18 @@ struct Naive { } } - int findSet(int a) { + int find(int a) { int res = a; dfs(a, [&](int x){res = min(res, x);}); return res; } - void unionSets(int a, int b) { + bool link(int a, int b) { + bool linked = false; + dfs(a, [&](int x) { linked |= x == b; }); adj[a].push_back(b); adj[b].push_back(a); + return !linked; } int size(int a) { @@ -44,22 +44,38 @@ struct Naive { dfs(a, [&](int /**/){res++;}); return res; } + + int add() { + int idx = ssize(adj); + adj.emplace_back(); + seen.push_back(counter); + return idx; + } }; void stress_test() { ll queries = 0; for (int tries = 0; tries < 200; tries++) { int n = Random::integer(1, 100); - UF uf(n); + UnionFind uf(n); Naive naive(n); - for (int i = 0; i < n; i++) { + int rounds = n; + for (int i = 0; i < rounds; i++) { for (int j = 0; j < 10; j++) { int a = Random::integer(0, n); int b = Random::integer(0, n); - uf.unionSets(a, b); - naive.unionSets(a, b); + auto got = uf.link(a, b); + auto expected = naive.link(a, b); + if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL; } - UF tmp = uf; + { + auto got = uf.add(); + auto expected = naive.add(); + assert(expected == n); + if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL; + n++; + } + UnionFind tmp = uf; for (int j = 0; j < n; j++) { { auto got = tmp.size(j); @@ -69,8 +85,8 @@ void stress_test() { { int a = Random::integer(0, n); int b = Random::integer(0, n); - bool got = tmp.findSet(a) == tmp.findSet(b); - bool expected = naive.findSet(a) == naive.findSet(b); + bool got = tmp.find(a) == tmp.find(b); + bool expected = naive.find(a) == naive.find(b); if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL; } } @@ -84,7 +100,7 @@ constexpr int N = 2'000'000; void performance_test() { timer t; t.start(); - UF uf(N); + UnionFind uf(N); t.stop(); hash_t hash = 0; for (int operations = 0; operations < N; operations++) { @@ -92,9 +108,9 @@ void performance_test() { int j = Random::integer(0, N); int k = Random::integer(0, N); int l = Random::integer(0, N); - + t.start(); - uf.unionSets(i, j); + uf.link(i, j); hash += uf.size(k); hash += uf.size(l); t.stop(); diff --git a/test/graph/articulationPoints.bcc.cpp b/test/graph/articulationPoints.bcc.cpp index cee2d0b..f112338 100644 --- a/test/graph/articulationPoints.bcc.cpp +++ b/test/graph/articulationPoints.bcc.cpp @@ -8,7 +8,7 @@ struct edge { #include vector> naiveBCC(int m) { - init(m); + UnionFind uf(m); vector seen(ssize(adj), -1); int run = 0; @@ -28,13 +28,13 @@ vector> naiveBCC(int m) { } } for (auto ee : adj[i]) { - if (seen[ee.to] == run) unionSets(ee.id, e.id); + if (seen[ee.to] == run) uf.link(ee.id, e.id); } } } vector> res(m); for (int i = 0; i < m; i++) { - res[findSet(i)].push_back(i); + res[uf.find(i)].push_back(i); } for (auto& v : res) ranges::sort(v); res.erase(begin(ranges::remove_if(res, [](const vector& v){return ssize(v) <= 1;})), end(res)); diff --git a/test/graph/cycleCounting.cpp b/test/graph/cycleCounting.cpp index 9c7bf0c..82caf16 100644 --- a/test/graph/cycleCounting.cpp +++ b/test/graph/cycleCounting.cpp @@ -6,18 +6,14 @@ int naive(const vector>& edges, int n) { int res = 0; for (int i = 1; i < (1ll << ssize(edges)); i++) { vector deg(n); - init(n); + UnionFind uf(n); int cycles = 0; for (int j = 0; j < ssize(edges); j++) { if (((i >> j) & 1) != 0) { auto [a, b] = edges[j]; deg[a]++; deg[b]++; - if (findSet(a) != findSet(b)) { - unionSets(a, b); - } else { - cycles++; - } + if (!uf.link(a, b)) cycles++; } } bool ok = cycles == 1; diff --git a/test/graph/kruskal.cpp b/test/graph/kruskal.cpp index f6245b9..bc1cce5 100644 --- a/test/graph/kruskal.cpp +++ b/test/graph/kruskal.cpp @@ -1,22 +1,19 @@ #include "../util.h" #include -struct edge { +#define Edge Edge_ // we have a struct named Edge in util.h + +struct Edge { int from, to; ll cost; - bool operator<(const edge& o) const { + bool operator<(const Edge& o) const { return cost > o.cost; } }; -ll kruskal(vector& edges, int n) { - init(n); - #define Edge edge - #include - #undef Edge - return cost; -} -ll prim(vector& edges, int n) { +#include + +ll prim(vector& edges, int n) { vector>> adj(n); for (auto [a, b, d] : edges) { adj[a].emplace_back(d, b); @@ -51,13 +48,14 @@ void stress_test() { Graph g(n); g.erdosRenyi(m); - vector edges; + vector edges; g.forEdges([&](int a, int b){ ll w = Random::integer(-1'000'000'000ll, 1'000'000'000ll); edges.push_back({a, b, w}); }); - ll got = kruskal(edges, n); + vector mst; + ll got = kruskal(n, edges, mst); ll expected = prim(edges, n); if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL; @@ -72,14 +70,15 @@ void performance_test() { timer t; Graph g(N); g.erdosRenyi(M); - vector edges; + vector edges; g.forEdges([&](int a, int b){ ll w = Random::integer(-1'000'000'000ll, 1'000'000'000ll); edges.push_back({a, b, w}); }); t.start(); - hash_t hash = kruskal(edges, N); + vector mst; + hash_t hash = kruskal(N, edges, mst); t.stop(); if (t.time > 1000) cerr << "too slow: " << t.time << FAIL; cerr << "tested performance: " << t.time << "ms (hash: " << hash << ")" << endl; -- cgit v1.2.3