关于矩阵快速幂的若干优化

首先,我们复习一下矩阵乘法。

我们记3个矩阵A(a行b列),B(b行c列),C(a行c列)。我们要计算A*B,并把答案存到矩阵C中。

C[i][j]+=A[i][k]*B[k][j](1<=i<=a,1<=j<=c,1<=k<=b),即新矩阵的第i行第j个元素是原1矩阵的第i行*原2矩阵的第j列得来的。

一般来说,我们的计算方法是for(int i=1;i<=a;i++)for(int j=1;j<=c;j++)for(int k=1;k<=b;k++)C[i][j]+=A[i][k]*B[k][j];

其次,让我们复习一下快速幂。

举个例子吧,计算a^101。

我们知道,于是:

a^101=a^(1*2^6)*a^(1*2^5)*a^(0*2^4)*a^(1*2^3)*a^(1*2^2)*a^(0*2^1)*a^(1*2^0)。

我们把101转成2进制:1101101。每个2^x前的系数就是二进制第x位的数。

a^ (2^x)=a^(2^(x-1))^2。我们可以通过a^(2^(x-1))来求得a^(2^x)。

这样,对于二进制下的第x位,该位如果为1,就把ans*=a(更新答案,初始化为1)。然后每次a*=a(用a^(2^x)更新出a^(2^(x+1)),准备处理下一位)。

我们便可以在O(logp)(p为指数)的时间复杂度内出解。

最后,让我们来复习一下矩阵快速幂。

我们要求A^B^B^B^B^B^B^B......(A,B为矩阵),即A^(B^p)的值。

就像ans初值=1一样,记一个单位矩阵(主对角线为1)Ans,结合上面两种做法,我们就可以求出A^(B^p)的值。

(1)对于稀疏矩阵的优化

稀疏矩阵,即为矩阵中有很多元素为0。

优化方法:改变循环顺序。改为for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

这样有什么好处呢?

我们可以发现,只要A[i][k]==0,那么对答案矩阵(C)不会有任何贡献。

所以我们可以进行优化,在第二个循环到第三个循环直接加一个if,若A[i][k]!=0,才进入第三个循环。

for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)if(A[i][k])for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

题目:POJ 3735 Training little cats。

(2)预处理优化矩阵快速幂

主要针对多组数据。求A*B^k。给出A,B,T个询问k

在通常情况下,A是一个n行1列的矩阵,B是一个n行n列的矩阵。这样,我们的矩阵快速幂(求A^(B^k))的复杂度就是O((n^3logk+Tn^2logk))。

具体来说,我们先用O(n^3log(maxk))预处理出B^(2^p),再A*B^k=A*B^(2^a1)+A*B^(2^a2)+...算答案。复杂度O(Tn^2logk)

(3)优化快速幂过程

主要针对多组数据。

正常的快速幂的当次复杂度为O(log2(n))。看到那个2了吗,我们的工作就是要把这个2变大。

考虑一般的快速幂,一般的快速幂是以2进制为基础的,我们考虑用3进制为基础会怎么样。

对于每一个3进制位,如果该位是0,ans*=x^0,如果该位是1,ans*=x^1,如果该位是2,ans*=x^2

与2进制快速幂同理,每次x=x^3,p=p/3

所以复杂度是O(klogk(n)),k为进制

但是虽然这个2变大了,复杂度却一点也没变小

但是这并不能阻挡我们优化的决心,如果每次询问的底数都相同,我们是能优化的

预处理mi[a][b]=(x^(k^a))^b即可,每次ans*=mi[a][b],a是当前做到第几位,b是当前这位的数

mi[a][1]=mi[a-1][k-1]*mi[a-1][1]

mi[a][b]=mi[a][b-1]*mi[a][1]

这样复杂度变为(klogk(n)+logk(n))。

(4)常数优化

ikj循环,循环展开 for(int i = 1; i <= n; i++) for(int k = 1; k <= n; k++) for(int j = 1; j <= n; j++) c[i][j] += a[i][k] * b[k][j];

这样能保证b数组的内存访问是连续的

 

拥有上面全部优化的模板题:https://www.luogu.org/problemnew/show/P5107

#include <cstdio>
#include <cstring>
#define mod 998244353
#define T 256
#include <algorithm>

struct xxx{
    int a[52][52];
};
struct xx{
    int a[52];
};
struct QQ{
    int x, id;
}q[50100];
int n, d[55];
xxx mi[4][T + 1];
xx ans;
long long Ans[50100];

bool cmp(QQ a, QQ b) {return a.x < b.x;}
 
xxx operator * (xxx a, xxx b)
{
    xxx c; memset(c.a, 0, sizeof(c.a));
    for(int i = 1; i <= n; i++)
        for(int k = 1; k <= n; k++)
            if(a.a[i][k])
            for(int j = 1; j <= n; j++)
                c.a[i][j] = (c.a[i][j] + 1ll * a.a[i][k] * b.a[k][j]) % mod;
    return c;
}

xx operator * (xx a, xxx b) 
{
    xx c; memset(c.a, 0, sizeof(c.a));
    for(int j = 1; j <= n; j++)
        for(int i = 1; i <= n; i++)
            c.a[j] = (c.a[j] + 1ll * a.a[i] * b.a[i][j]) % mod;
    return c;
}

int qpow(int x, int p)
{
    int ans = 1;
    while(p)
    {
        if(p & 1) ans = 1ll * ans * x % mod;
        x = 1ll * x * x % mod; p >>= 1;
    }
    return ans;
}

xx operator ^ (xx a, int p)
{
    int j = 0;
    while(p)
    {
        ans = ans * mi[j][p & 255];
        j++; p >>= 8;
    }
    return ans;
}

int main()
{
    int m, Q; scanf("%d%d%d", &n, &m, &Q);
    for(int i = 1; i <= n; i++) scanf("%d", &ans.a[i]), mi[0][1].a[i][i] = 1, d[i] = 1;
    for(int i = 1; i <= m; i++)
    {
        int u, v; scanf("%d%d", &u, &v);
        mi[0][1].a[u][v]++; d[u]++;
    }
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            mi[0][1].a[i][j] = 1ll * mi[0][1].a[i][j] * qpow(d[i], mod - 2) % mod;
    for(int i = 0; i <= 3; i++)
    {
        for(int j = 0; j < T; j++)
        {
            if(i == 0 && j == 1) continue;
            if(j == 0) for(int k = 1; k <= n; k++) mi[i][j].a[k][k] = 1;
            else if(j == 1) mi[i][j] = mi[i - 1][T - 1] * mi[i - 1][1];
            else mi[i][j] = mi[i][j - 1] * mi[i][1];
        }
    }
    for(int i = 1; i <= Q; i++)
    {
        scanf("%d", &q[i].x);
        q[i].id = i;
    }
    std::sort(q + 1, q + Q + 1, cmp);
    for(int i = 1; i <= Q; i++)
    {
        ans = ans ^ (q[i].x - q[i - 1].x);
        for(int j = 1; j <= n; j++) Ans[q[i].id] = Ans[q[i].id] ^ ans.a[j];
        Ans[q[i].id] %= mod;
    }
    for(int i = 1; i <= Q; i++) printf("%lld\n", Ans[i]);
}

 

posted @ 2017-12-11 20:44  lher  阅读(1336)  评论(0编辑  收藏  举报