diff options
Diffstat (limited to 'math/transforms/ntt.cpp')
| -rw-r--r-- | math/transforms/ntt.cpp | 43 |
1 files changed, 19 insertions, 24 deletions
diff --git a/math/transforms/ntt.cpp b/math/transforms/ntt.cpp index e1e4588..18d5bd8 100644 --- a/math/transforms/ntt.cpp +++ b/math/transforms/ntt.cpp @@ -1,28 +1,23 @@ -constexpr ll mod = 998244353; -constexpr ll root = 3; +constexpr ll mod = 998244353, root = 3; -void fft(vector<ll>& a, bool inverse = 0) { +void fft(vector<ll>& a, bool inv = false) { int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); + auto b = a; + ll r = inv ? powMod(root, mod - 2, mod) : root; + + for (int s = n / 2; s > 0; s /= 2) { + ll ws = powMod(r, (mod - 1) / (n / s), mod), w = 1; + for (int j = 0; j < n / 2; j += s) { + for (int k = j; k < j + s; k++) { + ll u = a[j + k], t = a[j + s + k] * w % mod; + b[k] = (u + t) % mod; + b[n/2 + k] = (u - t + mod) % mod; + } + w = w * ws % mod; + } + swap(a, b); } - for (int s = 1; s < n; s *= 2) { - ll ws = powMod(root, (mod - 1) / s >> 1, mod); - if (inverse) ws = powMod(ws, mod - 2, mod); - for (int j = 0; j < n; j+= 2 * s) { - ll w = 1; - for (int k = 0; k < s; k++) { - ll u = a[j + k], t = a[j + s + k] * w; - t %= mod; - a[j + k] = (u + t) % mod; - a[j + s + k] = (u - t + mod) % mod; - w *= ws; - w %= mod; - }}} - if (inverse) { + if (inv) { ll div = powMod(n, mod - 2, mod); - for (ll i = 0; i < n; i++) { - a[i] *= div; - a[i] %= mod; -}}} + for (auto& x : a) x = x * div % mod; +}} |
