BZOJ 3796: Mushroom追妹纸 后缀自动机+拓扑排序+DP

就是喜欢后缀自动机,yy了一个只用后缀自动机解决的方法.   

对 3 个串建立广义后缀自动机,然后建立后缀树.  

标记出每个点在0/1/2个串中是否作为子串出现,然后将后缀树中 2 串结尾的所有子树都设为危险节点.  

然后对于 SAM 来一个拓扑序DP,我们开始的时候默认危险节点的最大值是 -inf,然后 2 结尾节点最大值为 len(2)-1.  

然后这么转移一下就好了. 

最开始错误的原因误以为只要一个节点会从一个危险节点转移,那么该节点就是危险节点.    

这个是错误的,而是要满足转移过来的节点全是危险节点,这个点才是危险节点. 

code: 

#include <cstdio> 
#include <string>
#include <algorithm>  
#include <queue>
#include <cstring>  
#define N 300006 
#define inf 1000000000 
using namespace std;   
void setIO(string s) 
{
    freopen((s+".in").c_str(),"r",stdin);  
    // freopen((s+".out").c_str(),"w",stdout); 
}
char S[N];  
queue<int>que;  
int edges,dfn,ans,dp[N];  
int ch[N][26],mx[N],pre[N],last,tot,tag[N][3],hd[N],to[N],nex[N],st[N],ed[N],is[N],deg[N];   
void add(int u,int v) 
{
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; 
}
void extend(int c,int id) 
{ 
    if(ch[last][c]) 
    {
        int p=last,q=ch[last][c];   
        if(mx[q]==mx[p]+1)  last=q;   
        else 
        {
            int nq=++tot;   
            mx[nq]=mx[p]+1;   
            memcpy(ch[nq],ch[q],sizeof(ch[nq]));   
            pre[nq]=pre[q],pre[q]=nq;   
            for(;p&&ch[p][c]==q;p=pre[p]) ch[p][c]=nq;   
            last=nq; 
        }
    }  
    else 
    {
        int np=++tot,p=last; 
        mx[np]=mx[p]+1,last=np; 
        for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np;  
        if(!p) pre[np]=1; 
        else 
        {
            int q=ch[p][c]; 
            if(mx[q]==mx[p]+1)  pre[np]=q;  
            else 
            {
                int nq=++tot; 
                mx[nq]=mx[p]+1;   
                memcpy(ch[nq],ch[q],sizeof(ch[q]));    
                pre[nq]=pre[q],pre[q]=pre[np]=nq;     
                for(;p&&ch[p][c]==q;p=pre[p])  ch[p][c]=nq;  
            }
        }
    }      
    tag[last][id]=1;  
}    
void dfs(int u) 
{ 
    st[u]=++dfn; 
    for(int i=hd[u];i;i=nex[i]) 
    {
        int v=to[i]; 
        dfs(v);   
        tag[u][0]|=tag[v][0]; 
        tag[u][1]|=tag[v][1]; 
        tag[u][2]|=tag[v][2];  
    }
    ed[u]=dfn;  
}           
int main() 
{ 
    // setIO("input"); 
    last=tot=1;
    int i,j,len,cor; 
    for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',0);   
    for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',1);              
    for(scanf("%s",S+1),len=strlen(S+1),last=1,i=1;i<=len;++i) extend(S[i]-'a',2);    
    for(i=2;i<=tot;++i) add(pre[i],i);          
    dfs(1); 
    if(tag[last][0]&&tag[last][1])  
        ans=len-1;                                    
    for(i=1;i<=tot;++i)  if(st[i]>=st[last]&&ed[i]<=ed[last])  dp[i]=-inf,is[i]=1;  
    dp[last]=len-1;  
    for(i=1;i<=tot;++i)  for(j=0;j<26;++j) if(ch[i][j])  ++deg[ch[i][j]];     
    for(i=1;i<=tot;++i)  if(!deg[i]) que.push(i);       
    while(!que.empty()) 
    {  
        int u=que.front();que.pop();  
        for(int j=0;j<26;++j)  
        {
            int v=ch[u][j]; 
            if(!v) continue;                       
            if(!is[v]) dp[v]=max(dp[v],dp[u]+1);   
            --deg[v];  
            if(!deg[v]) que.push(v); 
        }
    }
    for(i=1;i<=tot;++i)  if(tag[i][0]&&tag[i][1]) ans=max(ans,dp[i]); 
    printf("%d\n",ans);  
    return 0;
}

  

posted @ 2019-12-31 08:52  EM-LGH  阅读(181)  评论(0编辑  收藏  举报