Codeforces 1511F Chainword
如何判定一个字符串 $s$ 能否被拆为若干段字典中的单词呢?
考虑对字典建立 trie 树,当在 $s$ 的末尾增加一个字符的时候,就相当于是在 trie 树上从一个结点 $u$ 沿着某一条转移边走到另外一个结点 $v$。当然,你也可以在某个字符串的结束位置不继续往下走,而是回到根节点。
假如只有一道提示线,那么根据套路,可以用 $f_{i, u}$ 表示 $s$ 加入了 $i$ 个字符,当前在 trie 的结点 $u$ 上,这种情况下的方案数。转移是显然的。
但现在有两道提示线,仿照这样的思路,使用 $f_{i, u, v}$ 表示 $s$ 加入了 $i$ 个字符,上方的线在 trie 的结点 $u$ 上,下方的线在 trie 的结点 $v$ 上,这种情况下的方案数。
很容易发现这个 DP 与 $i$ 其实没什么关系,可以使用矩阵快速幂优化。但是总状态数是至多约 $1600$ 的。很显然,矩阵快速幂并不能跑得动。
现在考虑删减有效状态:
- 记 $root \to u$ 的路径的字符串为 $S$,$root \to v$ 的路径的字符串为 $T$,显然有 $S$ 是 $T$ 的后缀,或者 $T$ 是 $S$ 的后缀。
- $f_{i, u, v} = f_{i, v, u}$,所以 $(u, v)$ 和 $(v, u)$ 可以视作同一个状态。
加入这两个限制后,状态数就被删减到了约 $l = 160$,这时候套用矩阵快速幂优化即可。
一个实现上的细节是,如何找到满足第一种要求的状态呢?其实从 $(0,0)$ 开始 BFS / DFS,每次转移的时候枚举最后一个字符,所有能访问到的结点就是我们想要的结点。
时间复杂度 $O(l^3 \log m)$。
#include <bits/stdc++.h> using namespace std; const int N = 50, M = 161, MOD = 998244353; struct matrix { int val[M][M]; matrix() { memset(val, 0, sizeof(val)); } matrix operator * (const matrix &oth) const { matrix res; for(int i = 0; i < M; i++) for(int j = 0; j < M; j++) for(int k = 0; k < M; k++) res.val[i][j] = (res.val[i][j] + 1LL * val[i][k] * oth.val[k][j] % MOD) % MOD; return res; } } trans; matrix Qpow(matrix a, long long power) { matrix res; for(int i = 0; i < M; i++) res.val[i][i] = 1; while(power) { if(power & 1) res = res * a; a = a * a; power >>= 1; } return res; } struct trie { int siz; struct node { bool ext; int ch[26]; node() { ext = false; memset(ch, -1, sizeof(ch)); } } o[N]; trie() { siz = 0; } void Insert(char *str) { int cur = 0; for(int i = 1; str[i]; i++) { if(o[cur].ch[str[i] - 'a'] == -1) o[cur].ch[str[i] - 'a'] = ++siz; cur = o[cur].ch[str[i] - 'a']; } o[cur].ext = true; } } tr; int n; long long m; char s[N]; queue<pair<int, int>> que; map<pair<int, int>, int> idx; int GetID(pair<int, int> p) { if(p.first > p.second) swap(p.first, p.second); if(!idx.count(p)) { int k = idx.size(); idx[p] = k; que.push(p); } return idx[p]; } int main() { ios::sync_with_stdio(false); cin >> n >> m; for(int i = 1; i <= n; i++) { cin >> (s + 1); tr.Insert(s); } que.push(make_pair(0, 0)); idx[make_pair(0, 0)] = 0; while(!que.empty()) { pair<int, int> cur = que.front(); que.pop(); int x = GetID(cur); for(int c = 0; c < 26; c++) { pair<int, int> nxt = make_pair(tr.o[cur.first].ch[c], tr.o[cur.second].ch[c]); if(nxt.first == -1 || nxt.second == -1) continue; trans.val[x][GetID(nxt)]++; if(tr.o[nxt.first].ext) trans.val[x][GetID(make_pair(0, nxt.second))]++; if(tr.o[nxt.second].ext) trans.val[x][GetID(make_pair(nxt.first, 0))]++; if(tr.o[nxt.first].ext && tr.o[nxt.second].ext) trans.val[x][0]++; } } cout << Qpow(trans, m).val[0][0] << endl; return 0; }