洛谷P5115 Check,Check,Check one two! 边分治+虚树+SAM

绝对是我写过最长的一份代码了.   

这个快敲吐了. 

通过这道题能 get 到一个套路:

两颗树同时统计信息的题可以考虑在个树上跑边分治,把点扔到另一颗树的虚树上,然后跑虚树DP.

具体地,这道题中我们发现 $LCP$ 长度是反串后缀树 $LCA$ 深度,$LCS$ 是正串后缀树 $LCA$ 深度.

我们建出正反两串后缀树后,将长度大于 K1/K2 的点的深度置为 0,然后跑一个边分+虚树即可.

code: 

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <vector>
#include <map>      
#define N 200007
#define inf 0x3f3f3f3f 
#define ull unsigned long long
  
// 代码已写完,人已阵亡.
      
using namespace std;
  
int bug;  
int K1,K2;
ull ans,W;
char S[N];   
  
namespace IO { 
  
    void setIO(string s)
    {
        string in=s+".in";
        string out=s+".out";
        freopen(in.c_str(),"r",stdin);
        // freopen(out.c_str(),"w",stdout);
    }
  
};
  
struct SAM {      
  
    #define M N<<1 
  
    int tot,last;   
    struct Edge { 
        int to,w;
        Edge(int to=0,int w=0):to(to),w(w){}  
    }; 
    vector<Edge>G[M];
    int pre[M],ch[M][26],mx[M],str_sam[M],sam_str[M],depth[M];             
  
    void Initialize() { tot=last=1; }   
  
    void extend(int c)
    {       
        int np=++tot,p=last; 
        mx[np]=mx[p]+1,last=np;
        for(;p&&!ch[p][c];p=pre[p]) ch[p][c]=np; 
        if(!p) pre[np]=1;
        else
        {
            int q=ch[p][c];
            if(mx[q]==mx[p]+1) pre[np]=q;
            else
            {
                int nq=++tot;   
                mx[nq]=mx[p]+1;  
                memcpy(ch[nq],ch[q],sizeof(ch[q]));  
                pre[nq]=pre[q],pre[np]=pre[q]=nq;  
                for(;p&&ch[p][c]==q;p=pre[p])  ch[p][c]=nq;
            }
        }    
    }
      
  
    void Build_LCP()
    {     
        int n=strlen(S+1),i,j,p=1;
        for(i=1;i<=n;++i)
        {
            p=ch[p][S[n-i+1]-'a']; 
            sam_str[p]=n-i+1;
            str_sam[n-i+1]=p;                   
        }
        for(i=2;i<=tot;++i)
        {
            if(mx[i]>K1) depth[i]=0;         
            else depth[i]=mx[i]; 
        }
        for(i=2;i<=tot;++i)  G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]]));       
    } 
      
    void Build_LCS()
    {
        int n=strlen(S+1),i,j,p=1; 
        for(i=1;i<=n;++i)
        {
            p=ch[p][S[i]-'a']; 
            sam_str[p]=i;
            str_sam[i]=p; 
        }  
        for(i=2;i<=tot;++i)
        {
            if(mx[i]>K2) depth[i]=0;
            else  depth[i]=mx[i];
        }
        for(i=2;i<=tot;++i)  G[pre[i]].push_back(Edge(i,depth[i]-depth[pre[i]]));   
    }   
  
    #undef M
  
}lcp,lcs;
  
namespace vir {    
  
    vector<int>G[N<<2];
    vector<int>clr;  

    int t,sta,tot;
    int is1[N<<2],is2[N<<2];
    int dfn[N<<2],dep[N<<2],size[N<<2],son[N<<2],top[N<<2],f[N<<2];     
    int S[N<<2],val[N<<2],re[N<<2];
    ull size1[N<<2],size2[N<<2];   
    ull sum1[N<<2],sum2[N<<2];        
  
    bool cmp(int a,int b)
    {
        return dfn[a]<dfn[b]; 
    }
  
    void get_dfn(int x,int fa)
    {       
        dfn[x]=++t;         
        size[x]=1;
        f[x]=fa; 
        for(int i=0;i<lcp.G[x].size();++i)
        {
            int y=lcp.G[x][i].to; 
            if(y==fa) continue;
            dep[y]=dep[x]+1;  
            get_dfn(y,x);     
            size[x]+=size[y];
            if(size[y]>size[son[x]]) son[x]=y;
        }      
    } 
  
    void dfs2(int u,int tp)
    {
        top[u]=tp;
        if(son[u]) dfs2(son[u],tp);
        for(int i=0;i<lcp.G[u].size();++i)
        {
            int v=lcp.G[u][i].to;  
            if(v==son[u]||v==f[u]) continue;   
            dfs2(v,v);  
        }
    }
  
    int LCA(int x,int y)
    {
        while(top[x]!=top[y])
        {
            dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]];
        }
        return dep[x]<dep[y]?x:y;
    }
  
    void _new(int x,int v,int c)
    {
        ++tot;
        re[tot]=x;
        val[x]=v;
        if(c==1) is1[x]=1;
        else is2[x]=1;
    }  
  
    void addvir(int x,int y)
    { 
        G[x].push_back(y);
    }
  
    void Initialize()
    {
        t=0; 
        get_dfn(1,0); 
        dfs2(1,1);          
    }     
  
    void Insert(int x)
    {
        if(sta<=1)
        {
            S[++sta]=x;    
            return;
        }
        int lca=LCA(S[sta],x);     
        if(lca==S[sta]) S[++sta]=x;
        else
        {
            while(sta>1&&dep[S[sta-1]]>=dep[lca]) addvir(S[sta-1],S[sta]),--sta;  
            if(S[sta]==lca)  S[++sta]=x;                         
            else
            {
                addvir(lca,S[sta]); 
                S[sta]=lca;
                S[++sta]=x; 
            } 
        }
    }
  
    void Build()
    {  
        sta=0;
        sort(re+1,re+1+tot,cmp);                         
        if(re[1]!=1) S[++sta]=1;         
        for(int i=1;i<=tot;++i) Insert(re[i]);
        while(sta>1)  addvir(S[sta-1],S[sta]),--sta;                       
    }  
  
    void DP(int x)
    {   
        clr.push_back(x); 
        for(int i=0;i<G[x].size();++i)
        {
            int y=G[x][i];
            DP(y);          
            size1[x]+=size1[y];
            size2[x]+=size2[y]; 
            sum1[x]+=sum1[y]; 
            sum2[x]+=sum2[y];             
        }
        ull tmp=0; 
        ull cntw=0;
        ull cur=0; 
        for(int i=0;i<G[x].size();++i)
        { 
            int y=G[x][i];   
            tmp+=(sum1[x]-sum1[y])*size2[y];      
            tmp+=(sum2[x]-sum2[y])*size1[y];            
            cntw+=(size1[x]-size1[y])*size2[y];
        }  
        cur+=tmp*lcp.depth[x];         
        cur-=cntw*W*lcp.depth[x]; 
        if(is1[x])
        {  
            cur+=size2[x]*val[x]*lcp.depth[x];  
            cur+=sum2[x]*lcp.depth[x];  
            cur-=size2[x]*W*lcp.depth[x]; 
        } 
        if(is2[x])
        {
            cur+=size1[x]*val[x]*lcp.depth[x]; 
            cur+=sum1[x]*lcp.depth[x];  
            cur-=size1[x]*W*lcp.depth[x]; 
        }     
        ans+=cur/2;  
        size1[x]+=is1[x];
        size2[x]+=is2[x]; 
        sum1[x]+=is1[x]*val[x];
        sum2[x]+=is2[x]*val[x]; 
        G[x].clear();       
    }
  
    void solve()
    {
        Build(); 
        DP(1); 
        for(int i=0;i<clr.size();++i) 
        {
            int x=clr[i];   
            val[x]=sum1[x]=sum2[x]=size1[x]=size2[x]=is1[x]=is2[x]=0;    
        }           
        for(int i=1;i<=tot;++i)
        {         
            re[i]=0; 
        }
        tot=0; 
        sta=0;      
        clr.clear(); 
    }
};  // 虚树
  
int tot,edges=1;
int totsize,rt1,rt2,mx,ed,lsc,rsc;
int hd[N<<2],vis[N<<3],size[N<<2];
  
struct Edge {
    int to,w,nex; 
}e[N<<3];  
   
struct Node { 
    int u,dis,val;      
    Node(int u=0,int dis=0,int val=0):u(u),dis(dis),val(val){}         
}L[N<<2],R[N<<2];
  
void add_div(int x,int y,int z)
{
    e[++edges].nex=hd[x],hd[x]=edges,e[edges].to=y,e[edges].w=z; 
}  
  
void Build_Tree(int x,int fa)
{
    int ff=0;             
    for(int i=0;i<lcs.G[x].size();++i)
    {
        int y=lcs.G[x][i].to; 
        if(y==fa)  continue;  
        if(!ff)
        {
            ff=x;  
            add_div(ff,y,lcs.G[x][i].w);
            add_div(y,ff,lcs.G[x][i].w);  
        }
        else
        {
            ++tot;        
            add_div(ff,tot,0);
            add_div(tot,ff,0);                         
            add_div(tot,y,lcs.G[x][i].w); 
            add_div(y,tot,lcs.G[x][i].w); 
            ff=tot;        
        }
        Build_Tree(y,x); 
    }
}
  
void find_edge(int x,int fa)
{  
    size[x]=1;
    for(int i=hd[x];i;i=e[i].nex)
    {
        int y=e[i].to; 
        if(y==fa||vis[i])  continue;  
        find_edge(y,x);
        int now=max(size[y],totsize-size[y]);  
        if(now<mx)
        {
            mx=now;
            ed=i; 
            rt1=y; 
            rt2=x;  
        }
        size[x]+=size[y];
    }
}
  
  
void get_node(int x,int fa,int dep,int ty)
{     
    if(ty==1)
    {           
        if(lcs.sam_str[x])  L[++lsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep); 
    }
    else
    {  
        if(lcs.sam_str[x])  R[++rsc]=Node(lcs.sam_str[x],lcs.depth[x]-dep,dep);
    }
    for(int i=hd[x];i;i=e[i].nex)
    {
        int y=e[i].to;   
        if(vis[i]||y==fa)  continue;  
        get_node(y,x,dep+e[i].w,ty); 
    }
} 

void Divide_And_conquer(int x)
{      
    if(totsize==1) return; 
    mx=inf;  
    rt1=rt2=ed=0;     
    find_edge(x,0);             
    vis[ed]=vis[ed^1]=1;    
    lsc=rsc=0;                                      
    get_node(rt1,0,0,1); 
    get_node(rt2,0,0,2);            
    W=(ull)e[ed].w;  
    ull tmp=ans;   
  
    for(int i=1;i<=lsc;++i) vir::_new(lcp.str_sam[L[i].u],L[i].dis,1); 
    for(int i=1;i<=rsc;++i) vir::_new(lcp.str_sam[R[i].u],R[i].dis,2); 
    vir::solve();                    

    int tmprt1=rt1,tmprt2=rt2;           
    int sizert1=size[rt1],sizert2=totsize-size[rt1];           
    totsize=sizert1;      
    Divide_And_conquer(tmprt1);    
    totsize=sizert2; 
    Divide_And_conquer(tmprt2); 
}
  
int main()
{
    // IO::setIO("input");
    int i,j,n; 
    scanf("%s%d%d",S+1,&K1,&K2);        
    n=strlen(S+1);    
    lcp.Initialize(); 
    lcs.Initialize(); 
    for(i=1;i<=n;++i)
    {    
        lcs.extend(S[i]-'a'); 
        lcp.extend(S[n-i+1]-'a'); 
    }
    lcs.Build_LCS(); 
    lcp.Build_LCP();
    tot=lcs.tot;     
    Build_Tree(1,0); 
    vir::Initialize();    
    totsize=tot;
    Divide_And_conquer(1);
    printf("%llu\n",ans); 
    return 0;
}

  

posted @ 2019-12-27 15:48  EM-LGH  阅读(236)  评论(0编辑  收藏  举报