HDU 4965 - Fast Matrix Calculation ( 矩阵快速幂 )

题意

给出一个 n * k 的矩阵A, 一个 k * n 的矩阵B ( 4 <= n <= 1000 ) (2 <= k<= 6)
进行以下操作 :
1. 计算n * n的矩阵C = A * B
2. 计算矩阵 M = Cnn
3. 矩阵M中每个元素模6得到M’
4. 计算M’中每个元素的和

思路

比赛期间没想到矩阵M的计算可以这样处理:
M=(AB)(AB)(AB)...(AB)
根据矩阵乘法的结合律
M=A(BA)(BA)...(BA)B
令C = (B*A)
中间的C矩阵至多是6 * 6矩阵, 就把原来至多1000 * 1000的快速幂简化成为 6 * 6的快速幂
最终的结果即 A * C * B


矩阵快速幂矩阵快速幂模板

AC代码

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <cmath>
#define mst(a) memset(a, 0, sizeof a)
using namespace std;
const int maxn = 1e3+5;
const int mmaxn = 6;
int mod = 6;
typedef long long ll;
int a[maxn][10], b[10][maxn], M[maxn][10], M2[maxn][maxn];
int n, kk;

struct mat{
    int s[mmaxn][mmaxn];
    mat(){
        mst(s);
    };
    mat operator * (const mat& c) {
        mat ans;
        for (int i = 0; i < mmaxn; i++) //矩阵乘法
            for (int j = 0; j < mmaxn; j++)
                for (int k = 0; k < mmaxn; k++)
                    ans.s[i][j] = (ans.s[i][j] + s[i][k] * c.s[k][j]) % mod;
        return ans;
    }
}str, c;

mat pow_mod(ll k) {
    if (k == 1)
        return str;
    mat a = pow_mod(k/2);//不能改
    mat ans = a * a;
    if (k & 1)
        ans = ans * str;
    return ans;
}

void _baplus(){
    for( int k = 0; k < n; k++ )
        for( int i = 0; i < kk; i++ )
            if( b[i][k] )
                for( int j = 0; j < kk; j++ )
                    str.s[i][j] += b[i][k]*a[k][j];
}

void _acbplus(){
    mst(M);
    mst(M2);
    // m = a*c ( 1000*6*6*6 --> 1000*6 )
    for( int k = 0; k < kk; k++ )
        for( int i = 0; i < n; i++ )
            if( a[i][k] )
                for( int j = 0; j < kk; j++ )
                    M[i][j] += a[i][k]*c.s[k][j];
    // m = m*b ( 1000*6*6*1000 --> 1000*1000 )
    for( int k = 0; k < kk; k++ )
        for( int i = 0; i < n; i++ )
            if( M[i][k] )
                for( int j = 0; j < n; j++ )
                    M2[i][j] += M[i][k]*b[k][j];
}

ll getans(){
    ll ans = 0;
    for( int i = 0; i < n; i++ )
        for( int j = 0; j < n; j++ )
            ans += M2[i][j]%mod;
    return ans;
}

int main()
{
    while( scanf("%d%d",&n, &kk) == 2 && n ){
        mst(str.s);
        mst(c.s);
        for( int i = 0; i < n; i++ )
            for( int j = 0; j < kk; j++ )
                scanf("%d",&a[i][j]);
        for( int i = 0; i < kk; i++ )
            for( int j = 0; j < n; j++ )
                scanf("%d",&b[i][j]);
        _baplus();
        ll m = n*n-1;
        c = pow_mod(m);
        _acbplus();
        printf("%lld\n",getans());
    }
    return 0;
}
posted @ 2018-05-01 19:07  JinxiSui  阅读(128)  评论(0编辑  收藏  举报