BZOJ 3796: Mushroom追妹纸 后缀自动机+拓扑排序+DP
就是喜欢后缀自动机,yy了一个只用后缀自动机解决的方法.
对 3 个串建立广义后缀自动机,然后建立后缀树.
标记出每个点在0/1/2个串中是否作为子串出现,然后将后缀树中 2 串结尾的所有子树都设为危险节点.
然后对于 SAM 来一个拓扑序DP,我们开始的时候默认危险节点的最大值是 -inf,然后 2 结尾节点最大值为 len(2)-1.
然后这么转移一下就好了.
最开始错误的原因误以为只要一个节点会从一个危险节点转移,那么该节点就是危险节点.
这个是错误的,而是要满足转移过来的节点全是危险节点,这个点才是危险节点.
code:
#include <cstdio> #include <string> #include <algorithm> #include <queue> #include <cstring> #define N 300006 #define inf 1000000000 using namespace std; void setIO(string s) { freopen((s+".in").c_str(),"r",stdin); // freopen((s+".out").c_str(),"w",stdout); } char S[N]; queue<int>que; int edges,dfn,ans,dp[N]; int ch[N][26],mx[N],pre[N],last,tot,tag[N][3],hd[N],to[N],nex[N],st[N],ed[N],is[N],deg[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void extend(int c,int id) { if(ch[last][c]) { int p=last,q=ch[last][c]; if(mx[q]==mx[p]+1) last=q; else { int nq=++tot; mx[nq]=mx[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[nq])); pre[nq]=pre[q],pre[q]=nq; for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq; last=nq; } } else { int np=++tot,p=last; mx[np]=mx[p]+1,last=np; for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np; if(!p) pre[np]=1; else { int q=ch[p][c]; if(mx[q]==mx[p]+1) pre[np]=q; else { int nq=++tot; mx[nq]=mx[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); pre[nq]=pre[q],pre[q]=pre[np]=nq; for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq; } } } tag[last][id]=1; } void dfs(int u) { st[u]=++dfn; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; dfs(v); tag[u][0]|=tag[v][0]; tag[u][1]|=tag[v][1]; tag[u][2]|=tag[v][2]; } ed[u]=dfn; } int main() { // setIO("input"); last=tot=1; int i,j,len,cor; for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',0); for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',1); for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',2); for(i=2;i<=tot;++i) add(pre[i],i); dfs(1); if(tag[last][0]&&tag[last][1]) ans=len-1; for(i=1;i<=tot;++i) if(st[i]>=st[last]&&ed[i]<=ed[last]) dp[i]=-inf,is[i]=1; dp[last]=len-1; for(i=1;i<=tot;++i) for(j=0;j<26;++j) if(ch[i][j]) ++deg[ch[i][j]]; for(i=1;i<=tot;++i) if(!deg[i]) que.push(i); while(!que.empty()) { int u=que.front();que.pop(); for(int j=0;j<26;++j) { int v=ch[u][j]; if(!v) continue; if(!is[v]) dp[v]=max(dp[v],dp[u]+1); --deg[v]; if(!deg[v]) que.push(v); } } for(i=1;i<=tot;++i) if(tag[i][0]&&tag[i][1]) ans=max(ans,dp[i]); printf("%d\n",ans); return 0; }