矩阵乘法优化 DP
矩阵乘法
我们有一个 DP 式,它的转移系数相对固定,不受 DP 值的变化而变化,可以递推
且它通常有一或两维状态,可以分为很多阶段,但每个阶段中状态数不多
现在我们要递推很多次,是线性复杂度接受不了的
便把 DP 的式子写为矩阵的形式,一般在 $O(w^3\log n)$ 复杂度内计算($w$ 为矩阵的大小)
它不仅应用于优化递推,还可以做到求某个区间的 递推/DP 值,此时再用类似线段树的数据结构维护每个区间的矩阵乘积
例题
1. P4007 小 Y 和恐怖的奴隶主
这个题,看 $m,k$ 小且不变,考虑加入状态中
以 $m=3$ 为例,剩下的把状态减小即可
设 $f(i,a,b,c)$ 表示当前攻击了 $i$ 次,此时随从中生命值为 $1$ 的有 $a$ 个,为 $2$ 的有 $b$ 个,为 $3$ 的有 $c$ 个
此时通过枚举 $(a,b,c)$ 发现合法的三元组不超过 170 个
但是 $n$ 很大,想到矩阵乘法
初始和最终求出的为行向量,初始局面对应三元组概率为 $1$
我们列出 $170\times 170$ 的矩阵,每次枚举两个三元组编号为 $i,j$,判断它们能否一步到达,分类讨论是攻击的谁,矩阵中 $a_{i,j}$ 乘以相应概率
但要求的是期望,只需把矩阵加一行一列,行向量加一格代表期望,当不攻击随从,即 $i=j$ 时贡献期望,为 $1\times P[\text{攻击boss}]$
但是多组询问中,每次都矩阵快速幂,$O(170^3T\log n)$ 显然过不去
用到一个技巧:
发现转移矩阵不变,且注意到转移矩阵乘行/列向量复杂度为 $O(w^2)$
考虑把所有转移矩阵的 $2^0,2^1\dots,2^{\log n}$ 预处理,矩阵乘法有结合律,把行向量与需要乘的矩阵相乘
复杂度 $O((w^3+Tw^2)\log n),w=170$,卡个常就过了
代码:
#include<bits/stdc++.h>
#define reg register
using namespace std;
typedef long long ll;
const ll mod = 998244353;
struct matrix
{
ll a[170][170];
inline matrix() {memset(a, 0, sizeof(a));}
}f, I, kong, b, pw[65], ans;
struct group
{
ll p, q, r;
}lin[170];
inline ll read()
{
reg char ch = getchar(); reg ll x = 0;
while(ch < '0' || ch > '9') ch = getchar();
while(ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
return x;
}
inline void print(ll x)
{
if(x / 10) print(x / 10);
putchar(x % 10 + '0');
}
ll t, n, m, k, cnt, targ, pr, invf[110];
inline ll qmi(ll a, ll b)
{
reg ll res = 1;
while(b)
{
if(b & 1) res = res * a % mod;
b >>= 1, a = a * a % mod;
}
return res;
}
inline ll inv(ll x) {return qmi(x, mod - 2);}
inline ll add(ll a, ll b) {return (a + b >= mod) ? (a + b - mod) : (a + b);}
inline matrix mul(matrix x, matrix y)
{
matrix z;
for(reg ll i = 1; i <= cnt; ++i)
for(reg ll u = 1; u <= cnt; ++u)
if(x.a[i][u])
for(reg ll j = 1; j <= cnt; ++j)
z.a[i][j] = add(z.a[i][j], x.a[i][u] * y.a[u][j] % mod);
return z;
}
inline matrix hmul(matrix x, matrix y)
{
matrix res;
for(reg ll i = 1; i <= cnt; ++i)
for(reg ll j = 1; j <= cnt; ++j)
res.a[1][i] = add(res.a[1][i], x.a[1][j] * y.a[j][i] % mod);
return res;
}
inline void work1()
{
for(reg ll i = 1; i < cnt; ++i)
for(reg ll j = 1; j < cnt; ++j)
{
if(lin[i].p == lin[j].p + 1)
f.a[i][j] = invf[lin[i].p + 1] * lin[i].p % mod;
else if(i == j) f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + 1] % mod;
}
}
inline void work2()
{
for(reg ll i = 1; i < cnt; ++i)
for(reg ll j = 1; j < cnt; ++j)
{
if(lin[i].p == lin[j].p + 1 && lin[i].q == lin[j].q)
f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].p % mod;
else if(lin[i].p + lin[i].q < k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q)
f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].q % mod;
else if(lin[i].p + lin[i].q == k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1)
f.a[i][j] = invf[lin[i].p + lin[i].q + 1] * lin[i].q % mod;
else if(i == j) f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + lin[i].q + 1] % mod;
}
}
inline void work3()
{
for(reg ll i = 1; i < cnt; ++i)
for(reg ll j = 1; j < cnt; ++j)
{ // j 是下一步的可能状态,i 是上一步,a[i][j] 为 i->j 的概率
if(lin[i].p == lin[j].p + 1 && lin[i].q == lin[j].q && lin[i].r == lin[j].r) // 打生命值为 1
f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].p % mod;
else if(lin[i].p + lin[i].q + lin[i].r < k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1 && lin[i].r == lin[j].r - 1)
f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].q % mod; // 打生命值为 2
else if(lin[i].p + lin[i].q + lin[i].r == k && lin[i].p == lin[j].p - 1 && lin[i].q == lin[j].q + 1 && lin[i].r == lin[j].r)
f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].q % mod;
else if(lin[i].p + lin[i].q + lin[i].r < k && lin[i].p == lin[j].p && lin[i].q == lin[j].q - 1 && lin[i].r == lin[j].r)
f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].r % mod; // 打生命值为 3
else if(lin[i].p + lin[i].q + lin[i].r == k && lin[i].p == lin[j].p && lin[i].q == lin[j].q - 1 && lin[i].r == lin[j].r + 1)
f.a[i][j] = invf[lin[i].p + lin[i].q + lin[i].r + 1] * lin[i].r % mod;
else if(i == j) // 打 boss
f.a[i][i] = f.a[i][cnt] = invf[lin[i].p + lin[i].q + lin[i].r + 1];
}
}
int main()
{
t = read(), m = read(), k = read();
for(reg ll x = 0; x <= 8; ++x)
for(reg ll y = 0; y <= 8; ++y)
for(reg ll z = 0; z <= 8; ++z)
if(x + y + z <= k && (m > 2 || !z) && (m > 1 || (!y && !z)))
{
++cnt, lin[cnt] = (group){x, y, z};
if(m == 1 && x == 1) targ = cnt;
else if(m == 2 && x == 0 && y == 1) targ = cnt;
else if(m == 3 && x == 0 && y == 0 && z == 1) targ = cnt;
}
// printf("%lld", cnt); 最多共 165 种合法状态
++cnt;
for(reg int i = 1; i <= 100; ++i) invf[i] = qmi(i, mod - 2);
for(reg ll i = 1; i <= cnt; ++i) I.a[i][i] = 1;
b.a[1][targ] = f.a[cnt][cnt] = 1;
if(m == 1) work1();
else if(m == 2) work2();
else work3();
pw[0] = f;
for(reg ll i = 1; i <= 61; ++i) pw[i] = mul(pw[i - 1], pw[i - 1]); // 预处理矩阵的 2^i 次方
while(t--)
{
n = read(), ans = b;
// printf("%lld \n", f.a[cnt][2]);
for(reg ll i = 0; i <= 61; ++i)
if((n >> i) & 1) ans = hmul(ans, pw[i]); // 行向量乘矩阵, O(166^2)(不会 T)
print(ans.a[1][cnt] % mod), putchar('\n'); // 最后一格表示期望
// printf("%lld %lld\n", ans.a[][], ans.a[][]);
}
return 0;
}