洛谷P4770 [NOI2018]你的名字 后缀自动机+线段树合并
洛谷P4770 [NOI2018]你的名字
题意
给定一个字符串\(S\),有\(Q\)次询问,每次询问给定一个区间\([l,r]\)和一个字符串\(T\),问\(T\)中有多少本质不同的子串且不是\(S[l;r]\)的子串。
\(|S|\le 5\cdot 10^5,Q\le 10^5,\sum|T| \le 10^6\)。
分析
对\(S\)建后缀自动机,用线段树合并维护\(right\)集合,对\(T\)中的每个前缀\([1;i]\)在\(S[l;r]\)上匹配,设最长长度为\(mx[i]\),对\(T\)建后缀自动机再按拓扑序更新一遍\(mx[fa[i]]=max(mx[fa[i]],mx[i])\),答案就是\(\sum len[i]-max(mx[i],len[fa[i]])\)。怎么在\(s[l;r]\)上匹配呢,假设在\([1;i-1]\)这个前缀上匹配的长度为\(L\),在\(S\)的\(sam\)上的点为\(u\),若\(u\)存在到\(T[i]\)的转移边转移到点\(x\),且\(x\)的\(right\)集合在区间\([l+L,r]\)中有一个元素,说明可以继续匹配下去,更新长度,否则\(L--\),继续尝试转移,如果\(L==len[fa[u]]\),\(u=fa[u]\)。
Code
#include<bits/stdc++.h>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,rs[p]
#define pii pair<int,int>
#define lson l,mid,ls[p]
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=2e6+10;
const int M=5e5+10;
const int inf=1e9;
int n,m,q;
char s[N],t[N];
int sum[N],id[N],rt[N];
vector<int>g[N];
struct SegmentTree{
int tr[M*40];
int ls[M*40],rs[M*40],tot;
void up(int x,int l,int r,int &p){
if(!p) p=++tot;
if(l==r){
tr[p]=l;
return;
}
int mid=l+r>>1;
if(x<=mid) up(x,lson);
else up(x,rson);
tr[p]=max(tr[ls[p]],tr[rs[p]]);
}
int merge(int x,int y,int l,int r){
if(!x||!y) return x+y;
int p=++tot,mid=l+r>>1;
if(l==r){
tr[p]=l;
}else{
ls[p]=merge(ls[x],ls[y],l,mid);
rs[p]=merge(rs[x],rs[y],mid+1,r);
tr[p]=max(tr[ls[p]],tr[rs[p]]);
}
return p;
}
int qy(int dl,int dr,int l,int r,int p){
if(!p||dl>dr) return 0;
if(l==dl&&r==dr) return tr[p];
int mid=l+r>>1;
if(dr<=mid) return qy(dl,dr,lson);
else if(dl>mid) return qy(dl,dr,rson);
else return max(qy(dl,mid,lson),qy(mid+1,dr,rson));
}
}seg;
struct SAM{
int last,cnt;int ch[N][27],fa[N],len[N],mx[N];
int newnode(){
++cnt;
mx[cnt]=0;
memset(ch[cnt],0,sizeof ch[cnt]);
return cnt;
}
void insert(int c){
int p=last,np=newnode();last=np;len[np]=len[p]+1;
for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
if(!p) fa[np]=1;
else {
int q=ch[p][c];
if(len[q]==len[p]+1) fa[np]=q;
else {
int nq=newnode();len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof ch[q]);
fa[nq]=fa[q],fa[q]=fa[np]=nq;
for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
}
}
}
void init(){
last=cnt=1;
mx[cnt]=0;
memset(ch[cnt],0,sizeof ch[cnt]);
}
void dfs(int u){
for(int x:g[u]){
dfs(x);
rt[u]=seg.merge(rt[u],rt[x],1,n);
}
}
ll gao(){
for(int i=1;i<=cnt;i++) sum[i]=0;
for(int i=1;i<=cnt;i++) sum[len[i]]++;
for(int i=1;i<=cnt;i++) sum[i]+=sum[i-1];
for(int i=1;i<=cnt;i++) id[sum[len[i]]--]=i;
for(int i=cnt;i>=1;i--) mx[fa[id[i]]]=max(mx[fa[id[i]]],mx[id[i]]);
ll ans=0;
for(int i=2;i<=cnt;i++) ans+=max(0,len[i]-max(mx[i],len[fa[i]]));
return ans;
}
void build(){
for(int i=2;i<=cnt;i++) g[fa[i]].pb(i);
dfs(1);
}
}S,T;
void solve(int cas){
int l,r;
scanf("%s%d%d",t+1,&l,&r);
m=strlen(t+1);
T.init();
int u=1,L=0;
for(int i=1;i<=m;i++){
int c=t[i]-'a';
while(u!=1&&(!(S.ch[u][c]&&seg.qy(l+L,r,1,n,rt[S.ch[u][c]])))){
L--;
if(L==S.len[S.fa[u]]) u=S.fa[u];
}
if(S.ch[u][c]&&seg.qy(l+L,r,1,n,rt[S.ch[u][c]])) u=S.ch[u][c],L++;
T.insert(c);
T.mx[T.last]=L;
}
printf("%lld\n",T.gao());
}
int main(){
scanf("%s",s+1);
n=strlen(s+1);
S.init();
for(int i=1;i<=n;i++){
S.insert(s[i]-'a');
seg.up(i,1,n,rt[S.last]);
}
S.build();
scanf("%d",&q);
for(int i=1;i<=q;i++){
solve(i);
}
return 0;
}