summaryrefslogtreecommitdiff
path: root/datastructures/waveletTree.cpp
diff options
context:
space:
mode:
authorNoobie99 <noob999noob999@gmail.com>2023-04-03 22:44:38 +0200
committerNoobie99 <noob999noob999@gmail.com>2023-04-03 22:44:38 +0200
commitc93c7956ebf2d572ea8fb341a636b60e6f97ad84 (patch)
tree7d6b6d8f6c6bc3dd9caf47e16a47d95384630706 /datastructures/waveletTree.cpp
parent107d427b32e3b60072a9008dc19be37b0a2b3ce1 (diff)
minor changes wavelet tree
Diffstat (limited to 'datastructures/waveletTree.cpp')
-rw-r--r--datastructures/waveletTree.cpp86
1 files changed, 40 insertions, 46 deletions
diff --git a/datastructures/waveletTree.cpp b/datastructures/waveletTree.cpp
index 36c1b56..4860a40 100644
--- a/datastructures/waveletTree.cpp
+++ b/datastructures/waveletTree.cpp
@@ -1,46 +1,40 @@
-struct WaveletTree {
- using it = vector<ll>::iterator;
- WaveletTree *ln, *rn;
- ll lo, hi;
- vector<int> b;
-private:
- WaveletTree(it from, it to, ll x, ll y)
- : ln(nullptr), rn(nullptr), lo(x), hi(y), b(1) {
- ll mid = (lo + hi) / 2;
- auto f = [&](ll x){return x < mid;};
- for (it c = from; c != to; c++) {
- b.push_back(b.back() + f(*c));
- }
- if (lo + 1 >= hi || from == to) return;
- it pivot = stable_partition(from, to, f);
- ln = new WaveletTree(from, pivot, lo, mid);
- rn = new WaveletTree(pivot, to, mid, hi);
- }
-public:
- WaveletTree(vector<ll> in) : WaveletTree(all(in),
- *min_element(all(in)), *max_element(all(in)) + 1){}
-
- // kth element in sort[l, r) all 0-indexed
- ll kth(int l, int r, int k) {
- if (l >= r || k >= r - l) return -1;
- if (lo + 1 >= hi) return lo;
- int inLeft = b[r] - b[l];
- if (k < inLeft) {
- return ln->kth(b[l], b[r], k);
- } else {
- return rn->kth(l-b[l], r-b[r], k-inLeft);
- }}
-
- // count elements in[l, r) smaller than k
- int countSmaller(int l, int r, ll k) {
- if (l >= r || k <= lo) return 0;
- if (hi <= k) return r - l;
- return ln->countSmaller(b[l], b[r], k) +
- rn->countSmaller(l-b[l], r-b[r], k);
- }
-
- ~WaveletTree(){
- delete ln;
- delete rn;
- }
-};
+struct WaveletTree {
+ using it = vector<ll>::iterator;
+ WaveletTree *ln = nullptr, *rn = nullptr;
+ vector<int> b = {0};
+ ll lo, hi;
+
+ WaveletTree(vector<ll> a) : WaveletTree(all(a), // call this
+ *min_element(all(a)), *max_element(all(a)) + 1) {}
+
+ WaveletTree(it from, it to, ll x, ll y) : lo(x), hi(y) {
+ if (lo + 1 >= hi || from == to) return;
+ ll mid = (lo + hi) / 2;
+ auto f = [&](ll x){return x < mid;};
+ for (it c = from; c != to; c++) {
+ b.push_back(b.back() + f(*c));
+ }
+ it pivot = stable_partition(from, to, f);
+ ln = new WaveletTree(from, pivot, lo, mid);
+ rn = new WaveletTree(pivot, to, mid, hi);
+ }
+
+ // kth element in sort[l, r) all 0-indexed
+ ll kth(int l, int r, int k) {
+ if (l >= r || k >= r - l) return -1;
+ if (lo + 1 >= hi) return lo;
+ int inLeft = b[r] - b[l];
+ if (k < inLeft) return ln->kth(b[l], b[r], k);
+ else return rn->kth(l-b[l], r-b[r], k-inLeft);
+ }
+
+ // count elements in[l, r) smaller than k
+ int countSmaller(int l, int r, ll k) {
+ if (l >= r || k <= lo) return 0;
+ if (hi <= k) return r - l;
+ return ln->countSmaller(b[l], b[r], k) +
+ rn->countSmaller(l-b[l], r-b[r], k);
+ }
+
+ ~WaveletTree() {delete ln; delete rn;}
+};