From c93c7956ebf2d572ea8fb341a636b60e6f97ad84 Mon Sep 17 00:00:00 2001 From: Noobie99 Date: Mon, 3 Apr 2023 22:44:38 +0200 Subject: minor changes wavelet tree --- datastructures/waveletTree.cpp | 86 ++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 46 deletions(-) (limited to 'datastructures') diff --git a/datastructures/waveletTree.cpp b/datastructures/waveletTree.cpp index 36c1b56..4860a40 100644 --- a/datastructures/waveletTree.cpp +++ b/datastructures/waveletTree.cpp @@ -1,46 +1,40 @@ -struct WaveletTree { - using it = vector::iterator; - WaveletTree *ln, *rn; - ll lo, hi; - vector b; -private: - WaveletTree(it from, it to, ll x, ll y) - : ln(nullptr), rn(nullptr), lo(x), hi(y), b(1) { - ll mid = (lo + hi) / 2; - auto f = [&](ll x){return x < mid;}; - for (it c = from; c != to; c++) { - b.push_back(b.back() + f(*c)); - } - if (lo + 1 >= hi || from == to) return; - it pivot = stable_partition(from, to, f); - ln = new WaveletTree(from, pivot, lo, mid); - rn = new WaveletTree(pivot, to, mid, hi); - } -public: - WaveletTree(vector in) : WaveletTree(all(in), - *min_element(all(in)), *max_element(all(in)) + 1){} - - // kth element in sort[l, r) all 0-indexed - ll kth(int l, int r, int k) { - if (l >= r || k >= r - l) return -1; - if (lo + 1 >= hi) return lo; - int inLeft = b[r] - b[l]; - if (k < inLeft) { - return ln->kth(b[l], b[r], k); - } else { - return rn->kth(l-b[l], r-b[r], k-inLeft); - }} - - // count elements in[l, r) smaller than k - int countSmaller(int l, int r, ll k) { - if (l >= r || k <= lo) return 0; - if (hi <= k) return r - l; - return ln->countSmaller(b[l], b[r], k) + - rn->countSmaller(l-b[l], r-b[r], k); - } - - ~WaveletTree(){ - delete ln; - delete rn; - } -}; +struct WaveletTree { + using it = vector::iterator; + WaveletTree *ln = nullptr, *rn = nullptr; + vector b = {0}; + ll lo, hi; + + WaveletTree(vector a) : WaveletTree(all(a), // call this + *min_element(all(a)), *max_element(all(a)) + 1) {} + + WaveletTree(it from, it to, ll x, ll y) : lo(x), hi(y) { + if (lo + 1 >= hi || from == to) return; + ll mid = (lo + hi) / 2; + auto f = [&](ll x){return x < mid;}; + for (it c = from; c != to; c++) { + b.push_back(b.back() + f(*c)); + } + it pivot = stable_partition(from, to, f); + ln = new WaveletTree(from, pivot, lo, mid); + rn = new WaveletTree(pivot, to, mid, hi); + } + + // kth element in sort[l, r) all 0-indexed + ll kth(int l, int r, int k) { + if (l >= r || k >= r - l) return -1; + if (lo + 1 >= hi) return lo; + int inLeft = b[r] - b[l]; + if (k < inLeft) return ln->kth(b[l], b[r], k); + else return rn->kth(l-b[l], r-b[r], k-inLeft); + } + + // count elements in[l, r) smaller than k + int countSmaller(int l, int r, ll k) { + if (l >= r || k <= lo) return 0; + if (hi <= k) return r - l; + return ln->countSmaller(b[l], b[r], k) + + rn->countSmaller(l-b[l], r-b[r], k); + } + + ~WaveletTree() {delete ln; delete rn;} +}; -- cgit v1.2.3