POJ 3233:Matrix Power Series 矩阵快速幂 乘积
Matrix Power Series
Time Limit: 3000MS | Memory Limit: 131072K | |
Total Submissions: 18450 | Accepted: 7802 |
Description
Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.
Input
The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.
Output
Output the elements of S modulo m in the same way as A is given.
Sample Input
2 2 4 0 1 1 1
Sample Output
1 2 2 3
题意很简单,就是矩阵相乘,然后求和。自己做的时候快速幂,发现快速幂竟然还是TLE。
不知道怎么搞,看了网上的代码,发现这个求和的深搜sum2很经典,充分利用偶数求和,假设是求1到6的和,先将6除以2,求1到3的和,然后对1到3的和 乘以3就是4到6的和,再一相加就是1到6的和。这段代码的思想很巧妙,很喜欢。以后求1到n的和时候可以用得上~
代码:
#include <iostream> #include <algorithm> #include <cmath> #include <vector> #include <string> #include <cstring> #pragma warning(disable:4996) using namespace std; struct matrix { int m[35][35]; }; int n, mod; long long ko; matrix no; matrix mu(matrix no1, matrix no2) { matrix t; memset(t.m, 0, sizeof(t.m)); int i, j, k; for (i = 1; i <= n; i++) { for (j = 1; j <= n; j++) { for (k = 1; k <= n; k++) { t.m[i][j] += no1.m[i][k] * no2.m[k][j]; if (t.m[i][j] >= mod) { t.m[i][j] %= mod; } } } } return t; } matrix multi(matrix no, long long x) { matrix b; memset(b.m, 0, sizeof(b.m)); int i; for (i = 1; i <= n; i++) { b.m[i][i] = 1; } while (x) { if (x & 1) b = mu(b, no); x = x >> 1; no = mu(no, no); } return b; } matrix add(matrix no1, matrix no2) { matrix t; int i, j; for (i = 1; i <= n; i++) { for (j = 1; j <= n; j++) { t.m[i][j] = no1.m[i][j] + no2.m[i][j]; if (t.m[i][j] >= mod) { t.m[i][j] %= mod; } } } return t; } matrix sum2(long long i)//假设i为7 { if (i == 1)return no; if (i & 1) return add(multi(no, i), sum2(i - 1));//7+6+... else { long long k = i >> 1;//3 matrix s = sum2(k);//1 2 3 return add(s, mu(s, multi(no, k)));1 2 3 + 4 5 6 } } int main() { //freopen("i.txt","r",stdin); //freopen("o.txt","w",stdout); int i, j; cin >> n >> ko >> mod; for (i = 1; i <= n; i++) { for (j = 1; j <= n; j++) { scanf("%d", &no.m[i][j]); if (no.m[i][j] >= mod) { no.m[i][j] %= mod; } } } no = sum2(ko); for (i = 1; i <= n; i++) { for (j = 1; j <= n; j++) { if (j == 1) cout << no.m[i][j]; else cout << " " << no.m[i][j]; } cout << endl; } //system("pause"); return 0; }
版权声明:本文为博主原创文章,未经博主允许不得转载。