浅谈矩阵加速——以时间复杂度为O(log n)的算法实现裴波那契数列第n项及前n之和使用矩阵加速法的优化求法
首先请连矩阵乘法乘法都还没有了解的同学简单看一下这篇博客:
https://blog.csdn.net/weixin_44049566/article/details/88945949
首先直接暴力求使用O(n)的时间复杂度肯定是不行的,所以我们应该使用更优的时间复杂度。
设f(n)为裴波那契数列第n项。让我们来构造两个矩阵:
和.
现在我们不妨将两个矩阵相乘,化简过后可以得到:,也就是.
如果再将得到的新矩阵乘以,便可以得到。
也就是我们想得到第n项,就可以这么实现:,也就是。
看到幂我们就可以想到快速幂,所以最后程序的时间复杂度便是O(log n)。
代码
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define N 10
#define LL long long
int n,mod;
struct Matrix {
LL n,m,c[N][N];
Matrix() { memset(c,0,sizeof(c)); };
void _read() {
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%lld",&c[i][j]);
}
Matrix operator * (const Matrix& a) {
Matrix r;
r.n=n;r.m=a.m;
for(int i=1;i<=r.n;i++)
for(int j=1;j<=r.m;j++)
for(int k=1;k<=m;k++)
r.c[i][j]= (r.c[i][j]+ (c[i][k] * a.c[k][j])%mod)%mod;
return r;
}
void _print() {
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
if(j!=1) cout<<" ";
cout<<c[i][j];
}
if(i!=n) puts("");
}
}
Matrix _power(int indexx) {
Matrix tmp,sum;tmp._pre1();sum._pre1();
while(indexx>0) {
if(indexx&1) sum=sum*tmp;
tmp=tmp*tmp;
indexx/=2;
}
return sum;
}
void _pre1() {
n=m=2;
c[1][1]=0;
c[1][2]=c[2][1]=c[2][2]=1;
}
void _pre2() {
n=1,m=2;
c[1][1]=c[1][2]=1;
}
}T,ans;
int main() {
cin>>n>>mod;
ans._pre2();
T._pre1();
if(n<=2) { cout<<1; return 0; }
T=T._power(n-3);
ans=ans*T;
cout<<ans.c[1][2];
}
那么前n项之和呢?
这里我们可以这么构造两个矩阵:(S(n)为前n项之和)
和将两个矩阵相乘,剩下的留给读者思考。
这里给出代码:
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
#define N 10
#define LL long long
int n,mod;
struct Matrix {
LL n,m,c[N][N];
Matrix() { memset(c,0,sizeof(c)); };
void _read() {
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%lld",&c[i][j]);
}
Matrix operator * (const Matrix& a) {
Matrix r;
r.n=n;r.m=a.m;
for(int i=1;i<=r.n;i++)
for(int j=1;j<=r.m;j++)
for(int k=1;k<=m;k++)
r.c[i][j]= (r.c[i][j]+ (c[i][k] * a.c[k][j])%mod)%mod;
return r;
}
void _print() {
for(int i=1;i<=n;i++) {
for(int j=1;j<=m;j++) {
if(j!=1) cout<<" ";
cout<<c[i][j];
}
if(i!=n) puts("");
}
}
Matrix _power(int indexx) {
Matrix tmp,sum;tmp._pre1();sum._pre1();
while(indexx>0) {
if(indexx&1) sum=sum*tmp;
tmp=tmp*tmp;
indexx/=2;
}
return sum;
}
void _pre1() {
n=m=3;
c[3][2]=c[3][3]=c[2][3]=c[1][1]=c[2][1]=c[3][1]=1;
}
void _pre2() {
n=1,m=3;
c[1][1]=2;c[1][3]=c[1][2]=1;
}
}T,ans;
int main() {
cin>>n>>mod;
if(n==2) return printf("%d",2%mod),0;
if(n==1) return printf("%d",1%mod),0;
ans._pre2();
T._pre1();
if(n<=2) { cout<<1; return 0; }
T=T._power(n-3);
ans=ans*T;
cout<<ans.c[1][1];
}