P4770 [NOI2018] 你的名字 题解

一句话题意:给定一个字符串 \(S\) ,询问 \(Q\) 次,每次询问字符串 \(T\) 有多少个本质不同子串在字符串 \(S_{l...r}\) 中没有出现。

一步转化:发现求没出现的子串不好求,那就求出现过的子串,最后用 \(T\) 的本质不同子串数减去本质不同公共子串数即为答案。

无限制情况

有区间的限制不好做,先考虑没有区间限制,即 \(l=1,r=|S|\) 的情况。

单串的本质不同子串很好求,不用多说,难点就是在求本质不同公共子串上。

我们记 \(fir[i]\) 表示 \(T_{1...i}\) 的后缀不是第一次出现在 \(i\) 位置的后缀数。(后缀第一次出现时的右端点不是 \(i\) 的后缀数)

显然 \(\sum_{i=1}^{|T|}i-fir[i]\) 就是单串的本质不同子串数。

而求出 \(fir[i]\) 很简单,每次将 \(t[i]\) 插入进 \(SAM\) 中,\(fir[i]=len[fa[lst]]\) ,意思就是 \(t[i]\)\(SAM\) 中的节点的父亲(后缀链接)表示的子串的最大长度。(这里应该好好想想)

考虑求出公共子串数,套路就是把 \(T\) 串放在 \(S\) 串的 \(SAM\) 上跑匹配。因为要求本质不同,所以每次匹配完的贡献是 \(cnt-fir[i]\) 。( \(cnt\) 是匹配的最长子串长度)

于是第一部分的分的拿到手了。

代码:

#include<bits/stdc++.h>
#define pc(x) putchar(x)
#define ll long long
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){f=ch=='-'?-1:f;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
void write(ll x)
{
    if(x<0){x=-x;putchar('-');}
    if(x>9)write(x/10);
    putchar(x%10+48);
}
int n,m;
char s[500005],t[500005];
int fir[500005];
struct SAM
{
    int ch[1000005][26],fa[1000005],len[1000005];
    int lst,cnt;
    void init(){lst=cnt=1;memset(ch[1],0,sizeof ch[1]);fa[1]=len[1]=0;}
    int newnode(){++cnt;memset(ch[cnt],0,sizeof ch[cnt]);fa[cnt]=len[cnt]=0;return cnt;}
    void copy(int x,int y)
    {
        for(int i=0;i<26;++i)
            ch[x][i]=ch[y][i];
        fa[x]=fa[y];len[x]=len[y];
    }
    void insert(int c)
    {
        int p=lst,np=lst=newnode();len[np]=len[p]+1;
        for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
        if(!p){fa[np]=1;return;}int q=ch[p][c];
        if(len[q]==len[p]+1){fa[np]=q;return;}
        int nq=newnode();copy(nq,q);len[nq]=len[p]+1;
        fa[np]=fa[q]=nq;
        for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
    }
    ll solve()//这里就是求解本质不同公共子串数
    {
        int p=1,cnt=0;ll res=0;
        for(int i=1;i<=m;++i)
        {
            while(p&&!ch[p][t[i]-'a'])p=fa[p],cnt=len[p];
            if(!p)p=1,cnt=0;
            else{p=ch[p][t[i]-'a'];++cnt;}
            if(cnt>fir[i])res+=cnt-fir[i];
        }return res;
    }
}tr1,tr2;
int main()
{
    scanf("%s",s+1);n=strlen(s+1);tr1.init();
    for(int i=1;i<=n;++i)tr1.insert(s[i]-'a');
    int q=read();
    while(q-->0)
    {
        scanf("%s",t+1);m=strlen(t+1);
        int l=read(),r=read();ll res=0;tr2.init();
        for(int i=1;i<=m;++i)
            tr2.insert(t[i]-'a'),fir[i]=tr2.len[tr2.fa[tr2.lst]];
        for(int i=1;i<=m;++i)res+=i-fir[i];//这里是求解单串本质不同子串数
        write(res-tr1.solve()),pc('\n');
    }return 0;
}

正解

现在多了一个区间的限制条件,很容易想到要处理节点的 \(endpos\) 集合,每次匹配转移时要判断下一个节点的 \(endpos\) 集合中有没有元素在 \(L...R\) 之间,我们为了保证匹配的长度尽可能的大,选择匹配的子串右端点要尽量的大。这一部分可以用线段树合并来做,线段树维护区间 \(endpos\) 最大值,时间复杂度是 \(O(n)\) 的。

这样直接处理答案其实不是对的,因为我们查找 \(endpos\) 的时候,是在能匹配的最长长度的节点上的线段树查询的,而长度越长,\(endpos\) 就越少,也就说明,我们在较长的匹配上找到的在区间中的匹配长度可能会很小,导致剩下的部分比 \(fir[i]\) 小,无法对答案产生贡献。但是如果我们在匹配长度较短(它的父亲节点)的线段树上查找的话,反而能找到一个 \(endpos\) ,在 \(L\) 的限制之内匹配的长度比 \(fir[i]\) 大,或者所有的匹配长度干脆都不会超出 \(L\) 。这时候就需要把这些长度加到答案里。

所以需要跳匹配节点的父亲,更新答案,直到匹配长度比 \(L\)\(endpos\) 最大值的长度还小,这时候就不再有新公共子串产生了。

代码:

#include<bits/stdc++.h>
#define pc(x) putchar(x)
#define ls(x) (seg[x].l)
#define rs(x) (seg[x].r)
#define mx(x) (seg[x].mx)
#define ll long long
using namespace std;
inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){f=ch=='-'?-1:f;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
void write(ll x)
{
    if(x<0){x=-x;putchar('-');}
    if(x>9)write(x/10);
    putchar(x%10+48);
}
int n,m;
char s[500005],t[500005];
int fir[500005];
struct seg_val_tree
{
    int l,r,mx;
}seg[20000005];
int tot,rt[1000005],cur[1000005];
void insert(int &pos,int l,int r,int k)
{
    pos=++tot;if(l==r){mx(pos)=l;return;}
    int mid=(l+r)>>1;
    if(k<=mid)insert(ls(pos),l,mid,k);
    else    insert(rs(pos),mid+1,r,k);
    mx(pos)=max(mx(ls(pos)),mx(rs(pos)));
}
int merge(int x,int y,int l,int r)
{
    if(!x||!y)return x+y;
    if(l==r){mx(x)=max(mx(x),mx(y));return x;}
    int mid=(l+r)>>1,pos=++tot;
    ls(pos)=merge(ls(x),ls(y),l,mid);
    rs(pos)=merge(rs(x),rs(y),mid+1,r);
    mx(pos)=max(mx(ls(pos)),mx(rs(pos)));
    return pos;
}
int query(int pos,int l,int r,int L,int R)
{
    if(!pos)return 0;
    if(L<=l&&r<=R)return mx(pos);
    int mid=(l+r)>>1,res=0;
    if(L<=mid)res=max(res,query(ls(pos),l,mid,L,R));
    if(R>mid)res=max(res,query(rs(pos),mid+1,r,L,R));
    return res;
}
int c[1000005],a[1000005];
struct SAM
{
    int ch[1000005][26],fa[1000005],len[1000005];
    int lst,cnt;
    void init(){lst=cnt=1;memset(ch[1],0,sizeof ch[1]);fa[1]=len[1]=0;}
    int newnode(){++cnt;memset(ch[cnt],0,sizeof ch[cnt]);fa[cnt]=len[cnt]=0;return cnt;}
    void copy(int x,int y)
    {
        for(int i=0;i<26;++i)ch[x][i]=ch[y][i];
        fa[x]=fa[y];len[x]=len[y];
    }
    void insert(int c)
    {
        int p=lst,np=lst=newnode();len[np]=len[p]+1;
        for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
        if(!p){fa[np]=1;return;}int q=ch[p][c];
        if(len[q]==len[p]+1){fa[np]=q;return;}
        int nq=newnode();copy(nq,q);len[nq]=len[p]+1;
        fa[np]=fa[q]=nq;
        for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
    }
    void pre()//预处理endpos集合,线段树合并
    {
        for(int i=1;i<=cnt;++i)c[len[i]]++;
        for(int i=1;i<=cnt;++i)c[i]+=c[i-1];
        for(int i=cnt;i>1;--i)a[c[len[i]]--]=i;
        for(int i=cnt;i>1;--i)
        {
            int pos=a[i],f=fa[pos];
            rt[f]=merge(rt[f],rt[pos],1,n);
        }
    }
    ll calc(int L,int R)//统计本质不同公共子串
    {
        int p=1,pos=0,cnt=0;ll res=0;
        for(int i=1;i<=m;++i)
        {
            while(p&&(!ch[p][t[i]-'a']||(pos=query(rt[ch[p][t[i]-'a']],1,n,L,R))==0))
                p=fa[p],cnt=len[p];
            if(!p)p=1,cnt=0;
            else
            {
                p=ch[p][t[i]-'a'];++cnt;
                int length=min(cnt,pos-L+1);
                if(length>fir[i])res+=length-fir[i];
                if(cnt>pos-L+1)
                {
                    int length3=max(length,fir[i]),q=fa[p],cnt2=len[q];
                    while(q)
                    {
                        int pos2=query(rt[q],1,n,L,R),length2=min(cnt2,pos2-L+1);
                        if(length2>length3)res+=length2-length3;
                        else if(cnt2<=pos2-L+1)break;
                        q=fa[q];cnt2=len[q];length3=max(length2,fir[i]);
                    }
                }
            }
        }return res;
    }
}tr1,tr2;
int main()
{
    scanf("%s",s+1);n=strlen(s+1);tr1.init();
    for(int i=1;i<=n;++i)
    {
        tr1.insert(s[i]-'a');cur[i]=tr1.lst;
        insert(rt[cur[i]],1,n,i);
    }int q=read();tr1.pre();
    while(q-->0)
    {
        scanf("%s",t+1);m=strlen(t+1);
        int L=read(),R=read();ll res=0;tr2.init();
        for(int i=1;i<=m;++i)
            tr2.insert(t[i]-'a'),fir[i]=tr2.len[tr2.fa[tr2.lst]];
        for(int i=1;i<=m;++i)res+=i-fir[i];
        write(res-tr1.calc(L,R)),pc('\n');
    }return 0;
}

大牛逼题!正解部分不看看题解根本想不到这样去统计答案

posted @ 2022-03-24 11:45  violetctl39  阅读(102)  评论(0编辑  收藏  举报