From 5ab8a5088b729a9953b8dff1b2a985dc8fb2098b Mon Sep 17 00:00:00 2001 From: mzuenni Date: Mon, 27 Jun 2022 17:19:28 +0200 Subject: updated tcr --- math/bigint.cpp | 514 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 275 insertions(+), 239 deletions(-) (limited to 'math/bigint.cpp') diff --git a/math/bigint.cpp b/math/bigint.cpp index e25ebe3..1753200 100644 --- a/math/bigint.cpp +++ b/math/bigint.cpp @@ -1,240 +1,276 @@ -// Bislang keine Division. Multiplikation nach Schulmethode. -#define PLUS 0 -#define MINUS 1 -#define BASE 1000000000 -#define EXPONET 9 - +// base and base_digits must be consistent +constexpr ll base = 1000000; +constexpr ll base_digits = 6; struct bigint { - int sign; - vector digits; - - // Initialisiert mit 0. - bigint(void) { sign = PLUS; } - - // Initialisiert mit kleinem Wert. - bigint(ll value) { - if (value == 0) sign = PLUS; - else { - sign = value >= 0 ? PLUS : MINUS; - value = abs(value); - while (value) { - digits.push_back(value % BASE); - value /= BASE; - }}} - - // Initialisiert mit C-String. Kann nicht mit Vorzeichen umgehen. - bigint(char *str, int length) { - int base = 1; - ll digit = 0; - for (int i = length - 1; i >= 0; i--) { - digit += base * (str[i] - '0'); - if (base * 10 == BASE) { - digits.push_back(digit); - digit = 0; - base = 1; - } else base *= 10; - } - if (digit != 0) digits.push_back(digit); - sign = PLUS; - } - - // Löscht führende Nullen und macht -0 zu 0. - void trim() { - while (digits.size() > 0 && digits[digits.size() - 1] == 0) - digits.pop_back(); - if (digits.size() == 0 && sign == MINUS) sign = PLUS; - } - - // Gibt die Zahl aus. - void print() { - if (digits.size() == 0) { printf("0"); return; } - if (sign == MINUS) printf("-"); - printf("%lld", digits[digits.size() - 1]); - for (int i = digits.size() - 2; i >= 0; i--) { - printf("%09lld", digits[i]); // Anpassen, wenn andere Basis gewählt wird. - }} -}; - -// Kleiner-oder-gleich-Vergleich. -bool operator<=(bigint &a, bigint &b) { - if (a.digits.size() == b.digits.size()) { - int idx = a.digits.size() - 1; - while (idx >= 0) { - if (a.digits[idx] < b.digits[idx]) return true; - else if (a.digits[idx] > b.digits[idx]) return false; - idx--; - } - return true; - } - return a.digits.size() < b.digits.size(); -} - -// Kleiner-Vergeleich. -bool operator<(bigint &a, bigint &b) { - if (a.digits.size() == b.digits.size()) { - int idx = a.digits.size() - 1; - while (idx >= 0) { - if (a.digits[idx] < b.digits[idx]) return true; - else if (a.digits[idx] > b.digits[idx]) return false; - idx--; - } - return false; - } - return a.digits.size() < b.digits.size(); -} - -void sub(bigint *a, bigint *b, bigint *c); - -// a + b = c. a, b, c dürfen gleich sein. -void add(bigint *a, bigint *b, bigint *c) { - if (a->sign == b->sign) c->sign = a->sign; - else { - if (a->sign == MINUS) { - a->sign ^= 1; - sub(b, a, c); - a->sign ^= 1; - } else { - b->sign ^= 1; - sub(a, b, c); - b->sign ^= 1; - } - return; - } - - c->digits.resize(max(a->digits.size(), b->digits.size())); - ll carry = 0; - int i = 0; - for (; i < (int)min(a->digits.size(), b->digits.size()); i++) { - ll sum = carry + a->digits[i] + b->digits[i]; - c->digits[i] = sum % BASE; - carry = sum / BASE; - } - if (i < (int)a->digits.size()) { - for (; i< (int)a->digits.size(); i++) { - ll sum = carry + a->digits[i]; - c->digits[i] = sum % BASE; - carry = sum / BASE; - } - } else { - for (; i< (int)b->digits.size(); i++) { - ll sum = carry + b->digits[i]; - c->digits[i] = sum % BASE; - carry = sum / BASE; - }} - if (carry) c->digits.push_back(carry); -} - -// a - b = c. c darf a oder b sein. a und b müssen verschieden sein. -void sub(bigint *a, bigint *b, bigint *c) { - if (a->sign == MINUS || b->sign == MINUS) { - b->sign ^= 1; - add(a, b, c); - b->sign ^= 1; - return; - } - - if (a < b) { - sub(b, a, c); - c->sign = MINUS; - c->trim(); - return; - } - - c->digits.resize(a->digits.size()); - ll borrow = 0; - int i = 0; - for (; i < (int)b->digits.size(); i++) { - ll diff = a->digits[i] - borrow - b->digits[i]; - if (a->digits[i] > 0) borrow = 0; - if (diff < 0) { - diff += BASE; - borrow = 1; - } - c->digits[i] = diff % BASE; - } - for (; i < (int)a->digits.size(); i++) { - ll diff = a->digits[i] - borrow; - if (a->digits[i] > 0) borrow = 0; - if (diff < 0) { - diff += BASE; - borrow = 1; - } - c->digits[i] = diff % BASE; - } - c->trim(); -} - -// Ziffernmultiplikation a * b = c. b und c dürfen gleich sein. -// a muss kleiner BASE sein. -void digitMul(ll a, bigint *b, bigint *c) { - if (a == 0) { - c->digits.clear(); - c->sign = PLUS; - return; - } - c->digits.resize(b->digits.size()); - ll carry = 0; - for (int i = 0; i < (int)b->digits.size(); i++) { - ll prod = carry + b->digits[i] * a; - c->digits[i] = prod % BASE; - carry = prod / BASE; - } - if (carry) c->digits.push_back(carry); - c->sign = (a > 0) ? b->sign : 1 ^ b->sign; - c->trim(); -} - -// Zifferndivision b / a = c. b und c dürfen gleich sein. -// a muss kleiner BASE sein. -void digitDiv(ll a, bigint *b, bigint *c) { - c->digits.resize(b->digits.size()); - ll carry = 0; - for (int i = (int)b->digits.size() - 1; i>= 0; i--) { - ll quot = (carry * BASE + b->digits[i]) / a; - carry = carry * BASE + b->digits[i] - quot * a; - c->digits[i] = quot; - } - c->sign = b->sign ^ (a < 0); - c->trim(); -} - -// a * b = c. c darf weder a noch b sein. a und b dürfen gleich sein. -void mult(bigint *a, bigint *b, bigint *c) { - bigint row = *a, tmp; - c->digits.clear(); - for (int i = 0; i < (int)b->digits.size(); i++) { - digitMul(b->digits[i], &row, &tmp); - add(&tmp, c, c); - row.digits.insert(row.digits.begin(), 0); - } - c->sign = a->sign != b->sign; - c->trim(); -} - -// Berechnet eine kleine Zehnerpotenz. -inline ll pow10(int n) { - ll res = 1; - for (int i = 0; i < n; i++) res *= 10; - return res; -} - -// Berechnet eine große Zehnerpotenz. -void power10(ll e, bigint *out) { - out->digits.assign(e / EXPONET + 1, 0); - if (e % EXPONET) - out->digits[out->digits.size() - 1] = pow10(e % EXPONET); - else out->digits[out->digits.size() - 1] = 1; -} - -// Nimmt eine Zahl module einer Zehnerpotenz 10^e. -void mod10(int e, bigint *a) { - int idx = e / EXPONET; - if ((int)a->digits.size() < idx + 1) return; - if (e % EXPONET) { - a->digits.resize(idx + 1); - a->digits[idx] %= pow10(e % EXPONET); - } else { - a->digits.resize(idx); - } - a->trim(); -} + vll a; ll sign; + + bigint() : sign(1) {} + + bigint(ll v) {*this = v;} + + bigint(const string &s) {read(s);} + + void operator=(const bigint& v) { + sign = v.sign; + a = v.a; + } + + void operator=(ll v) { + sign = 1; + if (v < 0) sign = -1, v = -v; + a.clear(); + for (; v > 0; v = v / base) + a.push_back(v % base); + } + + bigint operator+(const bigint& v) const { + if (sign == v.sign) { + bigint res = v; + for (ll i = 0, carry = 0; i < (ll)max(a.size(), v.a.size()) || carry; ++i) { + if (i == (ll)res.a.size()) + res.a.push_back(0); + res.a[i] += carry + (i < (ll)a.size() ? a[i] : 0); + carry = res.a[i] >= base; + if (carry) + res.a[i] -= base; + } + return res; + } + return *this - (-v); + } + + bigint operator-(const bigint& v) const { + if (sign == v.sign) { + if (abs() >= v.abs()) { + bigint res = *this; + for (ll i = 0, carry = 0; i < (ll)v.a.size() || carry; ++i) { + res.a[i] -= carry + (i < (ll)v.a.size() ? v.a[i] : 0); + carry = res.a[i] < 0; + if (carry) res.a[i] += base; + } + res.trim(); + return res; + } + return -(v - *this); + } + return *this + (-v); + } + + void operator*=(ll v) { + if (v < 0) sign = -sign, v = -v; + for (ll i = 0, carry = 0; i < (ll)a.size() || carry; ++i) { + if (i == (ll)a.size()) a.push_back(0); + ll cur = a[i] * v + carry; + carry = cur / base; + a[i] = cur % base; + } + trim(); + } + + bigint operator*(ll v) const { + bigint res = *this; + res *= v; + return res; + } + + friend pair divmod(const bigint& a1, const bigint& b1) { + ll norm = base / (b1.a.back() + 1); + bigint a = a1.abs() * norm; + bigint b = b1.abs() * norm; + bigint q, r; + q.a.resize(a.a.size()); + for (ll i = (ll)a.a.size() - 1; i >= 0; i--) { + r *= base; + r += a.a[i]; + ll s1 = r.a.size() <= b.a.size() ? 0 : r.a[b.a.size()]; + ll s2 = r.a.size() <= b.a.size() - 1 ? 0 : r.a[b.a.size() - 1]; + ll d = (base * s1 + s2) / b.a.back(); + r -= b * d; + while (r < 0) + r += b, --d; + q.a[i] = d; + } + q.sign = a1.sign * b1.sign; + r.sign = a1.sign; + q.trim(); + r.trim(); + return make_pair(q, r / norm); + } + + bigint operator/(const bigint& v) const { + return divmod(*this, v).first; + } + + bigint operator%(const bigint& v) const { + return divmod(*this, v).second; + } + + void operator/=(ll v) { + if (v < 0) sign = -sign, v = -v; + for (ll i = (ll)a.size() - 1, rem = 0; i >= 0; --i) { + ll cur = a[i] + rem * base; + a[i] = cur / v; + rem = cur % v; + } + trim(); + } + + bigint operator/(ll v) const { + bigint res = *this; + res /= v; + return res; + } + + ll operator%(ll v) const { + if (v < 0) v = -v; + ll m = 0; + for (ll i = (ll)a.size() - 1; i >= 0; --i) + m = (a[i] + m * base) % v; + return m * sign; + } + + void operator+=(const bigint& v) { + *this = *this + v; + } + void operator-=(const bigint& v) { + *this = *this - v; + } + void operator*=(const bigint& v) { + *this = *this * v; + } + void operator/=(const bigint& v) { + *this = *this / v; + } + + bool operator<(const bigint& v) const { + if (sign != v.sign) return sign < v.sign; + if (a.size() != v.a.size()) + return a.size() * sign < v.a.size() * v.sign; + for (ll i = (ll)a.size() - 1; i >= 0; i--) + if (a[i] != v.a[i]) + return a[i] * sign < v.a[i] * sign; + return false; + } + + bool operator>(const bigint& v) const { + return v < *this; + } + bool operator<=(const bigint& v) const { + return !(v < *this); + } + bool operator>=(const bigint& v) const { + return !(*this < v); + } + bool operator==(const bigint& v) const { + return !(*this < v) && !(v < *this); + } + bool operator!=(const bigint& v) const { + return *this < v || v < *this; + } + + void trim() { + while (!a.empty() && !a.back()) a.pop_back(); + if (a.empty()) sign = 1; + } + + bool isZero() const { + return a.empty() || (a.size() == 1 && a[0] == 0); + } + + bigint operator-() const { + bigint res = *this; + res.sign = -sign; + return res; + } + + bigint abs() const { + bigint res = *this; + res.sign *= res.sign; + return res; + } + + ll longValue() const { + ll res = 0; + for (ll i = (ll)a.size() - 1; i >= 0; i--) + res = res * base + a[i]; + return res * sign; + } + + void read(const string& s) { + sign = 1; + a.clear(); + ll pos = 0; + while (pos < (ll)s.size() && (s[pos] == '-' || s[pos] == '+')) { + if (s[pos] == '-') sign = -sign; + ++pos; + } + for (ll i = (ll)s.size() - 1; i >= pos; i -= base_digits) { + ll x = 0; + for (ll j = max(pos, i - base_digits + 1); j <= i; j++) + x = x * 10 + s[j] - '0'; + a.push_back(x); + } + trim(); + } + + friend istream& operator>>(istream& stream, bigint& v) { + string s; + stream >> s; + v.read(s); + return stream; + } + + friend ostream& operator<<(ostream& stream, const bigint& v) { + if (v.sign == -1) stream << '-'; + stream << (v.a.empty() ? 0 : v.a.back()); + for (ll i = (ll)v.a.size() - 2; i >= 0; --i) + stream << setw(base_digits) << setfill('0') << v.a[i]; + return stream; + } + + static vll karatsubaMultiply(const vll& a, const vll& b) { + ll n = a.size(); + vll res(n + n); + if (n <= 32) { + for (ll i = 0; i < n; i++) + for (ll j = 0; j < n; j++) + res[i + j] += a[i] * b[j]; + return res; + } + ll k = n >> 1; + vll a1(a.begin(), a.begin() + k); + vll a2(a.begin() + k, a.end()); + vll b1(b.begin(), b.begin() + k); + vll b2(b.begin() + k, b.end()); + vll a1b1 = karatsubaMultiply(a1, b1); + vll a2b2 = karatsubaMultiply(a2, b2); + for (ll i = 0; i < k; i++) a2[i] += a1[i]; + for (ll i = 0; i < k; i++) b2[i] += b1[i]; + vll r = karatsubaMultiply(a2, b2); + for (ll i = 0; i < (ll)a1b1.size(); i++) r[i] -= a1b1[i]; + for (ll i = 0; i < (ll)a2b2.size(); i++) r[i] -= a2b2[i]; + for (ll i = 0; i < (ll)r.size(); i++) res[i + k] += r[i]; + for (ll i = 0; i < (ll)a1b1.size(); i++) res[i] += a1b1[i]; + for (ll i = 0; i < (ll)a2b2.size(); i++) res[i + n] += a2b2[i]; + return res; + } + + bigint operator*(const bigint& v) const { + vll a(this->a.begin(), this->a.end()); + vll b(v.a.begin(), v.a.end()); + while (a.size() < b.size()) a.push_back(0); + while (b.size() < a.size()) b.push_back(0); + while (a.size() & (a.size() - 1)) + a.push_back(0), b.push_back(0); + vll c = karatsubaMultiply(a, b); + bigint res; + res.sign = sign * v.sign; + for (ll i = 0, carry = 0; i < (ll)c.size(); i++) { + ll cur = c[i] + carry; + res.a.push_back(cur % base); + carry = cur / base; + } + res.trim(); + return res; + } +}; \ No newline at end of file -- cgit v1.2.3