快速幂算法(SOJ 2984)
问题:给出一个非负整数$n$,输出第$n$个斐波那契数$F_{n}$对$10000$取模的结果。
分析:利用最直接的方法(从$F_{0},F_{1}$开始迭代)求$F_{n}$的时间复杂度为$O(n)$.如果要求出所有的$F_{n}$,这种方法已经是最优的了,但是如果只求某一个$F_{n}$,可以利用下面的表示形式
$\left[\begin{array}{c}F_{n}\\F_{n-1}\end{array}\right]=\left[\begin{array}{cc}1 &1\\ 1 &0 \end{array}\right]\left[\begin{array}{c}F_{n-1}\\F_{n-2}\end{array}\right]$,
$\left[\begin{array}{c}F_{n}\\F_{n-1}\end{array}\right]=\left[\begin{array}{cc}1 &1\\ 1 &0 \end{array}\right]^{n-2}\left[\begin{array}{c}F_{2}\\F_{1}\end{array}\right]$,
问题变成了求解矩阵$\left[\begin{array}{cc}1 &0\\1 &1\end{array}\right]$的$n-2$次幂。下面介绍矩阵的快速幂算法,在此之前先介绍一个数的快速幂算法。
一个数的快速幂算法:$a^{n}$.直接计算$a*a*\dots*a$时间复杂度为$O(n)$.现在考虑$n$的二进制形式,以$n=9$为例,
$9=(1001)_{2}=((1*2+0)*2+0)*2+1$,则
$a^{9}=a^{((1*2+0)*2+0)*2+1}$
$=(a^{(1*2+0)*2+0})^{2}*a$
$=((a^{1*2+0})^{2})^{2}*a$
$=(((a^{1})^{2})^{2})^{2}*a$
$=((((a^{0})^{2}*a)^{2})^{2})^{2}*a$.
我们注意到当遇到$1$时,对原有结果平方再乘以$a$;当遇到$0$时,对原有结果平方。至此我们给出一个数的快速幂算法:
(1)将$n$表示成二进制形式;
(2)初始化结果变量$ans=1$(即$a^{0}$);
(3)从$n$的二进制左边开始,遇到$1$令$ans=ans^{2}*a$,遇到$0$令$ans=ans^{2}$;
(4)返回$ans$.
这个算法的复杂度为$O(\log{n})$.
矩阵的快速幂算法:现在将上述算法拓展到矩阵上,$\mathbf{P}^{n}$.
原理是一样的,
初始化$ANS=\mathbf{I}$,即将$\mathbf{P}^{0}$当做单位矩阵$\mathbf{I}$;
遇到$1$令$ANS=ANS^{2}*\mathbf{P}$,遇到$0$令$ANS=ANS^{2}$.
模运算:取模运算有一些非常好的性质
$(x+y)\%c=(x\%c+y\%c)%c$,
$(x*y)\%c=((x\%c)*(y\%c))\%c$.
代码:
#include<iostream>
#include<cmath>
using namespace std;
int I[2][2];
int con = 10000;
void matMul(int flag)
{
int a, b, c, d;
a = (((I[0][0] % con)*(I[0][0] % con)) % con + ((I[0][1] % con)*(I[1][0] % con)) % con) % con;
b = (((I[0][0] % con)*(I[0][1] % con)) % con + ((I[0][1] % con)*(I[1][1] % con)) % con) % con;
c = (((I[1][0] % con)*(I[0][0] % con)) % con + ((I[1][1] % con)*(I[1][0] % con)) % con) % con;
d = (((I[1][0] % con)*(I[0][1] % con)) % con + ((I[1][1] % con)*(I[1][1] % con)) % con) % con;
I[0][0] = a;
I[0][1] = b;
I[1][0] = c;
I[1][1] = d;
if (flag)
{
a = (I[0][0] + I[0][1]) % con;
b = I[0][0];
c = (I[1][0]+ I[1][1]) % con;
d = I[1][0];
I[0][0] = a;
I[0][1] = b;
I[1][0] = c;
I[1][1] = d;
}
}
int main()
{
int N,n;
int i;
int digts;
int flag;
while (~scanf("%d", &N) && N >= 0)
{
if (N <= 2)
{
if (N < 2)
printf("%d\n", N);
else
printf("1\n");
continue;
}
I[0][0] = 1;
I[0][1] = 0;
I[1][0] = 0;
I[1][1] = 1;
n = N - 2;
digts = floor(log(n*1.0)/log(2.0)) + 1;
for (i = 1; i <= digts;i++)
{
flag = n / (1<<(digts-i));
n -= flag*(1 << (digts - i));
matMul(flag);
}
printf("%d\n", (I[0][0]+I[0][1])%con);
}
return 0;
}