Matrix Power Series(矩阵快速幂)

C - Matrix Power Series
Time Limit:3000MS     Memory Limit:131072KB     64bit IO Format:%I64d & %I64u

Description

Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.

Input

The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.

Output

Output the elements of S modulo m in the same way as A is given.

Sample Input

2 2 4
0 1
1 1

Sample Output

1 2
2 3

这题用了二分和快速幂。一道很经典的矩阵快速幂的题目。

首先是二分,求一个矩阵连续和可以分奇偶来讨论,使其规模减半。

比如当为奇数的时候

AK=(A1+A2+...A((K-1)/2))+A((K+1)/2)+A((K+1)/2)*(A1+A2+....A((K-1)/2))

当为偶数的时候

AK=(A1+A2+...A(K/2))+A(K/2)*(A1+A2+...A(K/2))

然后用两个函数模块来执行,一个是add(A,B),一个是mul(A,B)

设定一个当前减半后的函数的矩阵(用快速幂来求),node now;

设定一个用来储存连续和的结果的矩阵(用递归来调用本函数),node temp;

当为奇数时候

temp=add(temp,mul(temp,now));

temp=add(temp,now);

当为偶数的时候

temp=add(temp,mul(temp,now));

最后输出temp得到结果

 

然后是矩阵快速幂 cal(int k)(k为幂数)

矩阵快速幂的思想来自于a^b%m,同样都是转换成二进制,然后根据二进制的位数来计算

比如A^19  =>  (A^16)*(A^2)*(A^1),显然采取这样的方式计算时因子数将是log(n)级别的(原来的因子数是n),不仅这样,因子间也是存在某种联 系的,比如A^4能通过(A^2)*(A^2)得到,A^8又能通过(A^4)*(A^4)得到,这点也充分利用了现有的结果作为有利条件。下面举个例子 进行说明:

现在要求A^156,而156(10)=10011100(2) 

也就有A^156=>(A^4)*(A^8)*(A^16)*(A^128)  考虑到因子间的联系,我们从二进制10011100中的最右端开始计算到最左端。核心代码为

 1 node cal(int k)
 2 {
 3     node res=ori;//初始化成单位阵
 4     while(k)
 5     {
 6         if(k&1)//判断是奇数(1)的话
 7             res=res*A;//该二进制位上有数字,就乘以该幂次的矩阵
 8         k>>=1;//k右移一位
 9         A*=A;//矩阵以平方速度自乘
10     }
11     return res;
12 }

或者可以是

 1 node cal(int k)//矩阵快速幂
 2 {
 3     node p,q;
 4     p=init;//初始阵
 5     q=unit;//单位阵
 6     while(k!=1)
 7     {
 8         if(k&1)//幂为奇数
 9         {
10             k--;
11             q=mul(p,q);//把单独乘的那次空出来,在最后再乘
12         }
13         else
14         {
15             k>>=1;//除二
16             p=mul(p,p);//二分
17         }
18     }
19     p=mul(p,q);
20     return p;
21 }

该代码在十进制的角度,来看二进制的算法,奇数就减一,偶数的时候就除二。是以  ans=二的n次方*单独剩下凑不齐n次方的矩阵次数  的形式

 

这两个点是该题的关键,能下降计算的时间复杂度,下附完整代码

 1 #include<iostream>
 2 #include<stdio.h>
 3 #include<string.h>
 4 #include<math.h>
 5 #include<algorithm>
 6 using namespace std;
 7 int n,k,m;
 8 struct node
 9 {
10     int mapp[50][50];
11 }init,res,unit,ret;
12 node add(node a,node b)
13 {
14     int i,j;
15     node c;
16     for(i=0;i<n;i++)
17         for(j=0;j<n;j++)
18             {c.mapp[i][j]=a.mapp[i][j]+b.mapp[i][j];
19             c.mapp[i][j]%=m;}
20     return c;
21 }
22 node mul(node a,node b)
23 {
24     int i,j,k;
25     node c;
26     for(i=0;i<n;i++)
27         for(j=0;j<n;j++)
28         c.mapp[i][j]=0;
29     for(i=0;i<n;i++)
30         for(j=0;j<n;j++)
31             for(k=0;k<n;k++)
32     {
33         c.mapp[i][j]+=a.mapp[i][k]*b.mapp[k][j];
34         c.mapp[i][j]%=m;
35     }
36     return c;
37 }
38 node cal(int k)//矩阵快速幂
39 {
40     node p,q;
41     p=init;//初始阵
42     q=unit;//单位阵
43     while(k!=1)
44     {
45         if(k&1)//幂为奇数
46         {
47             k--;
48             q=mul(p,q);//把单独乘的那次空出来,在最后再乘
49         }
50         else
51         {
52             k>>=1;//除二
53             p=mul(p,p);//二分
54         }
55     }
56     p=mul(p,q);
57     return p;
58 }
59 node sum(int k)
60 {
61     if(k==1)
62         return init;
63     node temp,now;
64     temp=sum(k/2);//总和
65     if(k&1)//按二进制,按位来取,判断是否为奇数
66     {
67         now=cal(k/2+1);//当前一个矩阵
68         temp=add(temp,mul(temp,now));
69         temp=add(now,temp);
70     }
71     else
72     {
73         now=cal(k/2);
74         temp=add(temp,mul(temp,now));
75     }
76     return temp;
77 }
78 int main()
79 {
80     int i,j;
81     scanf("%d%d%d",&n,&k,&m);
82     for(i=0;i<n;i++)
83         for(j=0;j<n;j++)
84         {
85             scanf("%d",&init.mapp[i][j]);
86             init.mapp[i][j]%=m;
87             if(i==j)
88             unit.mapp[i][j]=1;
89             else
90             unit.mapp[i][j]=0;
91         }
92     ret=sum(k);
93     for(i=0;i<n;i++)
94         {for(j=0;j<n;j++)
95         printf("%d ",ret.mapp[i][j]);
96         printf("\n");
97         }
98     return 0;
99 }

 

posted @ 2015-04-04 11:23  kingofprank  阅读(260)  评论(0编辑  收藏  举报