summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGloria Mundi <gloria@gloria-mundi.eu>2025-06-07 21:20:34 +0200
committerGloria Mundi <gloria@gloria-mundi.eu>2025-06-07 21:20:34 +0200
commit88d04413ebaab961f849ac6ef3d6ff2179253d41 (patch)
tree075e5f245f160cf3d8a03f728a4ebe41e010c5df
parentf8f53c2f9e63f0ac89b67dc4d413ec9a76415a73 (diff)
make union find a struct, remove kruskal
-rw-r--r--content/datastructures/datastructures.tex9
-rw-r--r--content/datastructures/unionFind.cpp42
-rw-r--r--content/graph/cycleCounting.cpp6
-rw-r--r--content/graph/graph.tex12
-rw-r--r--content/graph/kruskal.cpp20
-rw-r--r--test/datastructures/unionFind.cpp48
-rw-r--r--test/graph/articulationPoints.bcc.cpp6
-rw-r--r--test/graph/cycleCounting.cpp8
-rw-r--r--test/graph/kruskal.cpp27
9 files changed, 96 insertions, 82 deletions
diff --git a/content/datastructures/datastructures.tex b/content/datastructures/datastructures.tex
index c4bd312..1c51475 100644
--- a/content/datastructures/datastructures.tex
+++ b/content/datastructures/datastructures.tex
@@ -123,11 +123,12 @@
\begin{algorithm}{Union-Find}
\begin{methods}
- \method{init}{legt $n$ einzelne Unions an}{n}
- \method{findSet}{findet den Repräsentanten}{\log(n)}
- \method{unionSets}{vereint 2 Mengen}{\log(n)}
+ \method{UnionFind}{legt $n$ einzelne Elemente an}{n}
+ \method{find}{findet den Repräsentanten}{\log(n)}
+ \method{link}{vereint 2 Mengen}{\log(n)}
\method{size}{zählt Elemente in Menge, die $a$ enthält}{\log(n)}
- \method{m\*findSet + n\*unionSets}{Folge von Befehlen}{n+m\*\alpha(n)}
+ \method{add}{fügt neues einzelnes Element ein}{1}
+ \method{m\*find + n\*link}{Folge von Befehlen}{n+m\*\alpha(n)}
\end{methods}
\sourcecode{datastructures/unionFind.cpp}
\end{algorithm}
diff --git a/content/datastructures/unionFind.cpp b/content/datastructures/unionFind.cpp
index dd5a569..36a4b45 100644
--- a/content/datastructures/unionFind.cpp
+++ b/content/datastructures/unionFind.cpp
@@ -1,26 +1,26 @@
-// unions[i] >= 0 => unions[i] = parent
-// unions[i] < 0 => unions[i] = -size
-vector<int> unions;
+struct UnionFind {
+ vector<int> unions; // unions[i] = parent or unions[i] = -size
-void init(int n) { //Initialisieren
- unions.assign(n, -1);
-}
+ UnionFind(int n): unions(n, -1) {}
-int findSet(int a) { // Pfadkompression
- if (unions[a] < 0) return a;
- return unions[a] = findSet(unions[a]);
-}
+ int find(int a) {
+ return unions[a] < 0 ? a : unions[a] = find(unions[a]);
+ }
-void linkSets(int a, int b) { // Union by size.
- if (unions[b] > unions[a]) swap(a, b);
- unions[b] += unions[a];
- unions[a] = b;
-}
+ bool link(int a, int b) {
+ if ((a = find(a)) == (b = find(b))) return false;
+ if (unions[b] > unions[a]) swap(a, b);
+ unions[b] += unions[a];
+ unions[a] = b;
+ return true;
+ }
-void unionSets(int a, int b) { // Diese Funktion aufrufen.
- if (findSet(a) != findSet(b)) linkSets(findSet(a), findSet(b));
-}
+ int size(int a) {
+ return -unions[find(a)];
+ }
-int size(int a) {
- return -unions[findSet(a)];
-}
+ int add() {
+ unions.push_back(-1);
+ return ssize(unions) - 1;
+ }
+};
diff --git a/content/graph/cycleCounting.cpp b/content/graph/cycleCounting.cpp
index deac71e..b7545d5 100644
--- a/content/graph/cycleCounting.cpp
+++ b/content/graph/cycleCounting.cpp
@@ -38,13 +38,11 @@ struct cycles {
bool isCycle(cycle cur) {// cycle must be constructed from base
if (cur.none()) return false;
- init(ssize(adj)); // union find @\sourceref{datastructures/unionFind.cpp}@
+ UnionFind uf(ssize(adj)); // union find @\sourceref{datastructures/unionFind.cpp}@
for (int i = 0; i < ssize(edges); i++) {
if (cur[i]) {
cur[i] = false;
- if (findSet(edges[i].first) ==
- findSet(edges[i].second)) break;
- unionSets(edges[i].first, edges[i].second);
+ if (!uf.link(edges[i].first, edges[i].second)) break;
}}
return cur.none();
}
diff --git a/content/graph/graph.tex b/content/graph/graph.tex
index e46ad07..bf51d74 100644
--- a/content/graph/graph.tex
+++ b/content/graph/graph.tex
@@ -10,11 +10,13 @@
Für jeden Kreis $K$ im Graphen gilt:
Die schwerste Kante auf dem Kreis ist nicht Teil des minimalen Spannbaums.
- \subsection{\textsc{Kruskal}}
- \begin{methods}[ll]
- berechnet den Minimalen Spannbaum & \runtime{\abs{E}\cdot\log(\abs{E})} \\
- \end{methods}
- \sourcecode{graph/kruskal.cpp}
+ \optional{
+ \subsubsection{\textsc{Kruskal}'s Algorithm \opthint}
+ \begin{methods}[ll]
+ berechnet den Minimalen Spannbaum & \runtime{\abs{E}\cdot\log(\abs{E})} \\
+ \end{methods}
+ \sourcecode{graph/kruskal.cpp}
+ }
\end{algorithm}
\begin{algorithm}{Heavy-Light Decomposition}
diff --git a/content/graph/kruskal.cpp b/content/graph/kruskal.cpp
index d42800d..98a2682 100644
--- a/content/graph/kruskal.cpp
+++ b/content/graph/kruskal.cpp
@@ -1,9 +1,11 @@
-ranges::sort(edges, less{});
-vector<Edge> mst;
-ll cost = 0;
-for (Edge& e : edges) {
- if (findSet(e.from) != findSet(e.to)) {
- unionSets(e.from, e.to);
- mst.push_back(e);
- cost += e.cost;
-}}
+ll kruskal(int n, vector<Edge> edges, vector<Edge> &mst) {
+ ranges::sort(edges, less{});
+ ll cost = 0;
+ UnionFind uf(n); // union find @\sourceref{datastructures/unionFind.cpp}@
+ for (Edge &e: edges) {
+ if (uf.link(e.from, e.to)) {
+ mst.push_back(e);
+ cost += e.cost;
+ }}
+ return cost;
+}
diff --git a/test/datastructures/unionFind.cpp b/test/datastructures/unionFind.cpp
index 2afdc86..4783f6b 100644
--- a/test/datastructures/unionFind.cpp
+++ b/test/datastructures/unionFind.cpp
@@ -1,8 +1,5 @@
#include "../util.h"
-struct UF {
- UF(int n) {init(n);}
- #include <datastructures/unionFind.cpp>
-};
+#include <datastructures/unionFind.cpp>
struct Naive {
vector<vector<int>> adj;
@@ -28,15 +25,18 @@ struct Naive {
}
}
- int findSet(int a) {
+ int find(int a) {
int res = a;
dfs(a, [&](int x){res = min(res, x);});
return res;
}
- void unionSets(int a, int b) {
+ bool link(int a, int b) {
+ bool linked = false;
+ dfs(a, [&](int x) { linked |= x == b; });
adj[a].push_back(b);
adj[b].push_back(a);
+ return !linked;
}
int size(int a) {
@@ -44,22 +44,38 @@ struct Naive {
dfs(a, [&](int /**/){res++;});
return res;
}
+
+ int add() {
+ int idx = ssize(adj);
+ adj.emplace_back();
+ seen.push_back(counter);
+ return idx;
+ }
};
void stress_test() {
ll queries = 0;
for (int tries = 0; tries < 200; tries++) {
int n = Random::integer<int>(1, 100);
- UF uf(n);
+ UnionFind uf(n);
Naive naive(n);
- for (int i = 0; i < n; i++) {
+ int rounds = n;
+ for (int i = 0; i < rounds; i++) {
for (int j = 0; j < 10; j++) {
int a = Random::integer<int>(0, n);
int b = Random::integer<int>(0, n);
- uf.unionSets(a, b);
- naive.unionSets(a, b);
+ auto got = uf.link(a, b);
+ auto expected = naive.link(a, b);
+ if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL;
}
- UF tmp = uf;
+ {
+ auto got = uf.add();
+ auto expected = naive.add();
+ assert(expected == n);
+ if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL;
+ n++;
+ }
+ UnionFind tmp = uf;
for (int j = 0; j < n; j++) {
{
auto got = tmp.size(j);
@@ -69,8 +85,8 @@ void stress_test() {
{
int a = Random::integer<int>(0, n);
int b = Random::integer<int>(0, n);
- bool got = tmp.findSet(a) == tmp.findSet(b);
- bool expected = naive.findSet(a) == naive.findSet(b);
+ bool got = tmp.find(a) == tmp.find(b);
+ bool expected = naive.find(a) == naive.find(b);
if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL;
}
}
@@ -84,7 +100,7 @@ constexpr int N = 2'000'000;
void performance_test() {
timer t;
t.start();
- UF uf(N);
+ UnionFind uf(N);
t.stop();
hash_t hash = 0;
for (int operations = 0; operations < N; operations++) {
@@ -92,9 +108,9 @@ void performance_test() {
int j = Random::integer<int>(0, N);
int k = Random::integer<int>(0, N);
int l = Random::integer<int>(0, N);
-
+
t.start();
- uf.unionSets(i, j);
+ uf.link(i, j);
hash += uf.size(k);
hash += uf.size(l);
t.stop();
diff --git a/test/graph/articulationPoints.bcc.cpp b/test/graph/articulationPoints.bcc.cpp
index cee2d0b..f112338 100644
--- a/test/graph/articulationPoints.bcc.cpp
+++ b/test/graph/articulationPoints.bcc.cpp
@@ -8,7 +8,7 @@ struct edge {
#include <datastructures/unionFind.cpp>
vector<vector<int>> naiveBCC(int m) {
- init(m);
+ UnionFind uf(m);
vector<int> seen(ssize(adj), -1);
int run = 0;
@@ -28,13 +28,13 @@ vector<vector<int>> naiveBCC(int m) {
}
}
for (auto ee : adj[i]) {
- if (seen[ee.to] == run) unionSets(ee.id, e.id);
+ if (seen[ee.to] == run) uf.link(ee.id, e.id);
}
}
}
vector<vector<int>> res(m);
for (int i = 0; i < m; i++) {
- res[findSet(i)].push_back(i);
+ res[uf.find(i)].push_back(i);
}
for (auto& v : res) ranges::sort(v);
res.erase(begin(ranges::remove_if(res, [](const vector<int>& v){return ssize(v) <= 1;})), end(res));
diff --git a/test/graph/cycleCounting.cpp b/test/graph/cycleCounting.cpp
index 9c7bf0c..82caf16 100644
--- a/test/graph/cycleCounting.cpp
+++ b/test/graph/cycleCounting.cpp
@@ -6,18 +6,14 @@ int naive(const vector<pair<int, int>>& edges, int n) {
int res = 0;
for (int i = 1; i < (1ll << ssize(edges)); i++) {
vector<int> deg(n);
- init(n);
+ UnionFind uf(n);
int cycles = 0;
for (int j = 0; j < ssize(edges); j++) {
if (((i >> j) & 1) != 0) {
auto [a, b] = edges[j];
deg[a]++;
deg[b]++;
- if (findSet(a) != findSet(b)) {
- unionSets(a, b);
- } else {
- cycles++;
- }
+ if (!uf.link(a, b)) cycles++;
}
}
bool ok = cycles == 1;
diff --git a/test/graph/kruskal.cpp b/test/graph/kruskal.cpp
index f6245b9..bc1cce5 100644
--- a/test/graph/kruskal.cpp
+++ b/test/graph/kruskal.cpp
@@ -1,22 +1,19 @@
#include "../util.h"
#include <datastructures/unionFind.cpp>
-struct edge {
+#define Edge Edge_ // we have a struct named Edge in util.h
+
+struct Edge {
int from, to;
ll cost;
- bool operator<(const edge& o) const {
+ bool operator<(const Edge& o) const {
return cost > o.cost;
}
};
-ll kruskal(vector<edge>& edges, int n) {
- init(n);
- #define Edge edge
- #include <graph/kruskal.cpp>
- #undef Edge
- return cost;
-}
-ll prim(vector<edge>& edges, int n) {
+#include <graph/kruskal.cpp>
+
+ll prim(vector<Edge>& edges, int n) {
vector<vector<pair<ll, int>>> adj(n);
for (auto [a, b, d] : edges) {
adj[a].emplace_back(d, b);
@@ -51,13 +48,14 @@ void stress_test() {
Graph<NoData> g(n);
g.erdosRenyi(m);
- vector<edge> edges;
+ vector<Edge> edges;
g.forEdges([&](int a, int b){
ll w = Random::integer<ll>(-1'000'000'000ll, 1'000'000'000ll);
edges.push_back({a, b, w});
});
- ll got = kruskal(edges, n);
+ vector<Edge> mst;
+ ll got = kruskal(n, edges, mst);
ll expected = prim(edges, n);
if (got != expected) cerr << "got: " << got << ", expected: " << expected << FAIL;
@@ -72,14 +70,15 @@ void performance_test() {
timer t;
Graph<NoData> g(N);
g.erdosRenyi(M);
- vector<edge> edges;
+ vector<Edge> edges;
g.forEdges([&](int a, int b){
ll w = Random::integer<ll>(-1'000'000'000ll, 1'000'000'000ll);
edges.push_back({a, b, w});
});
t.start();
- hash_t hash = kruskal(edges, N);
+ vector<Edge> mst;
+ hash_t hash = kruskal(N, edges, mst);
t.stop();
if (t.time > 1000) cerr << "too slow: " << t.time << FAIL;
cerr << "tested performance: " << t.time << "ms (hash: " << hash << ")" << endl;