HDU-4965 Fast Matrix Calculation

HDU-4965 Fast Matrix Calculation

题目链接:Problem - 4965 (hdu.edu.cn)


​ 正常算出矩阵 \(A\times B\) 然后对该矩阵做一个矩阵快速幂的话,时间复杂度是 \(O(n^3\times\log(n^2))\) 的。这个复杂度是肯定过不了的,所以我们要进行优化,我们将式子展开:

\[\begin{aligned} (A\times B)^{n^2}&=A\times B\times A\times B\times\dots\times A\times B \\ &=A\times(B\times A\times B\times\dots\times A)\times B \\ &= A\times(B\times A)^{n^2-1}\times B \end{aligned} \]

​ 我们注意到 \(A\times B\) 是一个 \(n\times n\) 的矩阵,而 \(B\times A\) 只是一个 \(k\times k\) 的矩阵,而 \(k\le 6\),所以我们可以对 \(B\times A\) 这个矩阵进行快速幂。那么时间复杂度就是 \(O(k^3\times\log(n^2)+n^2\times k)\) 。这个时间复杂度是可以通过这道题的。

​ 代码如下:

#include<bits/stdc++.h>
using namespace std;
const int MOD = 6;
const int MAXN = 1005;
int n,k;
struct matrix
{
	int v[10][10];
	matrix(){memset(v,0,sizeof v);}
	matrix operator * (const matrix &x)const
	{
		matrix ans;
		for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) for(int l=1;l<=k;++l)
			ans.v[i][j]=(ans.v[i][j]+v[i][l]*x.v[l][j])%MOD;
		return ans;
	}
}t[31],res;
int a[MAXN][MAXN],b[MAXN][MAXN],ans[MAXN][MAXN];
int main()
{
	while(scanf("%d %d",&n,&k)!=EOF)
	{
		if(!n&&!k) break;
		for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) scanf("%d",&a[i][j]);
		for(int i=1;i<=k;++i) for(int j=1;j<=n;++j) scanf("%d",&b[i][j]);
		for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) t[0].v[i][j]=0;
		for(int i=1;i<=k;++i) for(int j=1;j<=k;++j) for(int l=1;l<=n;++l)
				t[0].v[i][j]=(t[0].v[i][j]+b[i][l]*a[l][j])%MOD;
		int up=log2(n*n-1);
		for(int i=1;i<=up;++i)
			t[i]=t[i-1]*t[i-1];
		int target=n*n-1;
		for(int i=1;i<=k;++i)
		{
			for(int j=1;j<=k;++j)
			{
				if(i==j) res.v[i][j]=1;
				else res.v[i][j]=0;
			}
		}
		for(int i=up;i>=0;--i)
			if((1<<i)&target) res=res*t[i];
		for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) ans[i][j]=0;
		for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) for(int l=1;l<=k;++l)
			ans[i][j]=(ans[i][j]+a[i][l]*res.v[l][j])%MOD;
		for(int i=1;i<=n;++i) for(int j=1;j<=k;++j) a[i][j]=ans[i][j];
		for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) ans[i][j]=0;
		for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) for(int l=1;l<=k;++l)
			ans[i][j]=(ans[i][j]+a[i][l]*b[l][j])%MOD;
		int out=0;
		for(int i=1;i<=n;++i) for(int j=1;j<=n;++j) out+=ans[i][j];
		printf("%d\n",out);
	}
	return 0;
}
posted @ 2021-09-15 15:02  夜空之星  阅读(31)  评论(0编辑  收藏  举报