#4730. 匹配
题目描述
树大小为 $n$, 第 $i$ 边有字符集 $S_i$. 给定 $m$ 个模式串 $t_1,t_2,\dots,t_m$。
$Q$ 次询问 $(u,v)$, 设 $u\to v$ 经过的边为 $e_1,e_2,\dots,e_k$,求串 $s$ 的方案数,满足:
- $|s|=k$
- $\forall i\in[1,k],s_i\in S_{e_i}$
- $\exist j,t_j \text{ is a substring of }s$
题解
考虑用总方案数减去不合法的方案数,建立 $\text{AC}$ 自动机,即 $\text{dp}$ : $f[i][j]$ 表示前 $i$ 个字符,目前在自动机上的 $j$ 号节点上的方案数,考虑每一步都不能走到结束节点上。然后发现可以写成矩阵的形式,考场上写的分块做法过不去似乎也优化不了,于是我们可以预处理 $nlogn$ 个矩阵,然后每次就往上跳,用向量乘上矩阵即可,这样效率是 $O(40^3nlogn+40^2Qlogn)$ ,过不去,瓶颈在于预处理部分。
考虑另类的跳 $\text{lca}$ 的方式:每次跳 $\le lowbit(dp[x])$ 步,这样跳的次数也不会超过 $O(logn)$ 而且对于每个点只需要预处理 $log(lowbit(dp[x]))+1$ 个矩阵即可。
于是我们对于 $\text{deep}$ 的每一位,如果第 $i$ 位上 $1$ 的个数比 $0$ 的来的少的话,我们就全体的 $\text{deep}$ 加上 $2^i$ 即可,这样预处理的最多次数就是 $\frac{n}{2} \times 1+\frac{n}{4} \times 2+ \frac{n}{8} \times 3+...<2n$ ,这样效率就是 $O(40^3n+40^2Qlogn)$ 。
代码
#include <bits/stdc++.h> using namespace std; const int N=5005,M=42,P=998244353; int n,m,q,hd[N],V[N],nx[N],t=1,e[M],tr[M][M],fi[M],fa[N][13],dp[N],Y,Z,f[2][M],su[2],c[N],b[N]; char g[N][M],h[M];queue<int>qu;bool F[N]; struct O{ int p[M][M]; }d[M],a[N],G,up[2501][13],dn[2501][13],S[100]; void add(int u,int v){ nx[++t]=hd[u];V[hd[u]=t]=v; } int X(int x){return x>=P?x-P:x;} void ins(){ int p=0,l=strlen(h); for (int k,i=0;i<l;i++){ k=h[i]-'a'; if (!tr[p][k]) tr[p][k]=++t; p=tr[p][k]; } e[p]=1; } void build(){ for (int i=0;i<26;i++) if (tr[0][i]) qu.push(tr[0][i]); while(!qu.empty()){ int k=qu.front();qu.pop(); for (int i=0;i<26;i++) if (tr[k][i]) fi[tr[k][i]]=tr[fi[k]][i], e[tr[k][i]]=e[tr[k][i]]|e[fi[tr[k][i]]], qu.push(tr[k][i]); else tr[k][i]=tr[fi[k]][i]; } } O Add(O A,O B){ for (int i=0;i<=t;i++) for (int j=0;j<=t;j++) A.p[i][j]=X(A.p[i][j]+B.p[i][j]); return A; } O Mul(O A,O B){ for (int i=0;i<=t;i++) for (int j=0;j<=t;j++) G.p[i][j]=0; for (int k=0;k<=t;k++) for (int j=0;j<=t;j++) if (A.p[k][j]) for (int i=0;i<=t;i++) if (B.p[i][k]) G.p[i][j]=X(G.p[i][j]+1ll*A.p[k][j]*B.p[i][k]%P); return G; } void dfs(int u,int fr){ dp[u]=dp[fa[u][0]=fr]+1; for (int v,j,i=hd[u];i;i=nx[i]){ if ((v=V[i])==fr) continue; j=i>>1;b[v]=strlen(g[j]); for (int k=0;k<b[v];k++) a[v]=Add(a[v],d[g[j][k]-97]); dfs(v,u); } } void dfs(int u){ for (int v,i=hd[u],w;i;i=nx[i]){ if ((v=V[i])==fa[u][0]) continue; w=c[dp[v]&-dp[v]]; up[v][0]=dn[v][0]=a[v]; for (int j=1;j<=w;j++) up[v][j]=Mul(up[v][j-1],up[fa[v][j-1]][j-1]), dn[v][j]=Mul(dn[fa[v][j-1]][j-1],dn[v][j-1]); dfs(v); } } int lca(int u,int v,int &w){ if (dp[u]<dp[v]) swap(u,v); while(dp[u]>dp[v]) w=1ll*w*b[u]%P,u=fa[u][0]; while(u!=v) w=1ll*w*b[u]%P, w=1ll*w*b[v]%P,u=fa[u][0],v=fa[v][0]; return u; } void Dp(O A){ for (int i=0;i<=t;i++){ f[Y][i]=0; for (int j=0;j<=t;j++) if (A.p[i][j]) f[Y][i]=X(f[Y][i]+1ll*A.p[i][j]*f[Z][j]%P); } Z^=1;Y^=1; } int main(){ cin>>n>>m>>q; for (int u,v,i=1;i<n;i++) scanf("%d%d%s",&u,&v,g[i]), add(u,v),add(v,u);t=0; for (int i=1;i<=m;i++) scanf("%s",h),ins();build(); for (int i=0;i<26;i++) for (int k,j=0;j<=t;j++) if (!e[k=tr[j][i]]) d[i].p[k][j]++; dfs(1,0);fa[1][0]=1;c[1<<12]=12; for (int i=0;i<12;i++){ su[0]=su[1]=0; for (int j=1;j<=n;j++) su[(dp[j]>>i)&1]++; if (su[0]>su[1]){ for (int j=1;j<=n;j++) dp[j]+=(1<<i); } c[1<<i]=i; } for (int i=0;i<13;i++) for (int j=0;j<=t;j++) up[1][i].p[j][j]=dn[1][i].p[j][j]=1; for (int i=1;i<13;i++) for (int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1]; dfs(1); for (int x,y,w,u,v,z,o;q--;){ scanf("%d%d",&x,&y); Y=f[o=Z=0][0]=w=1; z=lca(x,y,w); while(x!=z){ u=dp[x]&-dp[x]; for (int i=u;i;i>>=1) if (dp[x]-i>=dp[z]){ S[++o]=up[x][c[i]]; x=fa[x][c[i]];break; } } v=o; while(y!=z){ u=dp[y]&-dp[y]; for (int i=u;i;i>>=1) if (dp[y]-i>=dp[z]){ S[++o]=dn[y][c[i]]; y=fa[y][c[i]];break; } } for (int i=1;i<=v;i++) Dp(S[i]); for (int i=o;i>v;i--) Dp(S[i]); for (int i=0;i<=t;i++) w=X(w+P-f[Z][i]),f[0][i]=f[1][i]=0; printf("%d\n",w); } return 0; }