summaryrefslogtreecommitdiff
path: root/content/math/linearRecurrence.cpp
diff options
context:
space:
mode:
authorGloria Mundi <gloria@gloria-mundi.eu>2024-11-16 15:39:23 +0100
committerGloria Mundi <gloria@gloria-mundi.eu>2024-11-16 15:39:23 +0100
commit72bd993483453ed8ebc462f1a33385cd355d486f (patch)
treec5592ba1ed2fed79e26ba6158d097c9ceb43f061 /content/math/linearRecurrence.cpp
parent98567ec798aa8ca2cfbcb85c774dd470f30e30d4 (diff)
parent35d485bcf6a9ed0a9542628ce4aa94a3326d0884 (diff)
merge mzuenni changes
Diffstat (limited to 'content/math/linearRecurrence.cpp')
-rw-r--r--content/math/linearRecurrence.cpp53
1 files changed, 25 insertions, 28 deletions
diff --git a/content/math/linearRecurrence.cpp b/content/math/linearRecurrence.cpp
index 2501e64..a8adacd 100644
--- a/content/math/linearRecurrence.cpp
+++ b/content/math/linearRecurrence.cpp
@@ -1,33 +1,30 @@
-constexpr ll mod = 1'000'000'007;
-vector<ll> modMul(const vector<ll>& a, const vector<ll>& b,
- const vector<ll>& c) {
- ll n = sz(c);
- vector<ll> res(n * 2 + 1);
- for (int i = 0; i <= n; i++) { //a*b
- for (int j = 0; j <= n; j++) {
- res[i + j] += a[i] * b[j];
- res[i + j] %= mod;
+constexpr ll mod = 998244353;
+// oder ntt mul @\sourceref{math/transforms/ntt.cpp}@
+vector<ll> mul(const vector<ll>& a, const vector<ll>& b) {
+ vector<ll> c(sz(a) + sz(b) - 1);
+ for (int i = 0; i < sz(a); i++) {
+ for (int j = 0; j < sz(b); j++) {
+ c[i+j] += a[i]*b[j] % mod;
}}
- for (int i = 2 * n; i > n; i--) { //res%c
- for (int j = 0; j < n; j++) {
- res[i - 1 - j] += res[i] * c[j];
- res[i - 1 - j] %= mod;
- }}
- res.resize(n + 1);
- return res;
+ for (ll& x : c) x %= mod;
+ return c;
}
ll kthTerm(const vector<ll>& f, const vector<ll>& c, ll k) {
- assert(sz(f) == sz(c));
- vector<ll> tmp(sz(c) + 1), a(sz(c) + 1);
- tmp[0] = a[1] = 1; //tmp = (x^k) % c
-
- for (k++; k > 0; k /= 2) {
- if (k & 1) tmp = modMul(tmp, a, c);
- a = modMul(a, a, c);
- }
-
- ll res = 0;
- for (int i = 0; i < sz(c); i++) res += (tmp[i+1] * f[i]) % mod;
- return res % mod;
+ int n = sz(c);
+ vector<ll> q(n + 1, 1);
+ for (int i = 0; i < n; i++) q[i + 1] = (mod - c[i]) % mod;
+ vector<ll> p = mul(f, q);
+ p.resize(n);
+ p.push_back(0);
+ do {
+ vector<ll> q2 = q;
+ for (int i = 1; i <= n; i += 2) q2[i] = (mod - q2[i]) % mod;
+ vector<ll> x = mul(p, q2), y = mul(q, q2);
+ for (int i = 0; i <= n; i++){
+ p[i] = i == n ? 0 : x[2*i + (k&1)];
+ q[i] = y[2*i];
+ }
+ } while (k /= 2);
+ return p[0];
}