算法总结之求解模线性方程组
算法总结之求解模线性方程组
1)求解模线性方程 ax = b(mod n)
方程ax = b(mod n) -> ax = b + ny ->ax - ny = b
-> ax + n (-y) =b 其中a,n,b已知。 可用扩展欧几里得来求解该方程的一组特解。
这里给出下列几个定理用来求解方程:
1.当且仅当d|b时,方程ax = b(mod n)有解。d=gcd(a,n)
2.ax = b(mod n) 或者有d个不同解,或者无解。
3.令d=gcd(a,n) 假定对整数x', y', 有d = ax' + ny', 如果d | b, 则方程ax = b(mod n)有一个解的值为x0, 满足:
x0=x‘(b/d)(mod n)
4.假设方程ax = b(mod n)有解, x0是方程的任意一个解, 则方程对模n恰有d个不同的解, 分别为:
xi = x0 + i * (n / d), 其中 i = 1,2,3......d - 1
根据这4个定理,运用扩展欧几里得算法就能轻易的求出模线性方程的所有解了。
伪代码如下:
1 MODULAR_LINEAR_EQUATION_SOLVER(a,b,n) 2 (d,x',y')=EXTENDED_EUCLID(a,n) 3 if (d|b) 4 x0=x'(b/d) mod n 5 for i=0 to d-1 6 print (x0+i(n/d)) mod n 7 else 8 print "no solutions"
2)求解模线性方程组
x = a1(mod m1)
x = a2(mod m2)
x = a3(mod m3)
先求解方程组前两项。 x=m1*k1+a1=m2*k2+a2
-> m1*k1+m2*(-k2)=a2-a1
这个方程可以通过欧几里得求解出最小正整数的k1 则x=m1*k1+a1 显然x为两个方程的最小正整数解。
则这两个方程的通解为 X=x+k*LCM(m1,m2) -> X=x(mod LCM(m1,m2)) 就转换成了一个形式相同方程了
在通过这个方程和后面的其他方程求解。最终的结果就出来了。
以POJ2891为例 贴上代码:
Code:
1 /************************************************************************* 2 > File Name: poj2891.cpp 3 > Author: Enumz 4 > Mail: 369372123@qq.com 5 > Created Time: 2014年10月28日 星期二 02时50分07秒 6 ************************************************************************/ 7 8 #include<iostream> 9 #include<cstdio> 10 #include<cstdlib> 11 #include<string> 12 #include<cstring> 13 #include<list> 14 #include<queue> 15 #include<stack> 16 #include<map> 17 #include<set> 18 #include<algorithm> 19 #include<cmath> 20 #include<bitset> 21 #include<climits> 22 #define MAXN 100000 23 #define LL long long 24 using namespace std; 25 LL extended_gcd(LL a,LL b,LL &x,LL &y) //返回值为gcd(a,b) 26 { 27 LL ret,tmp; 28 if (b==0) 29 { 30 x=1,y=0; 31 return a; 32 } 33 ret=extended_gcd(b,a%b,x,y); 34 tmp=x; 35 x=y; 36 y=tmp-a/b*y; 37 return ret; 38 } 39 int main() 40 { 41 LL N; 42 while (cin>>N) 43 { 44 long long a1,m1; 45 long long a2,m2; 46 cin>>a1>>m1; 47 if (N==1) 48 printf("%lld\n",m1); 49 else 50 { 51 bool flag=0; 52 for (int i=2;i<=N;i++) 53 { 54 cin>>a2>>m2; 55 if (flag==1) continue; 56 long long x,y; 57 LL ret=extended_gcd(a1,a2,x,y); 58 if ((m2-m1)%ret!=0) 59 flag=1; 60 else 61 { 62 long long ans1=(m2-m1)/ret*x; 63 ans1=ans1%(a2/ret); 64 if (ans1<0) ans1+=(a2/ret); 65 m1=ans1*a1+m1; 66 a1=a1*a2/ret; 67 } 68 } 69 if (!flag) 70 cout<<m1<<endl; 71 else 72 cout<<-1<<endl; 73 } 74 } 75 return 0; 76 }