Codeforces 666E Forensic Examination 广义后缀自动机+线段树合并+树上倍增

Codeforces 666E Forensic Examination

题意

给出一个字符串\(s\)\(m\)个字符串\(t_1,t_2,\dots,t_m\)\(q\)次询问,每次询问给出四个整数\(l,r,pl,pr\),问\(t_l,t_{l+1},\dots,t_r\)中哪个字符串中\(s[pl;pr]\)作为子串出现次数最多,输出该\(t_i\)的下标和\(s[pl;pr]\)的出现次数。

\(|s| \le 5\cdot 10^5,m\le 5 \cdot 10^4,\sum|t_i| \le 5 \cdot 10^4,1\le l \le r \le m,1\le pl \le pr \le |s|\)

分析

把所有\(t_i\)后加一个字符'#',然后连起来变成\(t_1\#t_2\#\dots t_n\)建后缀自动机,对自动机上每个状态建线段树,\(cnt[v][i]\)表示状态\(v\)的结束位置属于字符串\(t_i\)的个数,线段树维护区间最大值。

把询问离线,然后用字符串\(s\)直接在自动机上跑,对\(s\)的每个前缀\([1;i]\)在自动机上找到能匹配上的最长后缀所在的状态\(v\),遍历所有询问中\(pr=i\)的,按\(pl\)升序遍历,对于每个询问让\(v\)跳后缀链接,跳到一个\(longest(v)\)刚好大于等于\(i-pl+1\)的状态,然后在这个状态的线段树上区间查询\([l,r]\)的最大值即可。

Code

#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,rs[p]
#define pii pair<int,int>
#define lson l,mid,ls[p]
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=1e6+10;
const int M=2e5+10;
const int inf=1e9;
int n,m,q;
char s[N],t[N];
vector<int>g[N];
int ed[N],now[N];
pii cmax(pii a,pii b){
    if(a.se==b.se){
        if(a.fi<b.fi) return a;
        return b;
    }
    if(a.se>b.se) return a;
    return b;
}
struct SegmentTree{
    pii tr[M*40];
    int ls[M*40],rs[M*40],tot;
    void up(int x,int l,int r,int &p){
        if(!p) p=++tot;
        if(l==r){
            tr[p].se++;
            tr[p].fi=l;
            return;
        }
        int mid=l+r>>1;
        if(x<=mid) up(x,lson);
        else up(x,rson);
        tr[p]=cmax(tr[ls[p]],tr[rs[p]]);
    }
    int merge(int x,int y,int l,int r){
        if(!x||!y) return x+y;
        int p=++tot,mid=l+r>>1;
        if(l==r){
            tr[p]=mp(l,tr[x].se+tr[y].se);
        }else{
            ls[p]=merge(ls[x],ls[y],l,mid);
            rs[p]=merge(rs[x],rs[y],mid+1,r);
            tr[p]=cmax(tr[ls[p]],tr[rs[p]]);
        }
        return p;
    }
    pii qy(int dl,int dr,int l,int r,int p){
        if(!p||l>r) return mp(dl,0);
        if(l==dl&&r==dr) return tr[p];
        int mid=l+r>>1;
        if(dr<=mid) return qy(dl,dr,lson);
        else if(dl>mid) return qy(dl,dr,rson);
        else return cmax(qy(dl,mid,lson),qy(mid+1,dr,rson));
    }
}seg;
struct SAM{
    int last,cnt;int ch[M][27],fa[M],len[M],rt[M],d[M],f[M][22];
    void insert(int c,int pos){
        int p=last,np=++cnt;last=np;len[np]=len[p]+1;
        for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
        if(!p) fa[np]=1;
        else {
            int q=ch[p][c];
            if(len[q]==len[p]+1) fa[np]=q;
            else  {
                int nq=++cnt;len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof ch[q]);
                fa[nq]=fa[q],fa[q]=fa[np]=nq;
                for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
            }
        }
        seg.up(pos,1,m,rt[np]);
    }
    void init(){
        last=cnt=1;
    }
    void dfs(int u){
        f[u][0]=fa[u];d[u]=d[fa[u]]+1;
        for(int i=1;(1<<i)<=d[u];i++){
            f[u][i]=f[f[u][i-1]][i-1];
        }
        for(int x:g[u]){
            dfs(x);
            rt[u]=seg.merge(rt[u],rt[x],1,m);
        }
    }
    void gao(){
        for(int i=2;i<=cnt;i++) g[fa[i]].pb(i);
        d[0]=-1;
        dfs(1);
        int u=1,l=0;
        for(int i=1;i<=n;i++){
            while(!ch[u][s[i]-'a']&&u!=1) u=fa[u],l=len[u];
            if(ch[u][s[i]-'a']) u=ch[u][s[i]-'a'],l++,
            ed[i]=u,now[i]=l;
        }
    }
    pii solve(int l,int r,int pl,int pr){
        int u=ed[pr],dl=now[pr];
        if(dl<pr-pl+1) return mp(l,0);
        for(int i=20;i>=0;i--){
            if((1<<i)<=d[u]&&len[f[u][i]]>=pr-pl+1) u=f[u][i];
        }
        return seg.qy(l,r,1,m,rt[u]);
    }
}sam;
int main(){
    scanf("%s",s+1);
    n=strlen(s+1);
    sam.init();
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        scanf("%s",t);
        int l=strlen(t);
        for(int j=0;j<l;j++) sam.insert(t[j]-'a',i);
        if(i!=m) sam.insert(26,i);
    }
    sam.gao();
    scanf("%d",&q);
    for(int i=1,l,r,pl,pr;i<=q;i++){
        scanf("%d%d%d%d",&l,&r,&pl,&pr);
        pii ans=sam.solve(l,r,pl,pr);
        printf("%d %d\n",ans.fi,ans.se);
    }
    return 0;
}
posted @ 2020-10-31 01:16  xyq0220  阅读(103)  评论(0编辑  收藏  举报