大数模版
/* * $File: test.cpp * $Date: Wed Feb 09 13:22:29 2011 +0800 * $Author: Zhou Xinyu <zxytim@gmail.com> * * a simple High precision integer implementation */ #include <cstdio> #include <cstring> #include <cctype> #include <cmath> #include <algorithm> #include <cassert> class Bignum { #define MULTIPLICATION_FASTER #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) private: typedef int Num_t; typedef long long Num_bigger_t; static const int N_BIT_MAX = 25000; static const int ITER_DEPTH_MAX = 14; // log(N_BIT_MAX) / log(2.0) static const int BASE = 1000000000; // static const int BASE = 10; // static const int ITER_DEPTH_MAX = 100; // log(N_BIT_MAX) / log(2.0) int nbits; Num_t bit[N_BIT_MAX]; static const int POSITIVE = false, NEGATIVE = true; // when the number is 0, @sign is %POSITIVE bool sign; bool isZero() const { return nbits == 1 && bit[0] == 0; } void setZero() { nbits = 1; bit[0] = 0; sign = POSITIVE; } static bool absGreater(const Bignum &a, const Bignum &b) { if (a.nbits != b.nbits) return a.nbits > b.nbits; for (int i = a.nbits - 1; i >= 0; i --) if (a.bit[i] != b.bit[i]) return a.bit[i] > b.bit[i]; return false; } static bool absGreaterEqual(const Bignum &a, const Bignum &b) { if (a.nbits != b.nbits) return a.nbits > b.nbits; for (int i = a.nbits - 1; i >= 0; i --) if (a.bit[i] != b.bit[i]) return a.bit[i] > b.bit[i]; return true; } static void absPlus(const Bignum &a, const Bignum &b, Bignum &ret) { if (a.isZero()) { ret = b; return; } else if (b.isZero()) { ret = a; return; } int &len = ret.nbits = MAX(a.nbits, b.nbits); ret.bit[0] = 0; for (int i = 0; i < len; i ++) { ret.bit[i] += (i < a.nbits ? a.bit[i] : 0) + (i < b.nbits ? b.bit[i] : 0); if (ret.bit[i] >= BASE) { ret.bit[i] -= BASE; ret.bit[i + 1] = 1; } else ret.bit[i + 1] = 0; } if (ret.bit[len]) len ++; } static void absPlusSafe(const Bignum &a, const Bignum &b, Bignum &c) { if (a.isZero()) { c = b; return; } else if (b.isZero()) { c = a; return; } static Bignum ret; int &len = ret.nbits = MAX(a.nbits, b.nbits); ret.bit[0] = 0; for (int i = 0; i < len; i ++) { ret.bit[i] += (i < a.nbits ? a.bit[i] : 0) + (i < b.nbits ? b.bit[i] : 0); if (ret.bit[i] >= BASE) { ret.bit[i] -= BASE; ret.bit[i + 1] = 1; } else ret.bit[i + 1] = 0; } if (ret.bit[len]) len ++; c = ret; } static void absMinus(const Bignum &a, const Bignum &b, Bignum &ret) { #ifdef DEBUG assert(a.nbits>= b.nbits); assert(absGreaterEqual(a, b)); #endif if (b.isZero()) { ret = a; return; } int &len = ret.nbits = a.nbits; Num_t borrow = 0; for (int i = 0; i < len; i ++) { Num_t n0 = a.bit[i] - borrow, n1 = (i < b.nbits ? b.bit[i] : 0); if (n0 < n1) { n0 += BASE; borrow = 1; } else borrow = 0; ret.bit[i] = n0 - n1; } while (len > 1 && ret.bit[len - 1] == 0) len --; } /*---------------- algorithms of big integer multiplication --------------- */ /* * force * O(n^2) */ static void absMultiply_square_n(const Bignum &a, const Bignum &b, Bignum &ret) { if (a.isZero() || b.isZero()) { ret.setZero(); return; } int &len = ret.nbits = a.nbits + b.nbits; for (int i = 0; i < len; i ++) ret.bit[i] = 0; for (int i = 0; i < a.nbits; i ++) for (int j = 0, p = i + j; j < b.nbits; j ++, p ++) { Num_bigger_t now = a.bit[i]; now *= b.bit[j]; now += ret.bit[p]; if (now >= BASE) { Num_t v = now / BASE; ret.bit[p + 1] += v; ret.bit[p] = (Num_t)(now - v * BASE); // now % BASE } else ret.bit[p] = (Num_t)now; } if (ret.bit[len - 1] == 0) len --; } /* * suppose the number of digits of the greater one of a and b is n, and we can * fill the smaller one with leading zeros. * * the division below is integer division. * * let a = A*10^(n/2) + B * b = C*10^(n/2) + D * * then * a * b = (A*10^(n/2) + B) * (C*10^(n/2) + D) * = AC*10^(n/2 + n/2) + BC*10^(n/2) + AD*10^(n/2) + BD * * a * b = AC*10^(n/2 + n/2) + (BC + AD)*10^(n/2) + BD * = AC*10^(n/2 + n/2) + (BC + AD - AC - BD + AC + BD)*10^(n/2) + BD * = AC*10^(n/2 + n/2) + ((A - B) * (D - C) + AC + BD)*10^(n/2) + BD * * we can see that only three products AC, BD and (A - B)(D - C) need to be calculated. * suppose the number of digits of a and b are all equal to n, * then the time complexity is: * T(1) = 1 * T(n) = 3T(n/2) + O(n) * and the solution is: * T(n) = O(n^(log(3, 2))) = O(n^1.58496) */ static void absMultiply_n_power_1p58496_iter(const Bignum &a, const Bignum &b, Bignum &ret, int depth) { static Bignum A[ITER_DEPTH_MAX], B[ITER_DEPTH_MAX], C[ITER_DEPTH_MAX], D[ITER_DEPTH_MAX]; // two below are product static Bignum AC[ITER_DEPTH_MAX], BD[ITER_DEPTH_MAX]; // two below are difference static Bignum AB[ITER_DEPTH_MAX], DC[ITER_DEPTH_MAX]; // below is product of differences static Bignum ABDC[ITER_DEPTH_MAX]; int nbits = MAX(a.nbits, b.nbits), half_nbits = nbits - (nbits >> 1); ret.setZero(); if (a.isZero() || b.isZero()) return; // just a small useless trick static const int NBITS_TO_FORCE = 20; int min_nbits = MIN(a.nbits, b.nbits); if (min_nbits <= NBITS_TO_FORCE || min_nbits <= sqrt((double)nbits)) { absMultiply_square_n(a, b, ret); return; } const Bignum *pa = &a, *pb = &b; if (absGreater(a, b)) swap(pa, pb); // *pb is the greater one partition(*pa, A[depth], B[depth], nbits); partition(*pb, C[depth], D[depth], nbits); absMultiply_n_power_1p58496_iter(A[depth], C[depth], AC[depth], depth + 1); #ifdef DEBUG static Bignum tmp; absMultiply_square_n(A[depth], C[depth], tmp); assert(tmp == AC[depth]); #endif absMultiply_n_power_1p58496_iter(B[depth], D[depth], BD[depth], depth + 1); #ifdef DEBUG absMultiply_square_n(B[depth], D[depth], tmp); assert(tmp == BD[depth]); #endif absPlus(AC[depth], BD[depth], ret); AB[depth] = A[depth] - B[depth]; DC[depth] = D[depth] - C[depth]; if (!(AB[depth].isZero() || DC[depth].isZero())) { absMultiply_n_power_1p58496_iter(AB[depth], DC[depth], ABDC[depth], depth + 1); #ifdef DEBUG absMultiply_square_n(AB[depth], DC[depth], tmp); assert(tmp == ABDC[depth]); #endif if (AB[depth].sign != DC[depth].sign) absMinus(ret, ABDC[depth], ret); else absPlusSafe(ret, ABDC[depth], ret); } left_shift_in_BASE_system(ret, half_nbits); left_shift_in_BASE_system(AC[depth], half_nbits << 1); absPlusSafe(ret, AC[depth], ret); absPlusSafe(ret, BD[depth], ret); } static void swap(const Bignum *&a, const Bignum *&b) { const Bignum *t = a; a = b; b = t; } static void partition(const Bignum &n, Bignum &a, Bignum &b, int nbits) { int half_nbits = (nbits + 1) >> 1; assert(half_nbits == nbits - (nbits >> 1)); if (n.nbits <= half_nbits) { a.setZero(); b = n; return; } for (int i = n.nbits - 1, p = (nbits >> 1) - (nbits - n.nbits) - 1; i >= half_nbits; i --, p --) a.bit[p] = (i < n.nbits ? n.bit[i] : 0); a.nbits = (nbits >> 1) - (nbits - n.nbits); while (a.bit[a.nbits - 1] == 0 && a.nbits > 1) a.nbits --; for (int i = half_nbits - 1; i >= 0; i --) b.bit[i] = n.bit[i]; b.nbits = half_nbits; while (b.bit[b.nbits - 1] == 0 && b.nbits > 1) b.nbits --; } static void left_shift_in_BASE_system(Bignum &ret, int nbits) { if (nbits <= 0) return; if (ret.isZero()) return; for (int i = ret.nbits - 1, j = ret.nbits + nbits - 1; i >= 0; i --, j --) ret.bit[j] = ret.bit[i]; for (int i = nbits - 1; i >= 0; i --) ret.bit[i] = 0; ret.nbits += nbits; } static void absMultiply_n_power_1p58496(const Bignum &a, const Bignum &b, Bignum &ret) { absMultiply_n_power_1p58496_iter(a, b, ret, 0); } /* * big integer multiplication using Fast Fourier Transform(FFT) algorithm * Time complexity is O(nlogn) */ static void absMultiply_nlogn(const Bignum &a, const Bignum &b, Bignum &ret) { } static void absMultiply(const Bignum &a, const Bignum &b, Bignum &ret) { if (a.isZero() || b.isZero()) { ret.setZero(); return; } #ifdef MULTIPLICATION_SLOW absMultiply_square_n(a, b, ret); #elif defined(MULTIPLICATION_FASTER) absMultiply_n_power_1p58496(a, b, ret); #elif defined(MULTIPLICATION_FASTEST) #else absMultiply_square_n(a, b, ret); #endif #ifdef DEBUG assert(ret.bit[len - 1] != 0); assert(!this->isZero()); #endif } /*------------ end of big integer multiplication -------------*/ public: Bignum(){} Bignum(long long val) { if (val < 0) { sign = NEGATIVE; val = -val; } else sign = POSITIVE; nbits = 1; if (val == 0) bit[0] = 0; else { while (val) { bit[nbits - 1] = 0; for (Num_t base = 1; base < BASE && val; base *= 10) { bit[nbits - 1] = bit[nbits - 1] + val % 10 * base; val /= 10; } if (val) nbits ++; } } } double toDouble() const { double ret = 0; for (int i = nbits - 1; i >= 0; i --) ret = ret * BASE + bit[i]; if (sign == NEGATIVE) ret = -ret; return ret; } long double toLongDouble() const { long double ret = 0; for (int i = nbits - 1; i >= 0; i --) ret = ret * BASE + bit[i]; if (sign == NEGATIVE) ret = -ret; return ret; } Bignum& fromString(const char *str) { const char *begin = str; str += strlen(str) - 1; if (!isdigit(*begin)) { if (*begin== '-') sign = NEGATIVE; else sign = POSITIVE; begin ++; } else sign = POSITIVE; nbits = 1; bit[0] = 0; while (str >= begin) { bit[nbits - 1] = 0; for (Num_t base = 1; base < BASE && str >= begin; base *= 10, str --) { while (str >= begin && !isdigit(*str)) str --; if (str >= begin) bit[nbits - 1] += (*str - '0') * base; } if (str >= begin) nbits ++; } return *this; } Bignum& operator = (const Bignum &n) { nbits = n.nbits; sign = n.sign; memcpy(bit, n.bit, sizeof(Num_t) * nbits); return *this; } bool operator == (const Bignum &n) const { if (nbits != n.nbits) return false; if (sign != n.sign) return false; for (int i = 0; i < nbits; i ++) if (bit[i] != n.bit[i]) return false; return true; } bool operator != (const Bignum &n) const { if (nbits != n.nbits) return true; if (sign != n.sign) return true; for (int i = 0; i < nbits; i ++) if (bit[i] != n.bit[i]) return true; return false; } bool operator < (const Bignum &n) const { if (nbits != n.nbits) return nbits < n.nbits; if (sign != n.sign) return sign == NEGATIVE; if (sign == POSITIVE) return absGreater(n, *this); else return absGreater(*this, n); } bool operator > (const Bignum &n) const { return n < *this; } bool operator <= (const Bignum &n) const { if (nbits != n.nbits) return nbits < n.nbits; if (sign != n.sign) return sign == NEGATIVE; if (sign == POSITIVE) return absGreaterEqual(n, *this); else return absGreaterEqual(*this, n); } bool operator >= (const Bignum &n) const { return n <= *this; } // TODO: bit shift is currently not provided. Bignum& operator + (const Bignum &n) const { static Bignum ret; if (sign != n.sign) { bool cmp = absGreaterEqual(*this, n); if (cmp) absMinus(*this, n, ret); else absMinus(n, *this, ret); if (sign == POSITIVE) { if (cmp) ret.sign = POSITIVE; else ret.sign = NEGATIVE; } else { if (!cmp) ret.sign = POSITIVE; else ret.sign = NEGATIVE; } } else { absPlus(*this, n, ret); ret.sign = sign; } return ret; } Bignum& operator += (const Bignum &n) { // TODO: don't do like below. that will take down the efficiency return *this = *this + n; } Bignum& operator - (const Bignum &n) const { static Bignum ret; if (sign != n.sign) { absPlus(*this, n, ret); ret.sign = sign; } else { bool cmp = absGreaterEqual(*this, n); if (cmp) absMinus(*this, n, ret); else absMinus(n, *this, ret); if (cmp) { if (sign == POSITIVE) ret.sign = POSITIVE; else ret.sign = NEGATIVE; } else { if (sign == POSITIVE) ret.sign = NEGATIVE; else ret.sign = POSITIVE; } } if (ret.isZero()) ret.sign = POSITIVE; return ret; } Bignum& operator -= (const Bignum &n) { // TODO return *this = *this - n; } Bignum& operator * (const Bignum &n) const { static Bignum ret; if (this->isZero() || n.isZero()) { ret.setZero(); return ret; } absMultiply(*this, n, ret); ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE); return ret; } Bignum& operator *= (const Bignum &n) { // TODO return *this = *this * n; } // !!!IMPORTANT!!! // this division algorithm work iff // divisor * BASE <= max number a long double can hold, // approximately 1e4932 // that is the number of digits of the product of divisor // and BASE in decimal should not exceed 4932 Bignum& operator / (const Bignum &n) const { static Bignum ret, remainder, tmp; ret.setZero(); if (*this < n) return ret; remainder.setZero(); long double dremainder = 0, dn = n.toLongDouble(); if (dn < 0) dn = -dn; static long double LONG_DOUBLE_MAX = 1e4932L; assert(dn * BASE <= LONG_DOUBLE_MAX); for (int i = nbits - 1; i >= 0; i --) { if (!remainder.isZero()) { for (int j = remainder.nbits - 1; j >= 0; j --) remainder.bit[j + 1] = remainder.bit[j]; remainder.bit[0] = bit[i]; remainder.nbits ++; } else { remainder.bit[0] = bit[i]; remainder.nbits = 1; remainder.sign = POSITIVE; } Num_t &b = ret.bit[i] = 0; while (n <= remainder) { dremainder = remainder.toLongDouble(); Num_t t = (Num_t)floor(dremainder / dn); b += t; #ifdef DEBUG assert(t < BASE); #endif absMultiply(n, t, tmp); absMinus(remainder, tmp, remainder); //remainder -= tmp; } // while (remainder < 0) // remainder += n; } ret.nbits = nbits; while (ret.bit[ret.nbits - 1] == 0) ret.nbits --; ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE); return ret; } Bignum& operator /= (const Bignum &n) { // TODO return *this = *this / n; } Bignum& operator % (const Bignum &n) const { // adapted from operator / static Bignum remainder, tmp; if (*this < n) { remainder = *this; return remainder; } long double dremainder = 0, dn = n.toLongDouble(); if (dn < 0) dn = -dn; static long double LONG_DOUBLE_MAX = 1e4932L; assert(dn * BASE <= LONG_DOUBLE_MAX); for (int i = nbits - 1; i >= 0; i --) { if (!remainder.isZero()) { for (int j = remainder.nbits - 1; j >= 0; j --) remainder.bit[j + 1] = remainder.bit[j]; remainder.bit[0] = bit[i]; remainder.nbits ++; } else { remainder.bit[0] = bit[i]; remainder.nbits = 1; remainder.sign = POSITIVE; } while (n <= remainder) { dremainder = remainder.toLongDouble(); Num_t t = (Num_t)floor(dremainder / dn); #ifdef DEBUG assert(t < BASE); #endif absMultiply(n, t, tmp); absMinus(remainder, tmp, remainder); } } remainder.sign = (this->sign == n.sign ? POSITIVE : NEGATIVE); return remainder; } Bignum& operator %= (const Bignum &n) { // TODO return *this = *this % n; } void print(bool newline = false, FILE *fout = stdout) { if (sign == NEGATIVE) fprintf(fout, "-"); fprintf(fout, "%d", bit[nbits - 1]); for (int i = nbits - 2; i >= 0; i --) for (Num_t base = BASE / 10; base; base /= 10) fprintf(fout, "%d", (bit[i] / base) % 10); if (newline) fprintf(fout, "\n"); } bool scan(FILE *fin = stdin) { char *str; int len = 0; for (Num_t base = BASE ; base; base /= 10) len += N_BIT_MAX; str = new char[len]; if (fscanf(fin, "%s", str) == EOF) return false; this->fromString(str); delete [] str; return true; } #undef MAX #undef MIN }; int main() { Bignum a, b; while (a.scan() && b.scan()) (a * b).print(true); return 0; }