summaryrefslogtreecommitdiff
path: root/test/math/lgsFp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'test/math/lgsFp.cpp')
-rw-r--r--test/math/lgsFp.cpp42
1 files changed, 17 insertions, 25 deletions
diff --git a/test/math/lgsFp.cpp b/test/math/lgsFp.cpp
index d7680ea..376a067 100644
--- a/test/math/lgsFp.cpp
+++ b/test/math/lgsFp.cpp
@@ -1,6 +1,5 @@
#include "../util.h"
#include <math/shortModInv.cpp>
-vector<vector<ll>> mat;
constexpr ll mod = 1'000'000'007;
namespace lgs {
#include <math/lgsFp.cpp>
@@ -8,30 +7,26 @@ namespace lgs {
vector<vector<ll>> inverseMat(const vector<vector<ll>>& m) {
- int n = sz(m);
- mat = m;
+ int n = ssize(m);
+ vector<vector<ll>> mat = m;
for (int i = 0; i < n; i++) {
- if (sz(mat[i]) != n) cerr << "error: no square matrix" << FAIL;
+ if (ssize(mat[i]) != n) cerr << "error: no square matrix" << FAIL;
mat[i].resize(2*n);
mat[i][n+i] = 1;
}
- lgs::gauss(sz(mat), sz(mat[0]));
- vector<vector<ll>> res(m);
+ vector<int> pivots = lgs::gauss(mat);
for (int i = 0; i < n; i++) {
- res[i] = vector<ll>(mat[i].begin() + n, mat[i].end());
- for (int j = 0; j < n; j++) {
- if (j != i && mat[i][j] != 0) cerr << "error: not full rank?" << FAIL;
- if (j == i && mat[i][j] != 1) cerr << "error: not full rank?" << FAIL;
- }
+ if (pivots[i] != i) cerr << "error: not full rank?" << FAIL;
+ mat[i].erase(begin(mat[i]), begin(mat[i]) + n);
}
- return res;
+ return mat;
}
vector<vector<ll>> mul(const vector<vector<ll>>& a, const vector<vector<ll>>& b) {
- int n = sz(a);
- int m = sz(b[0]);
- int x = sz(b);
- if (sz(a[0]) != sz(b)) cerr << "error: wrong dimensions" << FAIL;
+ int n = ssize(a);
+ int m = ssize(b[0]);
+ int x = ssize(b);
+ if (ssize(a[0]) != ssize(b)) cerr << "error: wrong dimensions" << FAIL;
vector<vector<ll>> res(n, vector<ll>(m));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
@@ -53,18 +48,17 @@ void test_square() {
vector<vector<ll>> m(n);
for (auto& v : m) v = Random::integers<ll>(n, 0, mod);
- mat = m;
- lgs::gauss(sz(mat), sz(mat[0]));
+ lgs::gauss(m);
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
- hash += mat[i][j];
+ hash += m[i][j];
}
}
queries += n;
}
- cerr << "tested sqaures: " << queries << " (hash: " << hash << ")" << endl;;
+ cerr << "tested squares: " << queries << " (hash: " << hash << ")" << endl;;
}
void stress_test_inv() {
@@ -82,8 +76,7 @@ void stress_test_inv() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
- if (i == j && prod[i][j] != 1) cerr << "error: not inverted" << FAIL;
- if (i != j && prod[i][j] != 0) cerr << "error: not inverted" << FAIL;
+ if (prod[i][j] != (i == j)) cerr << "error: not inverted" << FAIL;
}
}
@@ -98,15 +91,14 @@ void performance_test() {
vector<vector<ll>> m(N);
for (auto& v : m) v = Random::integers<ll>(N, 0, mod);
- mat = m;
t.start();
- lgs::gauss(sz(mat), sz(mat[0]));
+ lgs::gauss(m);
t.stop();
hash_t hash = 0;
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
- hash += mat[i][j];
+ hash += m[i][j];
}
}
if (t.time > 500) cerr << "too slow: " << t.time << FAIL;