summaryrefslogtreecommitdiff
path: root/graph/blossom.cpp
blob: b3983ada3216a876e0cf6807adfa8c6deeb0daae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
struct GM {
	vector<vector<int>> adjlist;
	// pairs ist der gematchte knoten oder n
	vector<int> pairs, first, que;
	vector<pair<int, int>> label;
	int head, tail;

	GM(int n) : adjlist(n), pairs(n + 1, n), first(n + 1, n), 
	            que(n), label(n + 1, {-1, -1}) {}

	void rematch(int v, int w) {
		int t = pairs[v]; pairs[v] = w;
		if (pairs[t] != v) return;
		if (label[v].second == -1) {
			pairs[t] = label[v].first;
			rematch(pairs[t], t);
		} else {
			int x = label[v].first;
			int y = label[v].second;
			rematch(x, y);
			rematch(y, x);
	}}

	int findFirst(int u) {
		return label[first[u]].first < 0 ? first[u]
		     : first[u] = findFirst(first[u]);
	}

	void relabel(int x, int y) {
		int r = findFirst(x);
		int s = findFirst(y);
		if (r == s) return;
		auto h = label[r] = label[s] = {~x, y};
		int join;
		while (true) {
			if (s != sz(adjlist)) swap(r, s);
			r = findFirst(label[pairs[r]].first);
			if (label[r] == h) {
				join = r;
				break;
			} else {
				label[r] = h;
		}}
		for (int v : {first[x], first[y]}) {
			for (; v != join; v = first[label[pairs[v]].first]) {
				label[v] = {x, y};
				first[v] = join;
				que[tail++] = v;
	}}}

	bool augment(int u) {
		label[u] = {sz(adjlist), -1};
		first[u] = sz(adjlist);
		head = tail = 0;
		for (que[tail++] = u; head < tail;) {
			int x = que[head++];
			for (int y : adjlist[x]) {
				if (pairs[y] == sz(adjlist) && y != u) {
					pairs[y] = x;
					rematch(x, y);
					return true;
				} else if (label[y].first >= 0) {
					relabel(x, y);
				} else if (label[pairs[y]].first == -1) {
					label[pairs[y]].first = x;
					first[pairs[y]] = y;
					que[tail++] = pairs[y];
		}}}
		return false;
	}

	int match() {
		int matching = head = tail = 0;
		for (int u = 0; u < sz(adjlist); u++) {
			if (pairs[u] < sz(adjlist) || !augment(u)) continue;
			matching++;
			for (int i = 0; i < tail; i++)
				label[que[i]] = label[pairs[que[i]]] = {-1, -1};
			label[sz(adjlist)] = {-1, -1};
		}
		return matching;
	}

};