POJ 3233
描述
Given a n × n matrix A and a positive integer k, find the sum \(S = A + A^{2} + A^{3} + … + A^{k}.\)
输入
The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 10e9) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32768, giving A’s elements in row-major order.
输出
Output the elements of S modulo m in the same way as A is given.
样例输入
2 2 4
0 1
1 1
样例输出
1 2
2 3
思路
起初我直接矩阵快速幂,然后把结果相加,可是TLE了,然后去百度看看别人的思路,找到一个神奇的公式。
\[ B =
\left[
\begin{matrix}
A & E \\
0 & E
\end{matrix}
\right]
\]
\[ B^{k+1} =
\left[
\begin{matrix}
A^{k+1} & E + A + A^{2} + A^{3} + … + A^{k} \\
0 & E
\end{matrix}
\right]
\]
有了这个公式之后,写了个分块矩阵就好,顺便重构了一下我那惨不忍睹的矩阵快速幂模版。
虽然被大佬嫌弃代码啰嗦,但还是放上来吧。
代码
#include <cstdio>
#include<cstring>
#define ll long long
#define maxn 32
using namespace std;
int n, k, mod;
struct Mat
{
ll f[maxn][maxn];
void cls(){memset(f, 0, sizeof(f));}//全部置为0
Mat() {cls();}
void myprint(int n)//输出n阶顺序主子式
{
for(int i = 0; i < n; i++)
{
for(int j = 0; j < n; j++)
printf("%d ", f[i][j] % mod);
printf("\n");
}
}
friend Mat operator + (Mat a, Mat b)
{
Mat res;
for(int i = 0; i < maxn; i++)
for(int j = 0; j < maxn; j++)
res.f[i][j] = a.f[i][j] + b.f[i][j];
return res;
}
friend Mat operator - (Mat a, Mat b)
{
Mat res;
for(int i = 0; i < maxn; i++)
for(int j = 0; j < maxn; j++)
{
res.f[i][j] = a.f[i][j] - b.f[i][j];
while(res.f[i][j] < 0) res.f[i][j] += mod;
}
return res;
}
friend Mat operator * (Mat a, Mat b)
{
Mat res;
for(int i = 0; i < maxn; i++)for(int j = 0; j < maxn; j++)
for(int k = 0; k < maxn; k++)
(res.f[i][j] += a.f[i][k] * b.f[k][j]) %= mod;
return res;
}
} E, I;
struct MatDiv
{
Mat s[2][2];
void set(Mat x1, Mat x2, Mat x3, Mat x4)
{
s[0][0] = x1; s[0][1] = x2;
s[1][0] = x3; s[1][1] = x4;
}
friend MatDiv operator * (MatDiv a, MatDiv b)
{
MatDiv res;
for(int i = 0; i < 2; i++) for(int j = 0; j < 2;j++)
for(int k = 0; k < 2; k++)
res.s[i][j] = a.s[i][k] * b.s[k][j] + res.s[i][j];
return res;
}
};
MatDiv quick_pow(MatDiv a)
{
MatDiv ans;
ans.set(E, I, E, I);
ll b = k;
while(b != 0)
{
if(b & 1) ans = ans * a;
b >>= 1;
a = a * a;
}
return ans;
}
int main()
{
for(int i = 0; i < maxn; i++)
E.f[i][i] = 1;
while(~scanf("%d %d %d", &n, &k, &mod))
{
Mat A; k++;
for(int i = 0; i < n; i++)
for(int j = 0; j < n; j++)
scanf("%d", &A.f[i][j]);
MatDiv B; B.set(A, E, I, E);
B = quick_pow(B);
A = B.s[0][1] - E;
A.myprint(n);
}
return 0;
}