矩阵乘法优化 DP

矩阵乘法

我们有一个 DP 式,它的转移系数相对固定,不受 DP 值的变化而变化,可以递推

且它通常有一或两维状态,可以分为很多阶段,但每个阶段中状态数不多

现在我们要递推很多次,是线性复杂度接受不了的

便把 DP 的式子写为矩阵的形式,一般在 $O(w^3\log n)$ 复杂度内计算($w$ 为矩阵的大小)

它不仅应用于优化递推,还可以做到求某个区间的 递推/DP 值,此时再用类似线段树的数据结构维护每个区间的矩阵乘积

例题

1. P4007 小 Y 和恐怖的奴隶主

这个题,看 $m,k$ 小且不变,考虑加入状态中

以 $m=3$ 为例,剩下的把状态减小即可

设 $f(i,a,b,c)$ 表示当前攻击了 $i$ 次,此时随从中生命值为 $1$ 的有 $a$ 个,为 $2$ 的有 $b$ 个,为 $3$ 的有 $c$ 个

此时通过枚举 $(a,b,c)$ 发现合法的三元组不超过 170 个

但是 $n$ 很大,想到矩阵乘法

初始和最终求出的为行向量,初始局面对应三元组概率为 $1$

我们列出 $170\times 170$ 的矩阵,每次枚举两个三元组编号为 $i,j$,判断它们能否一步到达,分类讨论是攻击的谁,矩阵中 $a_{i,j}$ 乘以相应概率

但要求的是期望,只需把矩阵加一行一列,行向量加一格代表期望,当不攻击随从,即 $i=j$ 时贡献期望,为 $1\times P[\text{攻击boss}]$

但是多组询问中,每次都矩阵快速幂,$O(170^3T\log n)$ 显然过不去

用到一个技巧:

发现转移矩阵不变,且注意到转移矩阵乘行/列向量复杂度为 $O(w^2)$

考虑把所有转移矩阵的 $2^0,2^1\dots,2^{\log n}$ 预处理,矩阵乘法有结合律,把行向量与需要乘的矩阵相乘

复杂度 $O((w^3+Tw^2)\log n),w=170$,卡个常就过了

代码:

#include<bits/stdc++.h>
#define reg register
using namespace std;

typedef long long ll;
const ll mod = 998244353;
struct matrix
{
    ll a[170][170];
    inline matrix() {memset(a, 0, sizeof(a));}
}f, I, kong, b, pw[65], ans;
struct group
{
    ll p, q, r;
}lin[170];
inline ll read()
{
    reg char ch = getchar();    reg ll x = 0;
    while(ch < '0' || ch > '9') ch = getchar();
    while(ch >= '0' && ch <= '9')   x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
    return x;
}
inline void print(ll x)
{
    if(x / 10)  print(x / 10);
    putchar(x % 10 + '0');
}
ll t, n, m, k, cnt, targ, pr, invf[110];
inline ll qmi(ll a, ll b)
{
    reg ll res = 1;
    while(b)
    {
        if(b & 1)   res = res * a % mod;
        b >>= 1, a = a * a % mod; 
    }
    return res;
}
inline ll inv(ll x) {return qmi(x, mod - 2);}
inline ll add(ll a, ll b)   {return (a + b >= mod) ? (a + b - mod) : (a + b);} 
inline matrix mul(matrix x, matrix y)
{
    matrix z;
    for(reg ll i = 1; i <= cnt; ++i)
        for(reg ll u = 1; u <= cnt; ++u)
            if(x.a[i][u])   
                for(reg ll j = 1; j <= cnt; ++j)    
                    z.a[i][j] = add(z.a[i][j], x.a[i][u] * y.a[u][j] % mod);
    return z;
}
inline matrix hmul(matrix x, matrix y)
{
    matrix res;
    for(reg ll i = 1; i <= cnt; ++i)
        for(reg ll j = 1; j <= cnt; ++j)    
            res.a[1][i] = add(res.a[1][i], x.a[1][j] * y.a[j][i] % mod);
    return res;
}
inline void work1()
{
    for(reg ll i = 1; i < cnt; ++i)
        for(reg ll j = 1; j < cnt; ++j)
        {
            if(lin[i].p == lin[j].p + 1)
                f.a[i][j] = invf[lin[i].p + 1] * lin[i].p % mod;
            else if(i == j) f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + 1] % mod;
        }
}
inline void work2()
{
    for(reg ll i = 1; i < cnt; ++i)
        for(reg ll j = 1; j < cnt; ++j)
        {
            if(lin[i].p == lin[j].p + 1 && lin[i].q == lin[j].q)
                f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].p % mod;
            else if(lin[i].p + lin[i].q < k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q)
                f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].q % mod; 
            else if(lin[i].p + lin[i].q == k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1)
                f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].q % mod;
            else if(i == j) f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + lin[i].q + 1] % mod;
        }
}
inline void work3()
{
    for(reg ll i = 1; i < cnt; ++i)
        for(reg ll j = 1; j < cnt; ++j)
        { // j 是下一步的可能状态,i 是上一步,a[i][j] 为 i->j 的概率 
            if(lin[i].p == lin[j].p + 1 && lin[i].q == lin[j].q && lin[i].r == lin[j].r) // 打生命值为 1 
                f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].p % mod;
            else if(lin[i].p + lin[i].q + lin[i].r < k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1 && lin[i].r == lin[j].r - 1)
                f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].q % mod;  // 打生命值为 2 
            else if(lin[i].p + lin[i].q + lin[i].r == k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1 && lin[i].r == lin[j].r)
                f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].q % mod;
            else if(lin[i].p + lin[i].q + lin[i].r < k && lin[i].p == lin[j].p && lin[i].q == lin[j].q - 1 && lin[i].r == lin[j].r)
                f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].r % mod;  // 打生命值为 3
            else if(lin[i].p + lin[i].q + lin[i].r == k && lin[i].p == lin[j].p && lin[i].q == lin[j].q - 1 && lin[i].r == lin[j].r + 1)
                f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].r % mod; 
            else if(i == j) // 打 boss 
                f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + lin[i].q + lin[i].r + 1];
        }
}
int main()
{
    t = read(), m = read(), k = read();
    for(reg ll x = 0; x <= 8; ++x)
        for(reg ll y = 0; y <= 8; ++y)
            for(reg ll z = 0; z <= 8; ++z)
                if(x + y + z <= k && (m > 2 || !z) && (m > 1 || (!y && !z)))    
                {
                    ++cnt, lin[cnt] = (group){x, y, z};
                    if(m == 1 && x == 1)    targ = cnt;
                    else if(m == 2 && x == 0 && y == 1) targ = cnt;
                    else if(m == 3 && x == 0 && y == 0 && z == 1)   targ = cnt;
                }
//  printf("%lld", cnt); 最多共 165 种合法状态
    ++cnt;
    for(reg int i = 1; i <= 100; ++i)   invf[i] = qmi(i, mod - 2);
    for(reg ll i = 1; i <= cnt; ++i)    I.a[i][i] = 1;
    b.a[1][targ] = f.a[cnt][cnt] = 1;
    if(m == 1)  work1();
    else if(m == 2) work2();
    else    work3();
    pw[0] = f;
    for(reg ll i = 1; i <= 61; ++i) pw[i] = mul(pw[i - 1], pw[i - 1]); // 预处理矩阵的 2^i 次方 
    while(t--)
    {
        n = read(), ans = b;
//      printf("%lld \n", f.a[cnt][2]);
        for(reg ll i = 0; i <= 61; ++i)
            if((n >> i) & 1)    ans = hmul(ans, pw[i]); // 行向量乘矩阵, O(166^2)(不会 T) 
        print(ans.a[1][cnt] % mod), putchar('\n'); // 最后一格表示期望 
//      printf("%lld %lld\n", ans.a[][], ans.a[][]);
    }
    return 0;
}
posted @ 2023-03-19 23:52  KellyWLJ  阅读(7)  评论(0编辑  收藏  举报  来源