P4770 [NOI2018]你的名字(SAM+线段树合并)

题目链接

https://www.luogu.com.cn/problem/P4770

题意

给出一个字符串\(S\)\(Q\)组询问,每次询问字符串\(T\)中有多少个本质不同的子串在\(S[l..r]\)中没有出现。

思路

问题转化为有几个\(T\)的本质不同的子串满足这个子串在\(S\)的给定区间当中出现了,然后我们求出这个东西之后拿\(T\)的所有本质不同子串去减就能得到答案了
先考虑一下\(l=1,r=|S|\)的情况
\(lim[i]\)表示字符串\(T[1..i]\)能在\(S\)中匹配到的最长后缀(即\(T[i−lim[i]+1,i]\)\(S\)的子串且\(lim[i]\)最大)
这个\(lim[i]\)可以不断地在\(S\)的后缀自动机上跳来求出。当无法向下匹配时,一直跳\(parent\)树直到可以匹配为止。
再对\(T\)建后缀自动机
那么答案就是:\(ans=\sum_{i=2}^{cnt}max(0,len[i]−max(len[fa[i]],lim[pos[i]]))\)
其中\(pos[i]\)表示该集合字符串出现的位置(随便一个都行,这里取最大的位置)
然后\(l\)\(r\)任意的情况用线段树维护\(S\)后缀自动机的\(endpos\)集合就行了
在求\(lim[i]\)时要注意一点,匹配失败不能直接跳\(parent\),节点集合中的字符串要逐个匹配。

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxx = 2*1e6+10;
struct SAM
{
    int last,tot,fa[maxx],ch[maxx][26],len[maxx],pos[maxx];
    void init()
    {
        last=tot=1;
        memset(ch[1],0,sizeof(ch[1]));
    }
    void add(int x)
    {
        int pre=last,now=last=++tot;
        memset(ch[now],0,sizeof(ch[now]));
        len[now]=len[pre]+1;
        pos[now]=len[now];
        for(;pre&&!ch[pre][x];pre=fa[pre])ch[pre][x]=now;
        if(!pre)fa[now]=1;
        else
        {
            int q=ch[pre][x];
            if(len[q]==len[pre]+1)fa[now]=q;
            else
            {
                int nows=++tot;
                memset(ch[nows],0,sizeof(ch[nows]));
                len[nows]=len[pre]+1;
                pos[nows]=pos[now];
                memcpy(ch[nows],ch[q],sizeof(ch[q]));
                fa[nows]=fa[q];fa[q]=fa[now]=nows;
                for(;pre&&ch[pre][x]==q;pre=fa[pre])ch[pre][x]=nows;
            }
        }
    }
}s1,s2;

int sum[50*maxx],ls[50*maxx],rs[50*maxx],rt[maxx],cnt;
int head[maxx],to[maxx],ne[maxx],num;
int n;
void update(int &u,int l,int r,int x)
{
    if(!u)u=++cnt;
    if(l==r)
    {
        sum[u]=1;
        return;
    }
    int mid=(l+r)/2;
    if(x<=mid)update(ls[u],l,mid,x);
    else update(rs[u],mid+1,r,x);
    sum[u]=sum[ls[u]]+sum[rs[u]];
}
int query(int u,int l,int r,int p,int q)
{
    if(p>q)return 0;
    if(!u)return 0;
    if(p<=l&&r<=q)return sum[u];
    int mid=(l+r)/2;
    int ans=0;
    if(p<=mid)ans+=query(ls[u],l,mid,p,q);
    if(q>mid)ans+=query(rs[u],mid+1,r,p,q);
    return ans;
}
int merge(int a,int b,int l,int r)
{
    if(!a)return b;
    if(!b)return a;
    int u=++cnt;
    if(l==r)
    {
        sum[u]=sum[a]|sum[b];
        return u;
    }
    int mid=(l+r)/2;
    ls[u]=merge(ls[a],ls[b],l,mid);
    rs[u]=merge(rs[a],rs[b],mid+1,r);
    sum[u]=sum[ls[u]]+sum[rs[u]];
    return u;
}

void addm(int u,int v)
{
    to[++num]=v,ne[num]=head[u],head[u]=num;
}
void dfs(int u)
{
    for(int i=head[u];i;i=ne[i])
    {
        dfs(to[i]);
        rt[u]=merge(rt[u],rt[to[i]],1,n);
    }
}

char s[maxx],t[maxx];
int lim[maxx];
int main()
{
    scanf("%s",s+1);
    n=strlen(s+1);
    s1.init();
    for(int i=1;i<=n;i++)
    {
        update(rt[s1.tot+1],1,n,i);
        s1.add(s[i]-'a');
    }
    for(int i=2;i<=s1.tot;i++)addm(s1.fa[i],i);
    dfs(1);
    int q,l,r;
    scanf("%d",&q);
    while(q--)
    {
        scanf("%s%d%d",t+1,&l,&r);
        int m=strlen(t+1),now=1,len=0;
        for(int i=1;i<=m;i++)
        {
            int x=t[i]-'a';
            while(now&&!(s1.ch[now][x]&&query(rt[s1.ch[now][x]],1,n,l+len,r)))
            {
                if(len>s1.len[s1.fa[now]])len--;
                else now=s1.fa[now];
            }
            if(!now)now=1,len=0;
            else len++,now=s1.ch[now][x];
            lim[i]=len;
        }
        s2.init();
        for(int i=1;i<=m;i++)s2.add(t[i]-'a');
        LL ans=0;
        for(int i=2;i<=s2.tot;i++)
            ans+=max(0,s2.len[i]-max(s2.len[s2.fa[i]],lim[s2.pos[i]]));
        printf("%lld\n",ans);
    }
    return 0;
}
posted @ 2020-06-01 23:21  灰灰烟影  阅读(180)  评论(0编辑  收藏  举报