Matrix Power Series - 矩阵快速幂对分块矩阵加速

题目
其中\(A\)是一个\(n \times n\)的矩阵,\(S_k = A + A^2 + A^3 + … + A^k\),求\(S_k\)
按照数论出现和,那么构造一个和数论一起递推的式子\(S_k = S_{k - 1} + A^k\)
那么假设\(A\)不是矩阵

\[\left[\begin{array}{l} 1 & 1\\ 0 & A \end{array}\right] \times \left[\begin{array}{l} S_{k - 1}\\ A^k \end{array}\right] = \left[\begin{array}{l} S_k\\ A^{k + 1} \end{array}\right]\]

用E代替1,用矩阵代替数字,转换一下就是

\[\left[\begin{array}{l} E & E\\ 0 & A \end{array}\right] \times \left[\begin{array}{l} S_{k - 1}\\ A^k \end{array}\right] = \left[\begin{array}{l} S_k\\ A^{k + 1} \end{array}\right] \]

构建了个分块矩阵,大小是\(2n \times 2n\)
加速矩阵\(\left[\begin{array}{l} E & E\\ 0 & A \end{array}\right]\) 初始矩阵\(\left[\begin{array}{l} S_1\\ A^2 \end{array}\right]\)

#include <iostream>
#include <cstdio>
#include <cstring>
#define ll long long
using namespace std;
const int N = 61;
int mod, k, n;
struct Matrix{//矩阵
    int n,m;
    int a[N][N];
    Matrix(int x,int y):n(x),m(y){memset(a,0,sizeof(a));}
    Matrix operator * (const Matrix &b){
        Matrix ans(n,b.m);
        for(int i = 0; i < n; i++){
            for(int j = 0; j < b.m; j++){
                for(int k = 0; k < m; k++){
                    ans.a[i][j] = (ans.a[i][j] + a[i][k] * b.a[k][j] % mod) % mod;
                }
            }
        }
        return ans;
    }
};
Matrix ksm(Matrix a, ll b){
    Matrix ans(a.n, a.m);
    for(int i = 0; i <= max(a.n, a.m); i++)
        ans.a[i][i] = 1;

    while(b){
        if(b & 1)ans = ans * a;
        a = a * a;
        b >>= 1;
    }
    return ans;
}
Matrix a(35, 35);
void solve(){
    Matrix base(k * 2, k * 2);
    for(int i = 0; i < k; i++)
        base.a[i][i] = base.a[i][i + k] = 1;

    for(int i = k; i < 2 * k; i++)
        for(int j = k; j < 2 * k; j++)
            base.a[i][j] = a.a[i - k][j - k];

    base = ksm(base, n - 1);
    

    Matrix ans(k * 2, k);
    for(int i = 0; i < k; i++)
        for(int j = 0; j < k; j++)
            ans.a[i][j] = a.a[i][j];

    a = a * a;
    for(int i = k; i < 2 * k; i++)
        for(int j = 0; j < k; j++)
            ans.a[i][j] = a.a[i - k][j];

    ans = base * ans;
    for(int i = 0; i < k; i++){
        for(int j = 0; j < k; j++)
            printf("%d ", ans.a[i][j]);
        putchar('\n');
    }
}
int main(){
    scanf("%d%d%d", &k, &n, &mod);
    for(int i = 0; i < k; i++)
        for(int j = 0; j < k; j++)
            scanf("%d", &a.a[i][j]);
    solve();
    return 0;
}

posted @ 2020-05-14 21:42  Emcikem  阅读(131)  评论(0编辑  收藏  举报