[BZOJ2242][SDOI2011]计算器

一道非常经典的同余方程入门题

题意

  1. 计算pow(a, b)
  2. 线性同余方程
  3. 高次同余方程

分析

pow(a, b)

我们可以使用快速幂完成这项操作

qword qpow(qword a, qword b, qword p) {
	a %= p; qword res = 1 % p; for (; b; b >>= 1, a = a * a % p) if (b & 1) res = res * a % p; return res;
}

线性同余方程

求解线性同余方程我们一般使用拓展欧几里得算法。
这里稍微证明一下这种算法的正确性并描述一下这种算法

Bezout(贝祖)定理:对于任意整数 \(a, b\),存在一对整数 \(x, y\),满足 \(ax + by = gcd(a, b)\)
证明
\(b = 0\)时, 显然有一对整数 \(x = 1, y = 0\) 使得 \(a * 1 + 0 * 0 = gcd(a, 0)\)
\(b > 0\),则 \(gcd(a, b) = gcd(b, a \bmod b)\)。假设存在一对整数 \(x, y\) 满足 \(b * x + (a \bmod b) * y = gcd(b, a \bmod b)\).
因为 $$bx + (a \bmod b)y = bx + (a - b * [a / b])y = ay - b(x - [a / b]y)$$
所以令 \(\mathop{{x}'} = y, \mathop{{y}'} = x - [a / b]y\) 就得到了 \(a\mathop{{x}'} + b\mathop{{y}'} = gcd(a, b)\)
应用数学归纳法,可知定理成立。

Bezout定理是按照欧几里得算法的流程进行证明的,所以这种能同时计算 \(x, y\) 的算法叫做扩展欧几里得算法
例如在本题中,由于 \(y, z\) 是给定的参数,我们不妨将其设为 \(a, b\)
原式变成了 \(ax = b + py\) 我们令 \(p = -p\) 就有 \(ax + py = b\)。(这里的 \(y\) 是同余方程中设出的另一个解)
易知,当 \(gcd(a, p) | b\) 时有解。我们最终只需要解出 \(ax + py = gcd(a, p)\) 的值,并且扩大 \(\frac{b}{gcd(a,p)}\)即可。

qword exgcd(qword a, qword b, qword &x, qword &y) {
	if (b == 0) { x = 1, y = 0; return a; }
	qword d = exgcd(b, a % b, x, y);
	qword z = x; x = y; y = z - y * (a / b);
	return d;
}

高次同余方程

求解高次同余方程我们一般使用Baby Step Gaint Step算法
求解形式:\(a^x = b \pmod p\), 要求 \(a, p\)互质
算法复杂度:$O(\sqrt{p})
因为 \(a, p\) 互质,所以可以在模 \(p\) 意义下执行关于 \(a\) 的乘、除运算。
\(x = i * t - j\) 其中 \(t = [\sqrt{p}], 0 \le j \le t - 1\),则方程变为 \(a ^{i * t - j} = b \pmod p\),即 \((a^t)^i = b * a ^ j \pmod p\)
对于所有的 \(j \in [0, t - 1]\) ,把 \(b * a^j \pmod p\) 插入一个hash表中
枚举 \(i\) 的所有可能取值,计算出 \((a^t)^i\) 在hash表中是否存在对应的 \(j\),更新答案即可。

qword baby_step_gaint_step(qword a, qword b, qword p) {
	mp.clear();
	b %= p;
	qword t = (qword)sqrt(p) + 1;
	for (int j = 0; j < t; ++ j) {
		qword val = (qword)b * qpow(a, j, p) % p;
		mp[val] = j;
	}
	a = qpow(a, t, p);
	if (a == 0) return b == 0 ? 1 : -1;
	for (int i = 0; i <= t; ++ i) {
		qword val = qpow(a, i, p); 
		qword j = mp.find(val) == mp.end() ? -1 : mp[val];
		if (j >= 0 && i * t - j >= 0) return i * t - j;
	}
	return -1;
}

完整代码

#include <bits/stdc++.h>
using namespace std;

typedef long long qword;

qword qpow(qword a, qword b, qword p) {
	a %= p; qword res = 1 % p; for (; b; b >>= 1, a = a * a % p) if (b & 1) res = res * a % p; return res;
}

qword exgcd(qword a, qword b, qword &x, qword &y) {
	if (b == 0) { x = 1, y = 0; return a; }
	qword d = exgcd(b, a % b, x, y);
	qword z = x; x = y; y = z - y * (a / b);
	return d;
}

map<qword, qword> mp;

qword baby_step_gaint_step(qword a, qword b, qword p) {
	mp.clear();
	b %= p;
	qword t = (qword)sqrt(p) + 1;
	for (int j = 0; j < t; ++ j) {
		qword val = (qword)b * qpow(a, j, p) % p;
		mp[val] = j;
	}
	a = qpow(a, t, p);
	if (a == 0) return b == 0 ? 1 : -1;
	for (int i = 0; i <= t; ++ i) {
		qword val = qpow(a, i, p); 
		qword j = mp.find(val) == mp.end() ? -1 : mp[val];
		if (j >= 0 && i * t - j >= 0) return i * t - j;
	}
	return -1;
}

qword t, op, x, y;

int main() {
	cin >> t >> op;
	for (int i = 1, a, b, p; i <= t; ++ i) {
		cin >> a >> b >> p;
		if (op == 1) cout << qpow(a, b, p) << endl;
		else if (op == 2) {
			qword d = exgcd(a, p, x, y);
			if (b % d) cout << "Orz, I cannot find x!" << endl;
			else {
				x *= b / d;
				cout << (x % p + p) % p << endl;
			}
		} else if (op == 3){
			x = baby_step_gaint_step(a, b, p);
			if (x == -1) cout << "Orz, I cannot find x!" << endl;
			else cout << (x % p + p) % p << endl;
		}
	}
	return 0;
}
posted @ 2018-10-05 11:35  AlessandroChen  阅读(409)  评论(0编辑  收藏  举报