summaryrefslogtreecommitdiff
path: root/datastructures/waveletTree.cpp
diff options
context:
space:
mode:
authorGloria Mundi <gloria@gloria-mundi.eu>2024-05-02 20:42:43 +0200
committerGloria Mundi <gloria@gloria-mundi.eu>2024-05-02 20:42:43 +0200
commit34c882ab75a60699429421684a9867cce0a22110 (patch)
tree09a14183b60c205196b27f5175c6693e4bf16e78 /datastructures/waveletTree.cpp
parente0beaa56b648367bc52dc8c7d44162ac1c8b45fe (diff)
wavelet tree changes + tests
Diffstat (limited to 'datastructures/waveletTree.cpp')
-rw-r--r--datastructures/waveletTree.cpp26
1 files changed, 9 insertions, 17 deletions
diff --git a/datastructures/waveletTree.cpp b/datastructures/waveletTree.cpp
index 476658e..95ff207 100644
--- a/datastructures/waveletTree.cpp
+++ b/datastructures/waveletTree.cpp
@@ -1,25 +1,20 @@
struct WaveletTree {
- using it = vector<ll>::iterator;
- WaveletTree *ln = nullptr, *rn = nullptr;
+ unique_ptr<WaveletTree> ln, rn;
vector<int> b = {0};
ll lo, hi;
- WaveletTree(vector<ll> in) : WaveletTree(all(in)) {}
-
- WaveletTree(it from, it to) : // call above one
- lo(*min_element(from, to)), hi(*max_element(from, to) + 1) {
+ WaveletTree(auto in) : lo(*ranges::min_element(in)),
+ hi(*ranges::max_element(in) + 1) {
ll mid = (lo + hi) / 2;
- auto f = [&](ll x) {return x < mid;};
- for (it c = from; c != to; c++) {
- b.push_back(b.back() + f(*c));
- }
+ auto f = [&](ll x) { return x < mid; };
+ for (ll x: in) b.push_back(b.back() + f(x));
if (lo + 1 >= hi) return;
- it pivot = stable_partition(from, to, f);
- ln = new WaveletTree(from, pivot);
- rn = new WaveletTree(pivot, to);
+ auto right = ranges::stable_partition(in, f);
+ ln = make_unique<WaveletTree>(
+ ranges::subrange(begin(in), begin(right)));
+ rn = make_unique<WaveletTree>(right);
}
- // kth element in sort[l, r) all 0-indexed
ll kth(int l, int r, int k) {
if (l >= r || k >= r - l) return -1;
if (lo + 1 >= hi) return lo;
@@ -28,13 +23,10 @@ struct WaveletTree {
else return rn->kth(l-b[l], r-b[r], k-inLeft);
}
- // count elements in[l, r) smaller than k
int countSmaller(int l, int r, ll k) {
if (l >= r || k <= lo) return 0;
if (hi <= k) return r - l;
return ln->countSmaller(b[l], b[r], k) +
rn->countSmaller(l-b[l], r-b[r], k);
}
-
- ~WaveletTree() {delete ln; delete rn;}
};