HDU-4965 Fast Matrix Calculation
HDU-4965 Fast Matrix Calculation
题目链接:Problem - 4965 (hdu.edu.cn)
正常算出矩阵 \(A\times B\) 然后对该矩阵做一个矩阵快速幂的话,时间复杂度是 \(O(n^3\times\log(n^2))\) 的。这个复杂度是肯定过不了的,所以我们要进行优化,我们将式子展开:
\[\begin{aligned}
(A\times B)^{n^2}&=A\times B\times A\times B\times\dots\times A\times B
\\
&=A\times(B\times A\times B\times\dots\times A)\times B
\\
&=
A\times(B\times A)^{n^2-1}\times B
\end{aligned}
\]
我们注意到 \(A\times B\) 是一个 \(n\times n\) 的矩阵,而 \(B\times A\) 只是一个 \(k\times k\) 的矩阵,而 \(k\le 6\),所以我们可以对 \(B\times A\) 这个矩阵进行快速幂。那么时间复杂度就是 \(O(k^3\times\log(n^2)+n^2\times k)\) 。这个时间复杂度是可以通过这道题的。
代码如下:
#include<bits/stdc++.h>
using namespace std;
const int MOD = 6;
const int MAXN = 1005;
int n,k;
struct matrix
{
int v[10][10];
matrix(){memset(v,0,sizeof v);}
matrix operator * (const matrix &x)const
{
matrix ans;
for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) for(int l=1;l<=k;++l)
ans.v[i][j]=(ans.v[i][j]+v[i][l]*x.v[l][j])%MOD;
return ans;
}
}t[31],res;
int a[MAXN][MAXN],b[MAXN][MAXN],ans[MAXN][MAXN];
int main()
{
while(scanf("%d %d",&n,&k)!=EOF)
{
if(!n&&!k) break;
for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) scanf("%d",&a[i][j]);
for(int i=1;i<=k;++i) for(int j=1;j<=n;++j) scanf("%d",&b[i][j]);
for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) t[0].v[i][j]=0;
for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) for(int l=1;l<=n;++l)
t[0].v[i][j]=(t[0].v[i][j]+b[i][l]*a[l][j])%MOD;
int up=log2(n*n-1);
for(int i=1;i<=up;++i)
t[i]=t[i-1]*t[i-1];
int target=n*n-1;
for(int i=1;i<=k;++i)
{
for(int j=1;j<=k;++j)
{
if(i==j) res.v[i][j]=1;
else res.v[i][j]=0;
}
}
for(int i=up;i>=0;--i)
if((1<<i)&target) res=res*t[i];
for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) ans[i][j]=0;
for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) for(int l=1;l<=k;++l)
ans[i][j]=(ans[i][j]+a[i][l]*res.v[l][j])%MOD;
for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) a[i][j]=ans[i][j];
for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) ans[i][j]=0;
for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) for(int l=1;l<=k;++l)
ans[i][j]=(ans[i][j]+a[i][l]*b[l][j])%MOD;
int out=0;
for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) out+=ans[i][j];
printf("%d\n",out);
}
return 0;
}
路漫漫其修远兮,吾将上下而求索。