NOI2018 你的名字 后缀自动机 + 线段树合并 + 可持久化

这个是满分做法, 68pts 做法在上一篇博客中 
会 68 pts 做法后就十分简单了,只要来一遍线段树合并 right 集合并在匹配的时候判一下是否在 $[l,r]$ 区间中即可
#include <cstdio>
#include <algorithm>
#include <cstring>
#define maxn 5000000 
#define N 30 
#define ll long long 
#define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout) 
using namespace std;

char str[maxn],ss[maxn]; 
int str_len,ss_len; 
int nodes;
int C[maxn],rk[maxn]; 
int rt[maxn];        
int pos[maxn],mx[maxn]; 

struct Segment_Tree{ int l,r,sumv;  }node[maxn<<2];  
int newnode(){ return ++nodes; } 
void modify(int p,int l,int r,int &o){
    if(!o) o=newnode();  
    ++node[o].sumv;    
    if(l==r) return;           
    int mid=(l+r)>>1;
    if(p<=mid) modify(p,l,mid,node[o].l);
    else modify(p,mid+1,r,node[o].r);  
}
int merge(int x,int y){
    if(!x||!y) return x+y; 
    int o=newnode();
    node[o].sumv=node[x].sumv+node[y].sumv;                              
    node[o].l=merge(node[x].l,node[y].l);
    node[o].r=merge(node[x].r,node[y].r); 
    return o; 
}
int query(int l,int r,int L,int R,int o){
    if(l>r||r<L||l>R) return 0;     
    if(l>=L&&r<=R) return node[o].sumv; 
    int mid=(l+r)>>1,res=0;
    if(node[o].l) res+=query(l,mid,L,R,node[o].l); 
    if(node[o].r) res+=query(mid+1,r,L,R,node[o].r); 
    return res; 
}

struct SAM{  
    int last,tot,dis[maxn],ch[maxn][N],f[maxn],pos[maxn];      
    void init() { last=++tot;  }
    void ins(int c,int y,int z,int rot){
        int p=last,np=++tot; last=np; dis[np]=dis[p]+1; pos[np] = z; 
        while(p&&!ch[p][c])ch[p][c]=np,p=f[p];        
        if(!p) f[np]=rot;
        else{      
            int q=ch[p][c],nq; 
            if(dis[q]==dis[p]+1) f[np]=q;
            else{
                nq=++tot;
                dis[nq]=dis[p]+1; 
                pos[nq] = pos[q];                    
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                f[nq]=f[q],f[q]=f[np]=nq;
                while(p&&ch[p][c]==q) ch[p][c]=nq,p=f[p];
            }
        }
        if(y) modify(y,1,str_len,rt[np]);               
    }   
    void build_S1(){
        for(int i=1;i<=tot;++i) C[dis[i]]++;  
        for(int i=1;i<=tot;++i) C[i]+=C[i-1]; 
        for(int i=1;i<=tot;++i) rk[C[dis[i]]--]=i;     
        for(int i=tot;i>=1;--i) {                    
            int p=rk[i];            
            rt[f[p]] = merge(rt[f[p]],rt[p]);        
        } 
    } 
}S1,S2;

int main(){
    //setIO("input");
    scanf("%s",str+1),str_len=strlen(str+1),S1.init();
    for(int i=1;i<=str_len;++i) S1.ins(str[i]-'a',i,0,1);         
    S1.build_S1(); 
    int queries,l,r,st,ed; 
    scanf("%d",&queries); 
    while(queries--) {

        S2.init(),st=S2.tot,scanf("%s%d%d",ss+1,&l,&r),ss_len=strlen(ss+1); 

        for(int i=1;i<=ss_len;++i) S2.ins(ss[i]-'a',0,i,st);

        ed=S2.tot;  


        int cnt=0,p=1;                      
        long long ans = 0;

        for(int i=1;i<=ss_len;++i) {
            int c=ss[i]-'a';
            while(1){
                if(S1.ch[p][c] && query(1,str_len,l+cnt,r,rt[S1.ch[p][c]]))
                {
                    ++cnt,p=S1.ch[p][c]; break; 
                }
                else {
                    if(!cnt) {p=1;break;       } 
                    --cnt; 
                    if(cnt==S1.dis[S1.f[p]]) p=S1.f[p]; 
                }
            }
            mx[i]=cnt; 
        }

        for(int i=st;i<=ed;++i)   
            ans+=max(0,S2.dis[i]-max(mx[S2.pos[i]],S2.dis[S2.f[i]]));             
        printf("%lld\n",ans); 
    }
    return 0; 
}

  

posted @ 2019-02-15 13:20  EM-LGH  阅读(400)  评论(0编辑  收藏  举报