Codeforces 1511 F. Chainword 题解
Codeforces 1511 F. Chainword
题意
给出\(n\)个模式串(长度\(<= 5\))
定义\(chainword\)为满足下三个条件的字符串和一对划分
- 长度为\(m\)
- 第一个划分满足每段都是模式串
- 第二个划分满足每段都是模式串
要求\(chainword\)的个数。
\(n <= 8,m <= 10^9\)
题解
我先是想到了一个不太可做的\(dp\),记录后四位字符然后每次转移一个模式串,合法的后四位字符并不多,便考虑矩阵优化,但是问题在于,由于转移的是模式串,\(chainword\)的长度每次不是增加\(1\),不能用矩阵转移。所以我们要考虑设计一个可以每次转移一个字符,状态数不多,且能保证最后答案是一整个一整个模式串转移来的状态。
题解给出了一个巧妙的状态,设\(dp_{i,u,v}\)表示\(chainword\)长度为\(i\),第一个划分走到了\(u\),第二个走到了\(v\),其中\(u,v\)是模式串字典树上的节点。每次枚举字符转移,最终答案是\(dp_{m,rt,rt}\),\(rt\)是字典树根节点。\(u,v\)到根形成的字符串,一定有一个是另一个的后缀,这样的\(u,v\)是不多的。关于个数的计算,我们考虑对反串建字典树,\(u,v\)就是树上的一对祖先和后代,及每个节点的深度之和。不失一般性地,我们假设\(u <= v\)可以算出这样的\(u,v\)只有\(161\)对。求出转移矩阵后快速幂求解即可。
#include<bits/stdc++.h>
#define ll long long
#define N
#define mod 998244353
#define rep(i,a,n) for (int i=a;i<=n;i++)
#define per(i,a,n) for (int i=n;i>=a;i--)
#define inf 0x3f3f3f3f
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define fi first
#define se second
#define lowbit(i) ((i)&(-i))
#define VI vector<int>
#define all(x) x.begin(),x.end()
#define SZ(x) ((int)x.size())
#define end qwq
using namespace std;
int n,m;
namespace Trie{
int tr[405][26],rt = 0,end[405],cnt;
void insert(char *s){
int cur = rt;
for(int i = 1;s[i];++i){
if(!tr[cur][s[i]-'a']) tr[cur][s[i]-'a'] = ++cnt;
cur = tr[cur][s[i]-'a'];
}
end[cur] = 1;
}
}
using namespace Trie;
struct matrix{
int a[205][205];
matrix(){memset(a,0,sizeof a);}
int* operator[](int i){return a[i];}
matrix operator*(matrix lhs) const{
matrix res;
rep(i,0,200) rep(j,0,200) rep(k,0,200){
res[i][k] = (res[i][k] + 1ll*a[i][j]*lhs[j][k])%mod;
}
return res;
}
}base;
queue<pii> Q;
map<pii,int> id;
int tot;
int get(pii x){
if(x.fi > x.se) swap(x.fi,x.se);
if(id.count(x) > 0) return id[x];
else{
id[x] = tot;
Q.push(x);
return tot++;
}
}
matrix qpow(matrix a,int b){
matrix res;
rep(i,0,200) res[i][i] = 1;
while(b){
if(b&1) res = res*a;
a = a*a; b >>= 1;
}
return res;
}
int main(){
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%d%d",&n,&m);
rep(i,1,n){
char s[105];
scanf("%s",s+1);
insert(s);
}
// printf("%d\n",cnt);
get(mp(0,0));
while(!Q.empty()){
pii u = Q.front(); Q.pop();
int x = u.fi,y = u.se,cid = get(mp(x,y));
// printf("%d %d %d\n",x,y,cid);
rep(i,0,25){
int tx = tr[x][i],ty = tr[y][i];
if(!tx || !ty) continue;
base[cid][get(mp(tx,ty))]++;
if(end[tx]) base[cid][get(mp(0,ty))]++;
if(end[ty]) base[cid][get(mp(0,tx))]++;
if(end[tx] && end[ty]) base[cid][get(mp(0,0))]++;
}
}
base = qpow(base,m);
printf("%d\n",base[0][0]);
return 0;
}