Bzoj5296: [Cqoi2018]破解D-H协议
破解D-H协议
Description
Diffie-Hellman密钥交换协议是一种简单有效的密钥交换方法。它可以让通讯双方在没有事先约定密钥(密码)的情况下
通过不安全的信道(可能被窃听)建立一个安全的密钥K,用于加密之后的通讯内容。
假定通讯双方名为Alice和Bob,协议的工作过程描述如下(其中mod表示取模运算):
1.协议规定一个固定的质数P,以及模P的一个原根g。P和g的数值都是公开的,无需保密。
2.Alice生成一个随机数a,并计算A=g^a mod P,将A通过不安全信道发送给Bob。
3.Bob生成一个随机数b,并计算B=g^b mod P,将B通过不安全信道发送给Alice。
4.Bob根据收到的A计算出K=A^b mod P,而Alice根据收到的B计算出K=B^a mod P。
5.双方得到了相同的K,即g^(a*b) mod P。K可以用于之后通讯的加密密钥。
可见,这个过程中可能被窃听的只有A、B,而a、b、K是保密的。并且根据A、B、P、g这4个数,不能轻易计算出
K,因此K可以作为一个安全的密钥。
当然安全是相对的,该协议的安全性取决于数值的大小,通常a、b、P都选取数百位以上的大整数以避免被破解。然而如
果Alice和Bob编程时偷懒,为了避免实现大数运算,选择的数值都小于2^31,那么破解他们的密钥就比较容易了。
Input
输入文件第一行包含两个空格分开的正整数g和P。
第二行为一个正整数n,表示Alice和Bob共进行了n次连接(即运行了n次协议)。
接下来n行,每行包含两个空格分开的正整数A和B,表示某次连接中,被窃听的A、B数值。
2≤A,B<P<231,2≤g<20, n<=20
Output
输出包含n行,每行1个正整数K,为每次连接你破解得到的密钥。
Sample Input
3 31
3
27 16
21 3
9 26
3
27 16
21 3
9 26
Sample Output
4
21
25
21
25
题解:g^x = a (mod p), 求出最小的x ,b^x (mod p)就是答案
对于离散对数问题 可以用大步小步法
对于p 是质数,令 x = A*ceil(sqrt(p)) + B, g^(A*ceil(sqrt(p)) + B) = a (mod p)
因为gcd(g,p) = 1,所以 g^(A*ceil(sqrt(p))) = a*g^(-B) (% p) ( 0 <= A <= ceil(sqrt(p)), 0 <= B <= ceil(sqrt(p)) )
即建立hash 存储 右边,枚举左边,判断是否在hash中存在
但这样需要求逆元,可以优化 令 x = A*ceil(sqrt(p)) - B ( 1 <= A <= ceil(sqrt(p)) + 1 , 0 <= B < ceil(sqrt(p)) )
g^(A*ceil(sqrt(p))) = a*g^B (% p) 这样的话,就可以解决逆元
但对于 p 不是质数的情况,则需要扩展的大步小步法,这里就不介绍了 ,发一个链接,讲解很好的博客
bsgs code:
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll pow_mod(ll a,ll n,ll mod) { ll ans = 1; while(n) { if(n&1) ans = (ans*a)%mod; a = (a*a)%mod; n >>= 1; } return ans; } ll bsgs(ll a,ll b,ll p) { map<ll,int>Hash; int m = ceil(sqrt(p)); ll v = pow_mod(a, m, p),k = v; for(int i = 0;i < m;i++) { if(!Hash.count(b)) Hash[b] = i; b = (b*a)%p; } for(int i = 1; i <= m+1;i++) { if(Hash.count(k)) return i*m - Hash[k]; k = (k*v)%p; } return -1; } int main(){ ll g,n,a,b,p; scanf("%lld%lld%lld",&g,&p,&n); while(n--) { scanf("%lld%lld",&a,&b); ll ans = bsgs(g,a,p); if(ans == -1) continue; printf("%lld\n",pow_mod(b,ans,p)); } }
exbsgs code:
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll pow_mod(ll a,ll n,ll mod = LONG_LONG_MAX) { ll ans = 1; while(n) { if(n&1) ans = (ans*a)%mod; a = (a*a)%mod; n >>= 1; } return ans; } int gcd(int a,int b){ return b?gcd(b,a%b):a;} int bsgs(int a, int b, int p) { int cnt = 0; ll t = 1; a %= p, b %= p; map<ll, int>H; if(b == 1) return 0; for(int g = gcd(a, p); g != 1; g = gcd(a, p)) { if(b % g) return -1; p /= g; b /= g; t = t * a / g % p; ++cnt; if(b == t) return cnt; } int m = int(sqrt(p) + 1); ll base = b; for(int i = 0; i < m; ++i) { H[base] = i; base = base * a % p; } base = pow_mod(a, m, p); ll now = t; for(int i = 1; i <= m + 1; ++i) { now = now * base % p; if(H.count(now)) return i * m - H[now] + cnt; } return -1; } int main() { int g,p; int t,a,b; scanf("%d%d%d",&g,&p,&t); while(t--) { scanf("%d%d",&a,&b); int ans = bsgs(g,a,p); if(ans == -1) continue; //cout<<ans<<endl; printf("%lld\n",pow_mod(1ll*b,1ll*ans,1ll*p)); } return 0; }