P4770 [NOI2018] 你的名字 题解
一句话题意:给定一个字符串 \(S\) ,询问 \(Q\) 次,每次询问字符串 \(T\) 有多少个本质不同子串在字符串 \(S_{l...r}\) 中没有出现。
一步转化:发现求没出现的子串不好求,那就求出现过的子串,最后用 \(T\) 的本质不同子串数减去本质不同公共子串数即为答案。
无限制情况
有区间的限制不好做,先考虑没有区间限制,即 \(l=1,r=|S|\) 的情况。
单串的本质不同子串很好求,不用多说,难点就是在求本质不同公共子串上。
我们记 \(fir[i]\) 表示 \(T_{1...i}\) 的后缀不是第一次出现在 \(i\) 位置的后缀数。(后缀第一次出现时的右端点不是 \(i\) 的后缀数)
显然 \(\sum_{i=1}^{|T|}i-fir[i]\) 就是单串的本质不同子串数。
而求出 \(fir[i]\) 很简单,每次将 \(t[i]\) 插入进 \(SAM\) 中,\(fir[i]=len[fa[lst]]\) ,意思就是 \(t[i]\) 在 \(SAM\) 中的节点的父亲(后缀链接)表示的子串的最大长度。(这里应该好好想想)
考虑求出公共子串数,套路就是把 \(T\) 串放在 \(S\) 串的 \(SAM\) 上跑匹配。因为要求本质不同,所以每次匹配完的贡献是 \(cnt-fir[i]\) 。( \(cnt\) 是匹配的最长子串长度)
于是第一部分的分的拿到手了。
代码:
#include<bits/stdc++.h>
#define pc(x) putchar(x)
#define ll long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){f=ch=='-'?-1:f;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void write(ll x)
{
if(x<0){x=-x;putchar('-');}
if(x>9)write(x/10);
putchar(x%10+48);
}
int n,m;
char s[500005],t[500005];
int fir[500005];
struct SAM
{
int ch[1000005][26],fa[1000005],len[1000005];
int lst,cnt;
void init(){lst=cnt=1;memset(ch[1],0,sizeof ch[1]);fa[1]=len[1]=0;}
int newnode(){++cnt;memset(ch[cnt],0,sizeof ch[cnt]);fa[cnt]=len[cnt]=0;return cnt;}
void copy(int x,int y)
{
for(int i=0;i<26;++i)
ch[x][i]=ch[y][i];
fa[x]=fa[y];len[x]=len[y];
}
void insert(int c)
{
int p=lst,np=lst=newnode();len[np]=len[p]+1;
for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
if(!p){fa[np]=1;return;}int q=ch[p][c];
if(len[q]==len[p]+1){fa[np]=q;return;}
int nq=newnode();copy(nq,q);len[nq]=len[p]+1;
fa[np]=fa[q]=nq;
for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
}
ll solve()//这里就是求解本质不同公共子串数
{
int p=1,cnt=0;ll res=0;
for(int i=1;i<=m;++i)
{
while(p&&!ch[p][t[i]-'a'])p=fa[p],cnt=len[p];
if(!p)p=1,cnt=0;
else{p=ch[p][t[i]-'a'];++cnt;}
if(cnt>fir[i])res+=cnt-fir[i];
}return res;
}
}tr1,tr2;
int main()
{
scanf("%s",s+1);n=strlen(s+1);tr1.init();
for(int i=1;i<=n;++i)tr1.insert(s[i]-'a');
int q=read();
while(q-->0)
{
scanf("%s",t+1);m=strlen(t+1);
int l=read(),r=read();ll res=0;tr2.init();
for(int i=1;i<=m;++i)
tr2.insert(t[i]-'a'),fir[i]=tr2.len[tr2.fa[tr2.lst]];
for(int i=1;i<=m;++i)res+=i-fir[i];//这里是求解单串本质不同子串数
write(res-tr1.solve()),pc('\n');
}return 0;
}
正解
现在多了一个区间的限制条件,很容易想到要处理节点的 \(endpos\) 集合,每次匹配转移时要判断下一个节点的 \(endpos\) 集合中有没有元素在 \(L...R\) 之间,我们为了保证匹配的长度尽可能的大,选择匹配的子串右端点要尽量的大。这一部分可以用线段树合并来做,线段树维护区间 \(endpos\) 最大值,时间复杂度是 \(O(n)\) 的。
这样直接处理答案其实不是对的,因为我们查找 \(endpos\) 的时候,是在能匹配的最长长度的节点上的线段树查询的,而长度越长,\(endpos\) 就越少,也就说明,我们在较长的匹配上找到的在区间中的匹配长度可能会很小,导致剩下的部分比 \(fir[i]\) 小,无法对答案产生贡献。但是如果我们在匹配长度较短(它的父亲节点)的线段树上查找的话,反而能找到一个 \(endpos\) ,在 \(L\) 的限制之内匹配的长度比 \(fir[i]\) 大,或者所有的匹配长度干脆都不会超出 \(L\) 。这时候就需要把这些长度加到答案里。
所以需要跳匹配节点的父亲,更新答案,直到匹配长度比 \(L\) 到 \(endpos\) 最大值的长度还小,这时候就不再有新公共子串产生了。
代码:
#include<bits/stdc++.h>
#define pc(x) putchar(x)
#define ls(x) (seg[x].l)
#define rs(x) (seg[x].r)
#define mx(x) (seg[x].mx)
#define ll long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){f=ch=='-'?-1:f;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
void write(ll x)
{
if(x<0){x=-x;putchar('-');}
if(x>9)write(x/10);
putchar(x%10+48);
}
int n,m;
char s[500005],t[500005];
int fir[500005];
struct seg_val_tree
{
int l,r,mx;
}seg[20000005];
int tot,rt[1000005],cur[1000005];
void insert(int &pos,int l,int r,int k)
{
pos=++tot;if(l==r){mx(pos)=l;return;}
int mid=(l+r)>>1;
if(k<=mid)insert(ls(pos),l,mid,k);
else insert(rs(pos),mid+1,r,k);
mx(pos)=max(mx(ls(pos)),mx(rs(pos)));
}
int merge(int x,int y,int l,int r)
{
if(!x||!y)return x+y;
if(l==r){mx(x)=max(mx(x),mx(y));return x;}
int mid=(l+r)>>1,pos=++tot;
ls(pos)=merge(ls(x),ls(y),l,mid);
rs(pos)=merge(rs(x),rs(y),mid+1,r);
mx(pos)=max(mx(ls(pos)),mx(rs(pos)));
return pos;
}
int query(int pos,int l,int r,int L,int R)
{
if(!pos)return 0;
if(L<=l&&r<=R)return mx(pos);
int mid=(l+r)>>1,res=0;
if(L<=mid)res=max(res,query(ls(pos),l,mid,L,R));
if(R>mid)res=max(res,query(rs(pos),mid+1,r,L,R));
return res;
}
int c[1000005],a[1000005];
struct SAM
{
int ch[1000005][26],fa[1000005],len[1000005];
int lst,cnt;
void init(){lst=cnt=1;memset(ch[1],0,sizeof ch[1]);fa[1]=len[1]=0;}
int newnode(){++cnt;memset(ch[cnt],0,sizeof ch[cnt]);fa[cnt]=len[cnt]=0;return cnt;}
void copy(int x,int y)
{
for(int i=0;i<26;++i)ch[x][i]=ch[y][i];
fa[x]=fa[y];len[x]=len[y];
}
void insert(int c)
{
int p=lst,np=lst=newnode();len[np]=len[p]+1;
for(;p&&!ch[p][c];p=fa[p])ch[p][c]=np;
if(!p){fa[np]=1;return;}int q=ch[p][c];
if(len[q]==len[p]+1){fa[np]=q;return;}
int nq=newnode();copy(nq,q);len[nq]=len[p]+1;
fa[np]=fa[q]=nq;
for(;p&&ch[p][c]==q;p=fa[p])ch[p][c]=nq;
}
void pre()//预处理endpos集合,线段树合并
{
for(int i=1;i<=cnt;++i)c[len[i]]++;
for(int i=1;i<=cnt;++i)c[i]+=c[i-1];
for(int i=cnt;i>1;--i)a[c[len[i]]--]=i;
for(int i=cnt;i>1;--i)
{
int pos=a[i],f=fa[pos];
rt[f]=merge(rt[f],rt[pos],1,n);
}
}
ll calc(int L,int R)//统计本质不同公共子串
{
int p=1,pos=0,cnt=0;ll res=0;
for(int i=1;i<=m;++i)
{
while(p&&(!ch[p][t[i]-'a']||(pos=query(rt[ch[p][t[i]-'a']],1,n,L,R))==0))
p=fa[p],cnt=len[p];
if(!p)p=1,cnt=0;
else
{
p=ch[p][t[i]-'a'];++cnt;
int length=min(cnt,pos-L+1);
if(length>fir[i])res+=length-fir[i];
if(cnt>pos-L+1)
{
int length3=max(length,fir[i]),q=fa[p],cnt2=len[q];
while(q)
{
int pos2=query(rt[q],1,n,L,R),length2=min(cnt2,pos2-L+1);
if(length2>length3)res+=length2-length3;
else if(cnt2<=pos2-L+1)break;
q=fa[q];cnt2=len[q];length3=max(length2,fir[i]);
}
}
}
}return res;
}
}tr1,tr2;
int main()
{
scanf("%s",s+1);n=strlen(s+1);tr1.init();
for(int i=1;i<=n;++i)
{
tr1.insert(s[i]-'a');cur[i]=tr1.lst;
insert(rt[cur[i]],1,n,i);
}int q=read();tr1.pre();
while(q-->0)
{
scanf("%s",t+1);m=strlen(t+1);
int L=read(),R=read();ll res=0;tr2.init();
for(int i=1;i<=m;++i)
tr2.insert(t[i]-'a'),fir[i]=tr2.len[tr2.fa[tr2.lst]];
for(int i=1;i<=m;++i)res+=i-fir[i];
write(res-tr1.calc(L,R)),pc('\n');
}return 0;
}
大牛逼题!正解部分不看看题解根本想不到这样去统计答案