summaryrefslogtreecommitdiff
path: root/datastructures/waveletTree.cpp
blob: 36c1b567fdca5c7f3fe19466de139f12ea0f266c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
struct WaveletTree {
	using it = vector<ll>::iterator;
	WaveletTree *ln, *rn;
	ll lo, hi;
	vector<int> b;
private:
	WaveletTree(it from, it to, ll x, ll y)
	: ln(nullptr), rn(nullptr), lo(x), hi(y), b(1) {
		ll mid = (lo + hi) / 2;
		auto f = [&](ll x){return x < mid;};
		for (it c = from; c != to; c++) {
			b.push_back(b.back() + f(*c));
		}
		if (lo + 1 >= hi || from == to) return;
		it pivot = stable_partition(from, to, f);
		ln = new WaveletTree(from, pivot, lo, mid);
		rn = new WaveletTree(pivot, to, mid, hi);
	}
public:
	WaveletTree(vector<ll> in) : WaveletTree(all(in), 
		*min_element(all(in)), *max_element(all(in)) + 1){}

	// kth element in sort[l, r) all 0-indexed
	ll kth(int l, int r, int k) {
		if (l >= r || k >= r - l) return -1;
		if (lo + 1 >= hi) return lo;
		int inLeft = b[r] - b[l];
		if (k < inLeft) {
			return ln->kth(b[l], b[r], k);
		} else {
			return rn->kth(l-b[l], r-b[r], k-inLeft);
	}}

	// count elements in[l, r) smaller than k
	int countSmaller(int l, int r, ll k) {
		if (l >= r || k <= lo) return 0;
		if (hi <= k) return r - l;
		return ln->countSmaller(b[l], b[r], k) + 
		       rn->countSmaller(l-b[l], r-b[r], k);
	}

	~WaveletTree(){
		delete ln;
		delete rn;
	}
};