summaryrefslogtreecommitdiff
path: root/math/transforms/all.cpp
diff options
context:
space:
mode:
authorNoobie99 <noob999noob999@gmail.com>2023-05-20 14:38:47 +0200
committerNoobie99 <noob999noob999@gmail.com>2023-05-20 14:38:47 +0200
commit416962c482e7d6dcf3efd8739dd888bf9bb62288 (patch)
treea59fdc78db798f0361b6751ddfdeda03e2139ba2 /math/transforms/all.cpp
parentc778ffb21c86fde0e44ae5f300994eedc52bd23d (diff)
increase fft precision
Diffstat (limited to 'math/transforms/all.cpp')
-rw-r--r--math/transforms/all.cpp17
1 files changed, 10 insertions, 7 deletions
diff --git a/math/transforms/all.cpp b/math/transforms/all.cpp
index 66e6a41..ebed881 100644
--- a/math/transforms/all.cpp
+++ b/math/transforms/all.cpp
@@ -5,19 +5,22 @@ using cplx = complex<double>;
//void fft(vector<ll> &a, bool inverse = 0) { @\hl{NTT, xor, or, and}@
void fft(vector<cplx>& a, bool inverse = 0) {
- int n = a.size();
+ int n = sz(a);
for (int i = 0, j = 1; j < n - 1; ++j) {
for (int k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) swap(a[i], a[j]);
}
+ static vector<cplx> ws(2, 1);
+ for (static int k = 2; k < n; k *= 2) {
+ ws.resize(n);
+ cplx w = polar(1.0, acos(-1.0) / k);
+ for (int i=k; i<2*k; i++) ws[i] = ws[i/2] * (i % 2 ? w : 1);
+ }
for (int s = 1; s < n; s *= 2) {
/*ll ws = powMod(root, (mod - 1) / s >> 1, mod); @\hl{NTT only}@
if (inverse) ws = powMod(ws, mod - 2, mod);*/
- double angle = PI / s * (inverse ? -1 : 1);
- cplx ws(cos(angle), sin(angle));
- for (int j = 0; j < n; j+= 2 * s) {
+ for (int j = 0; j < n; j += 2 * s) {
//ll w = 1; @\hl{NTT only}@
- cplx w = 1;
for (int k = 0; k < s; k++) {
/*ll u = a[j + k], t = a[j + s + k] * w; @\hl{NTT only}@
t %= mod;
@@ -41,11 +44,11 @@ void fft(vector<cplx>& a, bool inverse = 0) {
a[j + k] = t - u;
a[j + s + k] = u;
}*/
- cplx u = a[j + k], t = a[j + s + k] * w;
+ cplx u = a[j + k], t = a[j + s + k];
+ t *= (inverse ? conj(ws[s + k]) : ws[s + k]);
a[j + k] = u + t;
a[j + s + k] = u - t;
if (inverse) a[j + k] /= 2, a[j + s + k] /= 2;
- w *= ws;
}}}
/*if (inverse) { @\hl{NTT only}@
ll div = powMod(n, mod - 2, mod);