summaryrefslogtreecommitdiff
path: root/datastructures
diff options
context:
space:
mode:
Diffstat (limited to 'datastructures')
-rw-r--r--datastructures/segmentTree.cpp4
-rw-r--r--datastructures/test/segmentTree.cpp30
-rw-r--r--datastructures/test/segmentTree2.cpp26
3 files changed, 58 insertions, 2 deletions
diff --git a/datastructures/segmentTree.cpp b/datastructures/segmentTree.cpp
index 79c6cae..2cbf466 100644
--- a/datastructures/segmentTree.cpp
+++ b/datastructures/segmentTree.cpp
@@ -4,9 +4,9 @@ struct SegTree {
vector<T> tree;
static constexpr T E = 0; // Neutral element for combine
- SegTree(vector<T>& a) : n(sz(a)), tree(2 * n) {
+ SegTree(vector<T>& a) : n(ssize(a)), tree(2 * n) {
//SegTree(int size, T val = E) : n(size), tree(2 * n, val) {
- copy(all(a), tree.begin() + n);
+ ranges::copy(a, tree.begin() + n);
for (int i = n - 1; i > 0; i--) { // remove for range update
tree[i] = comb(tree[2 * i], tree[2 * i + 1]);
}}
diff --git a/datastructures/test/segmentTree.cpp b/datastructures/test/segmentTree.cpp
new file mode 100644
index 0000000..79c16e6
--- /dev/null
+++ b/datastructures/test/segmentTree.cpp
@@ -0,0 +1,30 @@
+#include "segmentTree.tmp.cpp"
+
+void test(int n) {
+ vector<ll> a(n);
+ for (ll &x: a) x = util::randint();
+ SegTree seg(a);
+ for (int i = 0; i < 5*n; i++) {
+ {
+ int j = util::randint(n);
+ ll v = util::randint();
+ a[j] = v;
+ seg.update(j, v);
+ }
+ {
+ int l = util::randint(n+1);
+ int r = util::randint(n+1);
+ if (l > r) swap(l, r);
+ assert(
+ seg.query(l, r)
+ ==
+ accumulate(a.begin() + l, a.begin() + r, 0ll)
+ );
+ }
+ }
+}
+
+int main() {
+ test(1000);
+ test(1);
+}
diff --git a/datastructures/test/segmentTree2.cpp b/datastructures/test/segmentTree2.cpp
new file mode 100644
index 0000000..f403a1d
--- /dev/null
+++ b/datastructures/test/segmentTree2.cpp
@@ -0,0 +1,26 @@
+#include "segmentTree2.tmp.cpp"
+
+void test(int n) {
+ vector<ll> a(n);
+ for (ll &x: a) x = util::randint();
+ SegTree seg(a);
+ for (int i = 0; i < 5*n; i++) {
+ {
+ int l = util::randint(n+1);
+ int r = util::randint(n+1);
+ if (l > r) swap(l, r);
+ ll v = util::randint();
+ for (int i = l; i < r; i++) a[i] += v;
+ seg.modify(l, r, v);
+ }
+ {
+ int j = util::randint(n);
+ assert(seg.query(j) == a[j]);
+ }
+ }
+}
+
+int main() {
+ test(1000);
+ test(1);
+}