Matrix Power Series 矩阵套矩阵

题目:

查看原题点击 传送门

题意:已知一个n*n的矩阵A,和一个正整数k,求S = A + A2 + A3 + … + Ak

思路:

不妨设S (k)= A + A2 + A3 + … + Ak

则有S(k)=S(k-1)+AK;然后可以用矩阵快速幂进行地推

 

 

 

 

E 为单位矩阵,A为最开始输入的矩阵,0为元素全为0的矩阵;

 

我们只用维护好这个大矩阵的快速幂就行了。

 

 接下来就是上代码了

代码:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
using namespace std;
#define ll long long
int mod;
int n,x;
struct node{
    ll mm[31][31];
}sum;
struct node1{
    node m[3][3];
}ans,base;
node mul(node a,node b){    //小矩阵之间的矩阵乘 
    node tp;
    memset(tp.mm,0,sizeof(tp));
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            for(int k=1;k<=n;k++){
                tp.mm[i][j]=(tp.mm[i][j]+a.mm[i][k]*b.mm[k][j])%mod;
            }
        }
    }
    return tp;
}
node pluss(node a,node b){    //大矩阵之间的矩阵加 
    node tp;
    memset(tp.mm,0,sizeof(tp));
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            tp.mm[i][j]=(a.mm[i][j]+b.mm[i][j])%mod;
        }
    }
    return tp;
}
node1 nodemul(node1 a,node1 b){   //大矩阵之间的矩阵乘 
    node1 tp;
    memset(tp.m,0,sizeof(tp.m));
    for(int i=1;i<=2;i++){
        for(int j=1;j<=2;j++){
            for(int k=1;k<=2;k++){
                tp.m[i][j]=pluss(tp.m[i][j],mul(a.m[i][k],b.m[k][j]));
            }
        }
    }
    return tp;
}
void quick(int k){   //对大矩阵进行快速幂计算 
    memset(ans.m,0,sizeof(ans.m));
    for(int i=1;i<=n;i++)ans.m[1][1].mm[i][i]=ans.m[2][2].mm[i][i]=1;
    while(k){
        if(k&1)ans=nodemul(ans,base);
        base=nodemul(base,base);
        k>>=1;
    }
}
int main(){
    scanf("%d%d%d",&n,&x,&mod);
    node tp; 
    memset(base.m,0,sizeof(base.m));
    memset(sum.mm,0,sizeof(sum.mm));
    memset(tp.mm,0,sizeof(tp.mm));
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            scanf("%lld",&base.m[2][2].mm[i][j]);
            tp.mm[i][j]=base.m[2][2].mm[i][j];    //用来保存A 
        }
    }
    for(int i=1;i<=n;i++){
        base.m[1][1].mm[i][i]=base.m[2][1].mm[i][i]=1;
    }
    quick(x);
    sum=mul(tp,ans.m[2][1]);
    for(int i=1;i<=n;i++){
        for(int j=1;j<=n;j++){
            printf("%d ",sum.mm[i][j]);
        }
        printf("\n");
    }
    return 0;
}

2020-07-23  19:37:01

 

posted @ 2020-07-23 19:37  Rain_luo  阅读(229)  评论(0编辑  收藏  举报