【快速幂运算与矩阵快速幂专题】

Posted on 2015-08-01 21:07  LLGemini  阅读(490)  评论(0编辑  收藏  举报

今天练了不少快速幂的手,这一直是之前的一个漏洞吧,现在把洞补上。东西是挺简单的东西,当然题目多变,做起来也问题多多。

首先放一下核心代码:

//矩阵乘法
const int mod = 10000;
const int maxn = 2;
struct matrix
{
    int a[maxn][maxn];
};
matrix mul(matrix A, matrix B)
{
    matrix ret;
    memset(ret.a, 0, sizeof(ret.a));
    for(int i = 0; i < maxn; ++i)
        for(int k = 0; k < maxn; ++k)
            if(A.a[i][k]) //注意此处的优化,一般矩阵复杂度还是O(n^3),然而当矩阵是稀疏矩阵,即存在很多0时,复杂度则甚至可能降为O(n^2);
                for(int j = 0; j < maxn; ++j)
                {
                    ret.a[i][j] += A.a[i][k] * B.a[k][j];
                    if(ret.a[i][j] >= mod)
                        ret.a[i][j] %= mod;
                }
    return ret;
}
//快速幂计算。二分原理: a^k = (a^2)^(k/2) = ((a^2)^2)^(k/4);
matrix expo(matrix p, int k)
{
    if(k == 1) return p;
    matrix ret;
    memset(ret.a, 0, sizeof(ret.a));
    for(int i = 0; i < maxn; ++i)
        ret.a[i][i] = 1;
    if(k == 0) return ret;
    while(k)
    {
        if(k & 1)
            ret = mul(p, ret);
        p = mul(p, p);
        k >>= 1;
    }
    return ret;
}

对于此处代码可以将幂转化成二进制形式来理解:例如当k=156时,156 = 10011100 = 128 + 16 + 8 + 4; ans = a156 = a128 * a16 * a8 * a4  

从右向左每一位i(i >= 0)即ai,碰见一个1就把ai乘到ans里;

1  while(k)
2 {
3     if(k & 1)
4         ret = mul(p, ret);
5     p = mul(p, p);
6     k >>= 1;
7 }

 结构体形式模板:

 1 #include <iostream>
 2 #include <cstdio>
 3 #include <cstring>
 4 using namespace std;
 5 typedef long long ll;
 6 int n, k, mod;
 7 const int maxn = 100;
 8 struct matrix
 9 {
10     int a[maxn][maxn];
11     void print()
12     {
13         for(int i = 0; i < n; i++)
14         {
15             for(int j = 0; j < n; j++)
16             {
17                 if(j) printf(" ");
18                 printf("%d", a[i][j] % mod);
19             }
20             printf("\n");
21         }
22     }
23     matrix& operator += (const matrix& rhs)
24     {
25         for(int i = 0; i < n; ++i)
26             for(int j = 0; j < n; ++j)
27                 if(rhs.a[i][j])
28                 {
29                     a[i][j] += rhs.a[i][j];
30                     if(a[i][j] >= mod)  a[i][j] %= mod;
31                 }
32         return *this;
33     }
34     matrix& operator *= (const matrix& rhs)
35     {
36         matrix ret;
37         memset(ret.a, 0, sizeof(ret.a));
38         for(int i = 0; i < n; ++i)
39             for(int k = 0; k < n; ++k)
40                 if(a[i][k])
41                     for(int j = 0; j < n; ++j)
42                     {
43                         ret.a[i][j] += a[i][k] * rhs.a[k][j];
44                         if(ret.a[i][j] >= mod)
45                             ret.a[i][j] %= mod;
46                     }
47         memcpy(a, ret.a, sizeof(a));
48         return *this;
49     }
50 };
51 matrix expo(matrix p, int k)
52 {
53     if(k == 1) return p;
54     matrix ret;
55     memset(ret.a, 0, sizeof(ret.a));
56     for(int i = 0; i < n; ++i)
57         ret.a[i][i] = 1;
58     if(k == 0) return ret;
59     while(k)
60     {
61         if(k & 1)
62             ret *= p;
63         p *= p;
64         k >>= 1;
65     }
66     return ret;
67 }
结构体形式模板

 

练手:

1、经典入门题:

 1 /*
 2 Problem: Fibonacci
 3 Tips: 矩阵快速幂
 4 Date: 2015.8.1
 5 */
 6 #include <iostream>
 7 #include <cstdio>
 8 #include <cstring>
 9 using namespace std;
10 typedef long long ll;
11 const int mod = 10000;
12 const int maxn = 2;
13 struct matrix
14 {
15     int a[maxn][maxn];
16 };
17 matrix mul(matrix A, matrix B)
18 {
19     matrix ret;
20     memset(ret.a, 0, sizeof(ret.a));
21     for(int i = 0; i < maxn; ++i)
22         for(int k = 0; k < maxn; ++k)
23             if(A.a[i][k])
24                 for(int j = 0; j < maxn; ++j)
25                 {
26                     ret.a[i][j] += A.a[i][k] * B.a[k][j];
27                     if(ret.a[i][j] >= mod)
28                         ret.a[i][j] %= mod;
29                 }
30     return ret;
31 }
32 matrix expo(matrix p, int k)
33 {
34     if(k == 1) return p;
35     matrix ret;
36     memset(ret.a, 0, sizeof(ret.a));
37     for(int i = 0; i < maxn; ++i)
38     {
39         ret.a[i][i] = 1;
40     }
41     if(k == 0) return ret;
42     while(k)
43     {
44         if(k & 1)
45             ret = mul(p, ret);
46         p = mul(p, p);
47         k >>= 1;
48     }
49     return ret;
50 }
51 int main()
52 {
53     int k;
54     matrix m;
55     while(~scanf("%d", &k))
56     {
57         if(k == -1) break;
58         if(!k) printf("0\n");
59         else if(k == 1) printf("1\n");
60         else
61         {
62             m.a[0][0] = m.a[0][1] = m.a[1][0] = 1;
63             m.a[1][1] = 0;
64 
65             matrix ans = expo(m, k-1);
66             printf("%d\n", ans.a[0][0]%mod);
67         }
68     }
69     return 0;
70 }
POJ 3070

2、学习构造矩阵:

 1 /*
 2 Problem: NYoj 301
 3 Tips:    矩阵快速幂 构造矩阵
 4 递推式: f(x)=a*f(x-2)+b*f(x-1)+c
 5 | f(n-2) f(n-1) 1 |   | a  0  0 |
 6 |    0     0    0 | * | b  1  0 |
 7 |    0     0    0 |   | c  0  1 |
 8 Date: 2015.8.1
 9 */
10 #include <iostream>
11 #include <cstdio>
12 #include <cstring>
13 using namespace std;
14 typedef long long ll;
15 const int mod = 1000007;
16 const int maxn = 3;
17 struct matrix
18 {
19     ll a[maxn][maxn];
20 };
21 matrix mul(matrix A, matrix B)
22 {
23     matrix ret;
24     memset(ret.a, 0, sizeof(ret.a));
25     for(int i = 0; i < maxn; ++i)
26         for(int k = 0; k < maxn; ++k)
27             if(A.a[i][k])
28                 for(int j = 0; j < maxn; ++j)
29                 {
30                     ret.a[i][j] += A.a[i][k] * B.a[k][j];
31                     if(ret.a[i][j] >= mod)
32                     {
33                         ret.a[i][j] %= mod;
34                     }
35                     else if(ret.a[i][j] < 0)
36                     {
37                         ret.a[i][j] += mod;
38                     }
39                 }
40     return ret;
41 }
42 matrix expo(matrix p, int k)
43 {
44     if(k == 1) return p;
45     matrix ret;
46     memset(ret.a, 0, sizeof(ret.a));
47     for(int i = 0; i < maxn; ++i)
48     {
49         ret.a[i][i] = 1;
50     }
51     if(k == 0) return ret;
52     while(k)
53     {
54         if(k & 1)
55             ret = mul(p, ret);
56         p = mul(p, p);
57         k >>= 1;
58     }
59     return ret;
60 }
61 
62 int main()
63 {
64     ll f1, f2, a, b, c, n;
65     matrix m1, m2;
66     int T; scanf("%d", &T);
67     while(T--)
68     {
69         scanf("%lld%lld%lld%lld%lld%lld", &f1, &f2, &a, &b, &c, &n);
70         if(n == 1) printf("%lld\n", (f1 + mod) % mod);
71         else if(n == 2) printf("%lld\n", (f2 + mod) % mod);
72         else
73         {
74             memset(m1.a, 0, sizeof(m1.a));
75             memset(m2.a, 0, sizeof(m2.a));
76             m1.a[0][0] = f2, m1.a[0][1] = f1, m1.a[0][2] = 1;
77             m2.a[0][0] = b, m2.a[1][0] = a, m2.a[2][0] = c;
78             m2.a[0][1] = m2.a[2][2] = 1;
79             matrix ans = expo(m2, n-2);
80             ans = mul(m1, ans);
81             if(ans.a[0][0] >= mod)
82             {
83                 ans.a[0][0] %= mod;
84             }
85             else if(ans.a[0][0] < 0)
86             {
87                 ans.a[0][0] += mod;
88             }
89             printf("%lld\n", ans.a[0][0]);
90         }
91     }
92     return 0;
93 }
NYoj 301

3、等比矩阵:构造矩阵(110MS) S = A + A2 + A3 + ... + Ak = A(I + A(I + A(... (I+A)))) ;

                 即构造矩阵每次相乘得到A+I.

  1 /*
  2 Tips: 矩阵快速幂 矩阵构造
  3 | A  1 |   | A  1 |     | A^2  1+A |   | A  1 |   | A^3  1+A+(A^2) |
  4 | 0  1 | * | 0  1 |  =  |  0    1  | * | 0  1 | = |  0       1     |
  5 
  6 | A  1 |^(k+1)    | A^k  1+A+(A^2)+...+(A^n)|
  7 | 0  1 |       =  |  0            1         |
  8 
  9 Date: 2015.8.1
 10 */
 11 #include <iostream>
 12 #include <cstdio>
 13 #include <cstring>
 14 using namespace std;
 15 typedef long long ll;
 16 int mod;
 17 const int maxn = 100;
 18 struct matrix
 19 {
 20     int a[maxn][maxn];
 21 };
 22 matrix add(matrix A, matrix B, int n)
 23 {
 24     matrix ret;
 25     memset(ret.a, 0, sizeof(ret.a));
 26     for(int i = 0; i < n; ++i)
 27         for(int k = 0; k < n; ++k)
 28         {
 29             ret.a[i][k] += A.a[i][k] + B.a[i][k];
 30             if(ret.a[i][k] >= mod)
 31                 ret.a[i][k] %= mod;
 32         }
 33     return ret;
 34 }
 35 matrix mul(matrix A, matrix B, int n)
 36 {
 37     matrix ret;
 38     memset(ret.a, 0, sizeof(ret.a));
 39     for(int i = 0; i < n; ++i)
 40         for(int k = 0; k < n; ++k)
 41             if(A.a[i][k])
 42                 for(int j = 0; j < n; ++j)
 43                 {
 44                     ret.a[i][j] += A.a[i][k] * B.a[k][j];
 45                     if(ret.a[i][j] >= mod)
 46                         ret.a[i][j] %= mod;
 47                 }
 48     return ret;
 49 }
 50 matrix expo(matrix p, int k, int n)
 51 {
 52     if(k == 1) return p;
 53     matrix ret;
 54     memset(ret.a, 0, sizeof(ret.a));
 55     for(int i = 0; i < n; ++i)
 56         ret.a[i][i] = 1;
 57     if(k == 0) return ret;
 58     while(k)
 59     {
 60         if(k & 1)
 61             ret = mul(p, ret, n);
 62         p = mul(p, p, n);
 63         k >>= 1;
 64     }
 65     return ret;
 66 }
 67 void print(matrix ans, int n)
 68 {
 69     for(int i = 0; i < n; i++)
 70     {
 71         for(int j = 0; j < n; j++)
 72         {
 73             if(j) printf(" ");
 74             printf("%d", ans.a[i][j] % mod);
 75         }
 76         printf("\n");
 77     }
 78 }
 79 int main()
 80 {
 81     int n, k, m;
 82     matrix A, res;
 83     scanf("%d%d%d", &n, &k, &m);
 84     mod = m;
 85     for(int i = 0; i < n; i++)
 86         for(int j = 0; j < n; j++)
 87             scanf("%d", &A.a[i][j]);
 88     if(k == 1)
 89         print(A, n);
 90     else
 91     {
 92         for(int i = 0; i < n; i++)
 93             A.a[i][i+n] = A.a[i+n][i+n] = 1;
 94 
 95         res = expo(A, k+1, 2*n);
 96         for(int i = 0; i < n; i++)
 97         {
 98             for(int j = 0; j < n; j++)
 99             {
100                 if(j) printf(" ");
101                 if(i != j) printf("%d", res.a[i][j+n] % mod);
102                 else printf("%d", (res.a[i][j+n] - 1 + mod) % mod);
103             }
104             printf("\n");
105         }
106     }
107     return 0;
108 }
poj 3233

此题还有其他解法,可视S = (A + A2 + ... + Ak/2) + (Ak/2+1 + Ak/2+2 + ... + Ak)

            = (A + A2 + ... + Ak/2) + Ak/2(A + A2 + ... + Ak/2) + [Ak]

            = (I + Ak/2)(A + A2 + ... + Ak/2) + [Ak]

            递归求解。

这种方法效率当然不如上种,自己没有再写,此处贴上别人家的代码

 1     //800MS  
 2     #include <iostream>  
 3     #include <cstdio>  
 4     #include <cstring>  
 5     using namespace std;  
 6     int m,n,K;  
 7     int a[30][30];  
 8     class Matrix  
 9     {  
10     public:  
11         int num[30][30];  
12         Matrix(bool is=true)        //初始化  
13         {  
14             memset(num,0,sizeof(num));  
15             if(is)  
16             for(int i=0;i<n;i++)  
17                 num[i][i]=1;  
18         }  
19         void print()                //输出函数  
20         {  
21             for(int i=0;i<n;++i)  
22             {  
23                 printf("%d",num[i][0]);  
24                 for(int j=1;j<n;++j)  
25                     printf(" %d",num[i][j]);  
26                 printf("\n");  
27             }  
28         }  
29         //重载乘法运算  
30        friend Matrix& operator *(const Matrix& max1,const Matrix& max2)  
31         {  
32             Matrix tmp(false);              //注意这里是false,即初始化的矩阵不是单位矩阵I  
33             for(int i=0;i<n;++i)  
34                 for(int j=0;j<n;++j)  
35                 {  
36                     for(int k=0;k<n;++k)  
37                         tmp.num[i][j]+=(max1.num[i][k]*max2.num[k][j])%m;  
38                 tmp.num[i][j]%=m;  
39                 }  
40             return tmp;  
41         }  
42        //重载+=运算  
43        Matrix& operator +=(const Matrix& max)  
44        {  
45            for(int i=0;i<n;++i)  
46                for(int j=0;j<n;++j)  
47                    num[i][j]=(num[i][j]+max.num[i][j])%m;  
48       
49             return *this;  
50         }  
51     }ans;  
52     Matrix mul(Matrix A,int k)      //求A^K  
53     {  
54         if(k==1)  
55             return A;  
56         Matrix tmp ;  
57         while(k)  
58         {  
59             if(k&1)  
60                 tmp = tmp * A;  
61             k>>=1;  
62             A = A*A;  
63         }  
64         return tmp;  
65     }  
66     Matrix S(Matrix A,int k)        //求 S[k]  
67     {  
68         if(k==1)  
69             return A;  
70           
71         Matrix tmp ;  
72         tmp += mul(A,k>>1);           //求 (I + A^(k/2) )  
73         tmp = tmp*S(A,k>>1);      //求 (I + A^(k/2) )*S[k/2]  
74         if(k&1)                     //判断时候要加上 A^k  
75             tmp+= mul(A,k);         //S[k] = (I+A^(k/2)) * S[k/2] + {A^k}  
76         return tmp;  
77     }   
78       
79     int main()  
80     {  
81         int i,j,k;  
82         scanf("%d %d %d",&n,&K,&m);  
83         for( i=0;i<n;++i)  
84             for( j=0;j<n;++j)  
85                 scanf("%d",&ans.num[i][j]);  
86         ans = S(ans,K);  
87         ans.print();  
88       
89         return 0;  
90     }  
别人家的poj 3233

4、被这题坑大半天简直是想咬舌。

 1 /*
 2 Problem: UVa 10006
 3 Tips:    快速幂练手题
 4 Date:    2015.8.2
 5 TLE原因:理解能力简直是作死。
 6 题意:   合数+对于任意的(!!全称量词啊不是存在T T)a满足 (a^n)%n == a
 7 */
 8 
 9 #include <iostream>
10 #include <cstdio>
11 #include <cmath>
12 #include <cstring>
13 using namespace std;
14 typedef long long ll;
15 //const int mod = 1000007;
16 const int maxn = 65100;
17 int n;
18 bool pri[maxn];
19 void get_pri()
20 {
21     memset(pri, true, sizeof(pri));
22     int m = sqrt(maxn + 0.5);
23     for(int i = 2; i <= m; i++)
24         if(pri[i])
25             for(int j = i*i; j < maxn; j += i)
26                 pri[j] = false;
27 }
28 int expo(int x, int k, int mod)
29 {
30     if(k == 0) return 1%mod;
31     if(k == 1) return x%mod;
32     ll ret = 1;
33     while(k)
34     {
35         if(k & 1) ret = (ret*x)%mod;
36         x = ((ll)x*x)%mod;
37         k >>= 1;
38     }
39     return ret;
40 }
41 
42 int main()
43 {
44     get_pri();
45     while(~scanf("%d", &n) && n)
46     {
47         if(pri[n])
48         {
49             printf("%d is normal.\n", n);
50             continue ;
51         }
52         bool flag = true;
53         for(int a = 2; a < n; a++)
54             if(expo(a, n, n) != a)
55             {
56                 flag = false;
57                 break;
58             }
59 
60         if(flag == true) printf("The number %d is a Carmichael number.\n", n);
61         else printf("%d is normal.\n", n);
62     }
63     return 0;
64 }
被坑的快速幂练手题