summaryrefslogtreecommitdiff
path: root/datastructures/lazyPropagation.cpp
diff options
context:
space:
mode:
authorNoobie99 <noob999noob999@gmail.com>2023-02-16 20:34:47 +0100
committerNoobie99 <noob999noob999@gmail.com>2023-02-16 20:34:47 +0100
commit3acf0a0820da8c7357f381ac3c2a4be3bee08184 (patch)
tree3a8363d3a66c508fda16c9be38114656024d0f63 /datastructures/lazyPropagation.cpp
parent48f36b5c91b3dd4a5c43390ce14f1a7e05174929 (diff)
Improved Lazy Segment Tree
Diffstat (limited to 'datastructures/lazyPropagation.cpp')
-rw-r--r--datastructures/lazyPropagation.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/datastructures/lazyPropagation.cpp b/datastructures/lazyPropagation.cpp
new file mode 100644
index 0000000..4817b16
--- /dev/null
+++ b/datastructures/lazyPropagation.cpp
@@ -0,0 +1,85 @@
+struct SegTree {
+ int size, height;
+ static constexpr ll neutral = 0; // Neutral element for combine
+ static constexpr ll updateFlag = 0; // Unused value by updates
+ vector<ll> tree, lazy;
+
+ SegTree(const vector<ll>& a) : SegTree(sz(a)) {
+ copy(all(a), tree.begin() + size);
+ for (int i = size - 1; i > 0; i--)
+ tree[i] = combine(tree[2 * i], tree[2 * i + 1]);
+ }
+
+ SegTree(int n) : size(n), height(__lg(2 * n)),
+ tree(2 * n, neutral), lazy(n, updateFlag) {}
+
+ ll combine(ll a, ll b) {return a + b;} // Modify this + neutral
+
+ void apply(int i, ll val, int k) { // And this + updateFlag
+ tree[i] = val * k;
+ if (i < size) lazy[i] = val; // Don't forget this
+ }
+
+ void push_down(int i, int k) {
+ if (lazy[i] != updateFlag) {
+ apply(2 * i, lazy[i], k);
+ apply(2 * i + 1, lazy[i], k);
+ lazy[i] = updateFlag;
+ }}
+
+ void push(int i) {
+ for (int s = height, k = 1 << (height-1); s > 0; s--, k /= 2)
+ push_down(i >> s, k);
+ }
+
+ void build(int i) {
+ for (int k = 2; i /= 2; k *= 2) {
+ push_down(i, k / 2);
+ tree[i] = combine(tree[2 * i], tree[2 * i + 1]);
+ }}
+
+ void update(int l, int r, ll val) { // data[l..r) = val
+ l += size, r += size;
+ int l0 = l, r0 = r;
+ push(l0), push(r0 - 1);
+ for (int k = 1; l < r; l /= 2, r /= 2, k *= 2) {
+ if (l&1) apply(l++, val, k);
+ if (r&1) apply(--r, val, k);
+ }
+ build(l0), build(r0 - 1);
+ }
+
+ ll query(int l, int r) { // sum[l..r)
+ l += size, r += size;
+ push(l), push(r - 1);
+ ll resL = neutral, resR = neutral;
+ for (; l < r; l /= 2, r /= 2) {
+ if (l&1) resL = combine(resL, tree[l++]);
+ if (r&1) resR = combine(tree[--r], resR);
+ }
+ return combine(resL, resR);
+ }
+
+ // Optional:
+ ll find_first(int l, int r, int x) {
+ l += size, r += size;
+ push(l), push(r - 1);
+ vector<pair<int, int>> a; stack<pair<int, int>> st;
+ for (int k = 1; l < r; l /= 2, r /= 2, k *= 2) {
+ if (l&1) a.emplace_back(l++, k);
+ if (r&1) st.emplace(--r, k);
+ }
+ while (!st.empty()) a.push_back(st.top()), st.pop();
+ for (auto [i, k] : a) {
+ if (tree[i] >= x) return find(i, x, k); // Modify this
+ }
+ return -1;
+ }
+
+ ll find(int i, int x, int k) {
+ if (i >= size) return i - size;
+ push_down(i, k / 2);
+ if (tree[2*i] >= x) return find(2 * i, x, k / 2); // And this
+ else return find(2 * i + 1, x, k / 2);
+ }
+};