大数模版

/*
 * $File: test.cpp
 * $Date: Wed Feb 09 13:22:29 2011 +0800
 * $Author: Zhou Xinyu <zxytim@gmail.com>
 *
 * a simple High precision integer implementation
 */
#include <cstdio>
#include <cstring>
#include <cctype>
#include <cmath>
#include <algorithm>
#include <cassert>
class Bignum
{
#define MULTIPLICATION_FASTER
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
	private:

		typedef int Num_t;
		typedef long long Num_bigger_t;
		static const int N_BIT_MAX = 25000;
		static const int ITER_DEPTH_MAX = 14; // log(N_BIT_MAX) / log(2.0)
		static const int BASE = 1000000000;
		//	static const int BASE = 10;
		//	static const int ITER_DEPTH_MAX = 100; // log(N_BIT_MAX) / log(2.0)

		int nbits;
		Num_t bit[N_BIT_MAX];

		static const int POSITIVE = false,
					 NEGATIVE = true;

		// when the number is 0, @sign is %POSITIVE
		bool sign;

		bool isZero()  const
		{
			return nbits == 1 && bit[0] == 0;
		}

		void setZero()
		{
			nbits = 1;
			bit[0] = 0;
			sign = POSITIVE;
		}

		static bool absGreater(const Bignum &a, const Bignum &b)
		{
			if (a.nbits != b.nbits)
				return a.nbits > b.nbits;
			for (int i = a.nbits - 1; i >= 0; i --)
				if (a.bit[i] != b.bit[i])
					return a.bit[i] > b.bit[i];
			return false;
		}

		static bool absGreaterEqual(const Bignum &a, const Bignum &b)
		{
			if (a.nbits != b.nbits)
				return a.nbits > b.nbits;
			for (int i = a.nbits - 1; i >= 0; i --)
				if (a.bit[i] != b.bit[i])
					return a.bit[i] > b.bit[i];
			return true;
		}


		static void absPlus(const Bignum &a, const Bignum &b, Bignum &ret)
		{
			if (a.isZero())
			{
				ret = b;
				return;
			}
			else if (b.isZero())
			{
				ret = a;
				return;
			}

			int &len = ret.nbits = MAX(a.nbits, b.nbits);
			ret.bit[0] = 0;
			for (int i = 0; i < len; i ++)
			{
				ret.bit[i] += (i < a.nbits ? a.bit[i] : 0)
					+ (i < b.nbits ? b.bit[i] : 0);
				if (ret.bit[i] >= BASE)
				{
					ret.bit[i] -= BASE;
					ret.bit[i + 1] = 1;
				} else ret.bit[i + 1] = 0;
			}
			if (ret.bit[len])
				len ++;
		}

		static void absPlusSafe(const Bignum &a, const Bignum &b, Bignum &c)
		{
			if (a.isZero())
			{
				c = b;
				return;
			}
			else if (b.isZero())
			{
				c = a;
				return;
			}

			static Bignum ret;
			int &len = ret.nbits = MAX(a.nbits, b.nbits);
			ret.bit[0] = 0;
			for (int i = 0; i < len; i ++)
			{
				ret.bit[i] += (i < a.nbits ? a.bit[i] : 0)
					+ (i < b.nbits ? b.bit[i] : 0);
				if (ret.bit[i] >= BASE)
				{
					ret.bit[i] -= BASE;
					ret.bit[i + 1] = 1;
				} else ret.bit[i + 1] = 0;
			}
			if (ret.bit[len])
				len ++;
			c = ret;
		}

		static void absMinus(const Bignum &a, const Bignum &b, Bignum &ret)
		{
#ifdef DEBUG
			assert(a.nbits>= b.nbits);
			assert(absGreaterEqual(a, b));
#endif
			if (b.isZero())
			{
				ret = a;
				return;
			}
			int &len = ret.nbits = a.nbits;
			Num_t borrow = 0;
			for (int i = 0; i < len; i ++)
			{
				Num_t n0 = a.bit[i] - borrow,
					  n1 = (i < b.nbits ? b.bit[i] : 0);
				if (n0 < n1)
				{
					n0 += BASE;
					borrow = 1;
				} else borrow = 0;
				ret.bit[i] = n0 - n1;
			}
			while (len > 1 && ret.bit[len - 1] == 0)
				len --;
		}

		/*---------------- algorithms of big integer multiplication --------------- */
		/*
		 * force
		 * O(n^2)
		 */
		static void absMultiply_square_n(const Bignum &a, const Bignum &b, Bignum &ret)
		{
			if (a.isZero() || b.isZero())
			{
				ret.setZero();
				return;
			}
			int &len = ret.nbits = a.nbits + b.nbits;
			for (int i = 0; i < len; i ++)
				ret.bit[i] = 0;
			for (int i = 0; i < a.nbits; i ++)
				for (int j = 0, p = i + j; j < b.nbits; j ++, p ++)
				{
					Num_bigger_t now = a.bit[i];
					now *= b.bit[j];
					now += ret.bit[p];
					if (now >= BASE)
					{
						Num_t v = now / BASE;
						ret.bit[p + 1] += v;
						ret.bit[p] = (Num_t)(now - v * BASE); // now % BASE
					} else ret.bit[p] = (Num_t)now;
				}
			if (ret.bit[len - 1] == 0)
				len --;
		}

		/*
		 * suppose the number of digits of the greater one of a and b is n, and we can
		 * fill the smaller one with leading zeros.
		 *
		 * the division below is integer division.
		 *
		 * let a = A*10^(n/2) + B
		 *     b = C*10^(n/2) + D
		 * 
		 * then
		 *     a * b = (A*10^(n/2) + B) * (C*10^(n/2) + D)
		 *           = AC*10^(n/2 + n/2) + BC*10^(n/2) + AD*10^(n/2) + BD
		 *
		 *	   a * b = AC*10^(n/2 + n/2) + (BC + AD)*10^(n/2) + BD
		 *           = AC*10^(n/2 + n/2) + (BC + AD - AC - BD + AC + BD)*10^(n/2) + BD
		 *           = AC*10^(n/2 + n/2) + ((A - B) * (D - C) + AC + BD)*10^(n/2) + BD
		 *
		 * we can see that only three products AC, BD and (A - B)(D - C) need to be calculated.
		 * suppose the number of digits of a and b are all equal to n,
		 * then the time complexity is:
		 *		T(1) = 1
		 *		T(n) = 3T(n/2) + O(n)
		 * and the solution is:
		 *		T(n) = O(n^(log(3, 2))) = O(n^1.58496)
		 */
		static void absMultiply_n_power_1p58496_iter(const Bignum &a, const Bignum &b, Bignum &ret, int depth)
		{
			static Bignum A[ITER_DEPTH_MAX], B[ITER_DEPTH_MAX], C[ITER_DEPTH_MAX], D[ITER_DEPTH_MAX];

			// two below are product
			static Bignum AC[ITER_DEPTH_MAX], BD[ITER_DEPTH_MAX];

			// two below are difference
			static Bignum AB[ITER_DEPTH_MAX], DC[ITER_DEPTH_MAX];

			// below is product of differences
			static Bignum ABDC[ITER_DEPTH_MAX];

			int nbits = MAX(a.nbits, b.nbits),
				half_nbits = nbits - (nbits >> 1);

			ret.setZero();

			if (a.isZero() || b.isZero())
				return;

			// just a small useless trick
			static const int NBITS_TO_FORCE = 20;
			int min_nbits = MIN(a.nbits, b.nbits);
			if (min_nbits <= NBITS_TO_FORCE || min_nbits <= sqrt((double)nbits))
			{
				absMultiply_square_n(a, b, ret);
				return;
			}

			const Bignum *pa = &a, *pb = &b;

			if (absGreater(a, b))
				swap(pa, pb);

			// *pb is the greater one

			partition(*pa, A[depth], B[depth], nbits);
			partition(*pb, C[depth], D[depth], nbits);


			absMultiply_n_power_1p58496_iter(A[depth], C[depth], AC[depth], depth + 1);
#ifdef DEBUG
			static Bignum tmp;
			absMultiply_square_n(A[depth], C[depth], tmp);
			assert(tmp == AC[depth]);
#endif

			absMultiply_n_power_1p58496_iter(B[depth], D[depth], BD[depth], depth + 1);
#ifdef DEBUG
			absMultiply_square_n(B[depth], D[depth], tmp);
			assert(tmp == BD[depth]);
#endif

			absPlus(AC[depth], BD[depth], ret);

			AB[depth] = A[depth] - B[depth];
			DC[depth] = D[depth] - C[depth];

			if (!(AB[depth].isZero() || DC[depth].isZero()))
			{
				absMultiply_n_power_1p58496_iter(AB[depth], DC[depth], ABDC[depth], depth + 1);
#ifdef DEBUG
				absMultiply_square_n(AB[depth], DC[depth], tmp);
				assert(tmp == ABDC[depth]);
#endif
				if (AB[depth].sign != DC[depth].sign)
					absMinus(ret, ABDC[depth], ret);
				else absPlusSafe(ret, ABDC[depth], ret);
			}

			left_shift_in_BASE_system(ret, half_nbits);

			left_shift_in_BASE_system(AC[depth], half_nbits << 1);

			absPlusSafe(ret, AC[depth], ret);
			absPlusSafe(ret, BD[depth], ret);

		}

		static void swap(const Bignum *&a, const Bignum *&b)
		{
			const Bignum *t = a;
			a = b;
			b = t;
		}

		static void partition(const Bignum &n, Bignum &a, Bignum &b, int nbits)
		{
			int half_nbits = (nbits + 1) >> 1;
			assert(half_nbits == nbits - (nbits >> 1));
			if (n.nbits <= half_nbits)
			{
				a.setZero();
				b = n;
				return;
			}

			for (int i = n.nbits - 1, p = (nbits >> 1) - (nbits - n.nbits) - 1; i >= half_nbits; i --, p --)
				a.bit[p] = (i < n.nbits ? n.bit[i] : 0);

			a.nbits = (nbits >> 1) - (nbits - n.nbits);
			while (a.bit[a.nbits - 1] == 0 && a.nbits > 1)
				a.nbits --;


			for (int i = half_nbits - 1; i >= 0; i --)
				b.bit[i] = n.bit[i];

			b.nbits = half_nbits;
			while (b.bit[b.nbits - 1] == 0 && b.nbits > 1)
				b.nbits --;
		}

		static void left_shift_in_BASE_system(Bignum &ret, int nbits)
		{
			if (nbits <= 0)
				return;
			if (ret.isZero())
				return;
			for (int i = ret.nbits - 1, j = ret.nbits + nbits - 1; i >= 0; i --, j --)
				ret.bit[j] = ret.bit[i];
			for (int i = nbits - 1; i >= 0; i --)
				ret.bit[i] = 0;
			ret.nbits += nbits;
		}

		static void absMultiply_n_power_1p58496(const Bignum &a, const Bignum &b, Bignum &ret)
		{
			absMultiply_n_power_1p58496_iter(a, b, ret, 0);
		}

		/*
		 * big integer multiplication using Fast Fourier Transform(FFT) algorithm
		 * Time complexity is O(nlogn)
		 */
		static void absMultiply_nlogn(const Bignum &a, const Bignum &b, Bignum &ret)
		{
		}

		static void absMultiply(const Bignum &a, const Bignum &b, Bignum &ret)
		{
			if (a.isZero() || b.isZero())
			{
				ret.setZero();
				return;
			}

#ifdef MULTIPLICATION_SLOW
			absMultiply_square_n(a, b, ret);
#elif defined(MULTIPLICATION_FASTER)
			absMultiply_n_power_1p58496(a, b, ret);
#elif defined(MULTIPLICATION_FASTEST)
#else
			absMultiply_square_n(a, b, ret);
#endif

#ifdef DEBUG
			assert(ret.bit[len - 1] != 0);
			assert(!this->isZero());
#endif

		}

		/*------------ end of big integer multiplication -------------*/

	public:
		Bignum(){}

		Bignum(long long val)
		{
			if (val < 0)
			{
				sign = NEGATIVE;
				val = -val;
			}
			else sign = POSITIVE;
			nbits = 1;
			if (val == 0)
				bit[0] = 0;
			else
			{
				while (val)
				{
					bit[nbits - 1] = 0;
					for (Num_t base = 1; base < BASE && val; base *= 10)
					{
						bit[nbits - 1] = bit[nbits - 1] + val % 10 * base;
						val /= 10;
					}
					if (val)
						nbits ++;
				}
			}
		}

		double toDouble() const
		{
			double ret = 0;
			for (int i = nbits - 1; i >= 0; i --)
				ret = ret * BASE + bit[i];
			if (sign == NEGATIVE)
				ret = -ret;
			return ret;
		}

		long double toLongDouble() const
		{
			long double ret = 0;
			for (int i = nbits - 1; i >= 0; i --)
				ret = ret * BASE + bit[i];
			if (sign == NEGATIVE)
				ret = -ret;
			return ret;
		}

		Bignum& fromString(const char *str)
		{
			const char *begin = str;
			str += strlen(str) - 1;
			if (!isdigit(*begin))
			{
				if (*begin== '-')
					sign = NEGATIVE;
				else sign = POSITIVE;
				begin ++;
			} else
				sign = POSITIVE;

			nbits = 1;
			bit[0] = 0;
			while (str >= begin)
			{
				bit[nbits - 1] = 0;
				for (Num_t base = 1; base < BASE && str >= begin; base *= 10, str --)
				{
					while (str >= begin && !isdigit(*str))
						str --;
					if (str >= begin)
						bit[nbits - 1] += (*str - '0') * base;
				}
				if (str >= begin)
					nbits ++;
			}

			return *this;
		}

		Bignum& operator = (const Bignum &n)
		{
			nbits = n.nbits;
			sign = n.sign;
			memcpy(bit, n.bit, sizeof(Num_t) * nbits);
			return *this;
		}

		bool operator == (const Bignum &n) const
		{
			if (nbits != n.nbits)
				return false;
			if (sign != n.sign)
				return false;
			for (int i = 0; i < nbits; i ++)
				if (bit[i] != n.bit[i])
					return false;
			return true;
		}

		bool operator != (const Bignum &n) const
		{
			if (nbits != n.nbits)
				return true;
			if (sign != n.sign)
				return true;
			for (int i = 0; i < nbits; i ++)
				if (bit[i] != n.bit[i])
					return true;
			return false;
		}

		bool operator < (const Bignum &n) const
		{
			if (nbits != n.nbits)
				return nbits < n.nbits;
			if (sign != n.sign)
				return sign == NEGATIVE;
			if (sign == POSITIVE)
				return absGreater(n, *this);
			else return absGreater(*this, n);
		}

		bool operator > (const Bignum &n) const
		{
			return n < *this;
		}

		bool operator <= (const Bignum &n) const
		{
			if (nbits != n.nbits)
				return nbits < n.nbits;
			if (sign != n.sign)
				return sign == NEGATIVE;
			if (sign == POSITIVE)
				return absGreaterEqual(n, *this);
			else return absGreaterEqual(*this, n);
		}

		bool operator >= (const Bignum &n) const
		{
			return n <= *this;
		}

		// TODO: bit shift is currently not provided.

		Bignum& operator + (const Bignum &n) const
		{
			static Bignum ret;
			if (sign != n.sign)
			{
				bool cmp = absGreaterEqual(*this, n);
				if (cmp)
					absMinus(*this, n, ret);
				else absMinus(n, *this, ret);

				if (sign == POSITIVE)
				{
					if (cmp)
						ret.sign = POSITIVE;
					else ret.sign = NEGATIVE;
				} 
				else
				{
					if (!cmp)
						ret.sign = POSITIVE;
					else ret.sign = NEGATIVE;
				}
			} 
			else
			{
				absPlus(*this, n, ret);
				ret.sign = sign;
			}
			return ret;
		}

		Bignum& operator += (const Bignum &n)
		{
			// TODO: don't do like below. that will take down the efficiency
			return *this = *this + n;
		}

		Bignum& operator - (const Bignum &n) const
		{
			static Bignum ret;
			if (sign != n.sign)
			{
				absPlus(*this, n, ret);
				ret.sign = sign;
			}
			else
			{
				bool cmp = absGreaterEqual(*this, n);
				if (cmp)
					absMinus(*this, n, ret);
				else absMinus(n, *this, ret);

				if (cmp)
				{
					if (sign == POSITIVE)
						ret.sign = POSITIVE;
					else ret.sign = NEGATIVE;
				}
				else
				{
					if (sign == POSITIVE)
						ret.sign = NEGATIVE;
					else ret.sign = POSITIVE;
				}
			}
			if (ret.isZero())
				ret.sign = POSITIVE;
			return ret;
		}

		Bignum& operator -= (const Bignum &n)
		{
			// TODO
			return *this = *this - n;
		}

		Bignum& operator * (const Bignum &n) const
		{
			static Bignum ret;
			if (this->isZero() || n.isZero())
			{
				ret.setZero();
				return ret;
			}
			absMultiply(*this, n, ret);
			ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE);
			return ret;
		}

		Bignum& operator *= (const Bignum &n)
		{
			// TODO
			return *this = *this * n;
		}

		// !!!IMPORTANT!!!
		// this division algorithm work iff 
		//		divisor * BASE <= max number a long double can hold,
		//							approximately 1e4932
		//		that is the number of digits of the product of divisor 
		//		and BASE in decimal should not exceed 4932

		Bignum& operator / (const Bignum &n) const
		{
			static Bignum ret, remainder, tmp;
			ret.setZero();

			if (*this < n)
				return ret;

			remainder.setZero();
			long double dremainder = 0, dn = n.toLongDouble();
			if (dn < 0)
				dn = -dn;

			static long double LONG_DOUBLE_MAX = 1e4932L;
			assert(dn * BASE <= LONG_DOUBLE_MAX);

			for (int i = nbits - 1; i >= 0; i --)
			{
				if (!remainder.isZero())
				{
					for (int j = remainder.nbits - 1; j >= 0; j --)
						remainder.bit[j + 1] = remainder.bit[j];
					remainder.bit[0] = bit[i];
					remainder.nbits ++;
				}
				else
				{
					remainder.bit[0] = bit[i];
					remainder.nbits = 1;
					remainder.sign = POSITIVE;
				}

				Num_t &b = ret.bit[i] = 0;
				while (n <= remainder)
				{
					dremainder = remainder.toLongDouble();
					Num_t t = (Num_t)floor(dremainder / dn);
					b += t;
#ifdef DEBUG
					assert(t < BASE);
#endif
					absMultiply(n, t, tmp);
					absMinus(remainder, tmp, remainder);
					//remainder -= tmp;
				}

				//				while (remainder < 0)
				//					remainder += n;
			}
			ret.nbits = nbits;
			while (ret.bit[ret.nbits - 1] == 0)
				ret.nbits --;

			ret.sign = (sign == n.sign ? POSITIVE : NEGATIVE);
			return ret;
		}

		Bignum& operator /= (const Bignum &n)
		{
			// TODO
			return *this = *this / n;
		}

		Bignum& operator % (const Bignum &n) const
		{
			// adapted from operator /
			static Bignum remainder, tmp;
			if (*this < n)
			{
				remainder = *this;
				return remainder;
			}

			long double dremainder = 0, dn = n.toLongDouble();
			if (dn < 0)
				dn = -dn;

			static long double LONG_DOUBLE_MAX = 1e4932L;
			assert(dn * BASE <= LONG_DOUBLE_MAX);

			for (int i = nbits - 1; i >= 0; i --)
			{
				if (!remainder.isZero())
				{
					for (int j = remainder.nbits - 1; j >= 0; j --)
						remainder.bit[j + 1] = remainder.bit[j];
					remainder.bit[0] = bit[i];
					remainder.nbits ++;
				}
				else
				{
					remainder.bit[0] = bit[i];
					remainder.nbits = 1;
					remainder.sign = POSITIVE;
				}

				while (n <= remainder)
				{
					dremainder = remainder.toLongDouble();
					Num_t t = (Num_t)floor(dremainder / dn);
#ifdef DEBUG
					assert(t < BASE);
#endif
					absMultiply(n, t, tmp);
					absMinus(remainder, tmp, remainder);
				}
			}
			remainder.sign = (this->sign == n.sign ? POSITIVE : NEGATIVE);
			return remainder;
		}

		Bignum& operator %= (const Bignum &n)
		{
			// TODO
			return *this = *this % n;
		}

		void print(bool newline = false, FILE *fout = stdout)
		{
			if (sign == NEGATIVE)
				fprintf(fout, "-");
			fprintf(fout, "%d", bit[nbits - 1]);
			for (int i = nbits - 2; i >= 0; i --)
				for (Num_t base = BASE / 10; base; base /= 10)
					fprintf(fout, "%d", (bit[i] / base) % 10);
			if (newline)
				fprintf(fout, "\n");
		}

		bool scan(FILE *fin = stdin)
		{
			char *str;
			int len = 0;
			for (Num_t base = BASE ; base; base /= 10)
				len += N_BIT_MAX;
			str = new char[len];
			if (fscanf(fin, "%s", str) == EOF)
				return false;
			this->fromString(str);
			delete [] str;
			return true;
		}
#undef MAX
#undef MIN
};

int main()
{
	Bignum a, b;
	while (a.scan() && b.scan())
		(a * b).print(true);
	return 0;
}

posted on 2011-03-13 12:13  CrazyAC  阅读(229)  评论(0编辑  收藏  举报