diff options
Diffstat (limited to 'math/transforms')
| -rw-r--r-- | math/transforms/all.cpp | 62 | ||||
| -rw-r--r-- | math/transforms/andTransform.cpp | 21 | ||||
| -rw-r--r-- | math/transforms/bitwiseTransforms.cpp | 12 | ||||
| -rw-r--r-- | math/transforms/fft.cpp | 6 | ||||
| -rw-r--r-- | math/transforms/ntt.cpp | 43 | ||||
| -rw-r--r-- | math/transforms/orTransform.cpp | 21 | ||||
| -rw-r--r-- | math/transforms/xorTransform.cpp | 21 |
7 files changed, 53 insertions, 133 deletions
diff --git a/math/transforms/all.cpp b/math/transforms/all.cpp deleted file mode 100644 index ebed881..0000000 --- a/math/transforms/all.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/*constexpr ll mod = 998244353; @\hl{NTT only}@ -constexpr ll root = 3;*/ - -using cplx = complex<double>; - -//void fft(vector<ll> &a, bool inverse = 0) { @\hl{NTT, xor, or, and}@ -void fft(vector<cplx>& a, bool inverse = 0) { - int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); - } - static vector<cplx> ws(2, 1); - for (static int k = 2; k < n; k *= 2) { - ws.resize(n); - cplx w = polar(1.0, acos(-1.0) / k); - for (int i=k; i<2*k; i++) ws[i] = ws[i/2] * (i % 2 ? w : 1); - } - for (int s = 1; s < n; s *= 2) { - /*ll ws = powMod(root, (mod - 1) / s >> 1, mod); @\hl{NTT only}@ - if (inverse) ws = powMod(ws, mod - 2, mod);*/ - for (int j = 0; j < n; j += 2 * s) { - //ll w = 1; @\hl{NTT only}@ - for (int k = 0; k < s; k++) { - /*ll u = a[j + k], t = a[j + s + k] * w; @\hl{NTT only}@ - t %= mod; - a[j + k] = (u + t) % mod; - a[j + s + k] = (u - t + mod) % mod; - w = (w * ws) % mod;*/ - /*ll u = a[j + k], t = a[j + s + k]; @\hl{xor only}@ - a[j + k] = u + t; - a[j + s + k] = u - t;*/ - /*if (!inverse) { @\hl{or only}@ - a[j + k] = u + t; - a[j + s + k] = u; - } else { - a[j + k] = t; - a[j + s + k] = u - t; - }*/ - /*if (!inverse) { @\hl{and only}@ - a[j + k] = t; - a[j + s + k] = u + t; - } else { - a[j + k] = t - u; - a[j + s + k] = u; - }*/ - cplx u = a[j + k], t = a[j + s + k]; - t *= (inverse ? conj(ws[s + k]) : ws[s + k]); - a[j + k] = u + t; - a[j + s + k] = u - t; - if (inverse) a[j + k] /= 2, a[j + s + k] /= 2; - }}} - /*if (inverse) { @\hl{NTT only}@ - ll div = powMod(n, mod - 2, mod); - for (ll i = 0; i < n; i++) { - a[i] = (a[i] * div) % mod; - }}*/ - /*if (inverse) { @\hl{xor only}@ - for (ll i = 0; i < n; i++) { - a[i] /= n; - }}*/ -} diff --git a/math/transforms/andTransform.cpp b/math/transforms/andTransform.cpp index ab2bd40..1fd9f5c 100644 --- a/math/transforms/andTransform.cpp +++ b/math/transforms/andTransform.cpp @@ -1,17 +1,8 @@ -void fft(vector<cplx>& a, bool inverse = 0) { +void fft(vector<ll>& a, bool inv = false) { int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); - } for (int s = 1; s < n; s *= 2) { - for (int j = 0; j < n; j+= 2 * s) { - for (int k = 0; k < s; k++) { - ll u = a[j + k], t = a[j + s + k]; - if (!inverse) { - a[j + k] = t; - a[j + s + k] = u + t; - } else { - a[j + k] = t - u; - a[j + s + k] = u; -}}}}} + for (int i = 0; i < n; i += 2 * s) { + for (int j = i; j < i + s; j++) { + ll& u = a[j], &v = a[j + s]; + tie(u, v) = inv ? pair(v - u, u) : pair(v, u + v); +}}}} diff --git a/math/transforms/bitwiseTransforms.cpp b/math/transforms/bitwiseTransforms.cpp new file mode 100644 index 0000000..7d1f80d --- /dev/null +++ b/math/transforms/bitwiseTransforms.cpp @@ -0,0 +1,12 @@ +void fft(vector<ll>& a, bool inv = false) { + int n = sz(a); + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < n; i += 2 * s) { + for (int j = i; j < i + s; j++) { + ll& u = a[j], &v = a[j + s]; + tie(u, v) = inv ? pair(v - u, u) : pair(v, u + v); // AND + //tie(u, v) = inv ? pair(v, u - v) : pair(u + v, u); //OR + //tie(u, v) = pair(u + v, u - v); // XOR + }}} + //if (inv) for (ll& x : a) x /= n; // XOR (careful with MOD) +} diff --git a/math/transforms/fft.cpp b/math/transforms/fft.cpp index 53a2d8d..2bd95b2 100644 --- a/math/transforms/fft.cpp +++ b/math/transforms/fft.cpp @@ -1,6 +1,6 @@ using cplx = complex<double>; -void fft(vector<cplx>& a, bool inverse = 0) { +void fft(vector<cplx>& a, bool inv = false) { int n = sz(a); for (int i = 0, j = 1; j < n - 1; ++j) { for (int k = n >> 1; k > (i ^= k); k >>= 1); @@ -16,8 +16,8 @@ void fft(vector<cplx>& a, bool inverse = 0) { for (int j = 0; j < n; j += 2 * s) { for (int k = 0; k < s; k++) { cplx u = a[j + k], t = a[j + s + k]; - t *= (inverse ? conj(ws[s + k]) : ws[s + k]); + t *= (inv ? conj(ws[s + k]) : ws[s + k]); a[j + k] = u + t; a[j + s + k] = u - t; - if (inverse) a[j + k] /= 2, a[j + s + k] /= 2; + if (inv) a[j + k] /= 2, a[j + s + k] /= 2; }}}} diff --git a/math/transforms/ntt.cpp b/math/transforms/ntt.cpp index e1e4588..18d5bd8 100644 --- a/math/transforms/ntt.cpp +++ b/math/transforms/ntt.cpp @@ -1,28 +1,23 @@ -constexpr ll mod = 998244353; -constexpr ll root = 3; +constexpr ll mod = 998244353, root = 3; -void fft(vector<ll>& a, bool inverse = 0) { +void fft(vector<ll>& a, bool inv = false) { int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); + auto b = a; + ll r = inv ? powMod(root, mod - 2, mod) : root; + + for (int s = n / 2; s > 0; s /= 2) { + ll ws = powMod(r, (mod - 1) / (n / s), mod), w = 1; + for (int j = 0; j < n / 2; j += s) { + for (int k = j; k < j + s; k++) { + ll u = a[j + k], t = a[j + s + k] * w % mod; + b[k] = (u + t) % mod; + b[n/2 + k] = (u - t + mod) % mod; + } + w = w * ws % mod; + } + swap(a, b); } - for (int s = 1; s < n; s *= 2) { - ll ws = powMod(root, (mod - 1) / s >> 1, mod); - if (inverse) ws = powMod(ws, mod - 2, mod); - for (int j = 0; j < n; j+= 2 * s) { - ll w = 1; - for (int k = 0; k < s; k++) { - ll u = a[j + k], t = a[j + s + k] * w; - t %= mod; - a[j + k] = (u + t) % mod; - a[j + s + k] = (u - t + mod) % mod; - w *= ws; - w %= mod; - }}} - if (inverse) { + if (inv) { ll div = powMod(n, mod - 2, mod); - for (ll i = 0; i < n; i++) { - a[i] *= div; - a[i] %= mod; -}}} + for (auto& x : a) x = x * div % mod; +}} diff --git a/math/transforms/orTransform.cpp b/math/transforms/orTransform.cpp index fdb5bb8..eb1da44 100644 --- a/math/transforms/orTransform.cpp +++ b/math/transforms/orTransform.cpp @@ -1,17 +1,8 @@ -void fft(vector<ll>& a, bool inverse = 0) { +void fft(vector<ll>& a, bool inv = false) { int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); - } for (int s = 1; s < n; s *= 2) { - for (int j = 0; j < n; j+= 2 * s) { - for (int k = 0; k < s; k++) { - ll u = a[j + k], t = a[j + s + k]; - if (!inverse) { - a[j + k] = u + t; - a[j + s + k] = u; - } else { - a[j + k] = t; - a[j + s + k] = u - t; -}}}}} + for (int i = 0; i < n; i += 2 * s) { + for (int j = i; j < i + s; j++) { + ll& u = a[j], &v = a[j + s]; + tie(u, v) = inv ? pair(v, u - v) : pair(u + v, u); +}}}} diff --git a/math/transforms/xorTransform.cpp b/math/transforms/xorTransform.cpp index 48e4df2..f9d1d82 100644 --- a/math/transforms/xorTransform.cpp +++ b/math/transforms/xorTransform.cpp @@ -1,17 +1,10 @@ -void fft(vector<ll>& a, bool inverse = 0) { +void fft(vector<ll>& a, bool inv = false) { int n = sz(a); - for (int i = 0, j = 1; j < n - 1; ++j) { - for (int k = n >> 1; k > (i ^= k); k >>= 1); - if (j < i) swap(a[i], a[j]); - } for (int s = 1; s < n; s *= 2) { - for (int j = 0; j < n; j+= 2 * s) { - for (int k = 0; k < s; k++) { - ll u = a[j + k], t = a[j + s + k]; - a[j + k] = u + t; - a[j + s + k] = u - t; + for (int i = 0; i < n; i += 2 * s) { + for (int j = i; j < i + s; j++) { + ll& u = a[j], &v = a[j + s]; + tie(u, v) = pair(u + v, u - v); }}} - if (inverse) { - for (ll i = 0; i < n; i++) { - a[i] /= n; -}}} + if (inv) for (ll& x : a) x /= n; +} |
