Luogu4770 [NOI2018]你的名字
Luogu4770 [NOI2018]你的名字
\(SAM+LCT\)
\(update2020.11.16:\)更新了一只\(\log\)解法。
好歹自己切了一道字符串黑题,这几天字符串没白颓。
观察原问题,如果模式串不是区间形式的话,很容易想到一个做法,就是对于输入串每一个\(r\)位置,除去\(r\)所在的后缀中不满足条件的后缀,显然不满足条件的后缀一定是一段区间,所以从\(r\)后缀在\(SAM\)的\(parent\)树上的匹配的最深位置\(x\)到\(root\)路径中所有的后缀都需要砍掉。
然而模式串是一段区间,可以发现,不满足条件的串仍然是\(x \rightarrow root\)路径上一个位置\(y\)到\(root\)的路径,因为新的不满足条件的后缀一定是原问题的子集。
那么我们考虑找到那个节点,也就是说,该节点代表的集合中存在串\(s\),使得它在\(r\)之前最后出现的位置的左端点\(L \ge l\)。
参考Luogu6292 区间本质不同子串个数的做法,将\(r\)指针一位一位扫过去,同时用\(LCT\)更新最右的右端点。
同时,在\(parent\)树上,节点的\(l\)端点在\(x \rightarrow root\)上一定单调递增,因为对于最右\(r\)端点,显然祖先一定不小于子孙,对于长度,祖先又比子孙小。
既然具有了单调性,我们就可以考虑倍增了,向\(L<l\)的最浅节点跳。在细节上,注意我们最终跳到的节点可能会存在一部分子串满足题意,这些必须统计。
还有一个问题需要处理,就是对于输入的串,即使除去了不满足题意的串,其自身的串依然会存在重复情况,对此,我们对输入串仍需要建立\(SAM\)。
那么我们在\(parent\)树上打标记,我们已经找出了一个后缀长度在\([1,t]\)范围内是不合法的,这在新的\(SAM\)上仍然对应一个节点到根的路径,再次倍增。
最后一次\(DFS\)统计答案即可。
时间复杂度:\(O(\sum \lvert T_i \rvert \log^2 n)\)。
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#define pr pair<int,int>
#define mp make_pair
#define ll long long
#define IT vector< pr > :: iterator
#define N 500005
#define qN 100005
using namespace std;
int l,r;
int lth[qN],g[N << 1];
ll ans[qN];
char s[N];
string T[qN];
vector< pr >e[N];
int n,q;
struct edge
{
int nxt,v;
edge (int Nxt=0,int V=0)
{
nxt=Nxt,v=V;
}
}E[N << 1];
int tot,fr[N << 1];
void add(int x,int y)
{
++tot;
E[tot]=edge(fr[x],y),fr[x]=tot;
}
struct SAM
{
int lst=1,cnt=1,tr[N << 1][26],pre[N << 1],len[N << 1];
int f[N << 1][22];
void ins(int c)
{
int p=lst,q,np;
lst=np=++cnt;
len[np]=len[p]+1;
for (;p && !tr[p][c];p=pre[p])
tr[p][c]=np;
if (!p)
pre[np]=1; else
{
q=tr[p][c];
if (len[p]+1==len[q])
pre[np]=q; else
{
int g=++cnt;
memcpy(tr[g],tr[q],sizeof(tr[q]));
len[g]=len[p]+1,pre[g]=pre[q];
for (;p && tr[p][c]==q;p=pre[p])
tr[p][c]=g;
pre[np]=pre[q]=g;
}
}
}
void Do_st()
{
for (int i=1;i<=cnt;++i)
f[i][0]=pre[i];
for (int j=1;j<=20;++j)
for (int i=1;i<=cnt;++i)
f[i][j]=f[f[i][j-1]][j-1];
}
void build()
{
for (int i=2;i<=cnt;++i)
add(pre[i],i);
}
void Clear()
{
memset(tr,0,26*(cnt+1)*sizeof(int));
memset(g,0,(cnt+1)*sizeof(int));
memset(fr,0,(cnt+1)*sizeof(int));
tot=0,cnt=lst=1;
}
}S1,S2;
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
#define fa(x) a[x].f
#define tag(x) a[x].cltg
#define col(x) a[x].cl
struct LCT
{
int ch[2],f,cltg,cl;
}a[N << 1];
int Q[N << 1];
int id(int x)
{
return ls(fa(x))==x?0:1;
}
bool isrt(int x)
{
return ls(fa(x))!=x && rs(fa(x))!=x;
}
void connect(int x,int F,int son)
{
fa(x)=F;
a[F].ch[son]=x;
}
void rot(int x)
{
int y=fa(x),r=fa(y);
int yson=id(x),rson=id(y);
if (isrt(y))
fa(x)=r; else
connect(x,r,rson);
connect(a[x].ch[yson^1],y,yson);
connect(y,x,yson^1);
}
void push_tag(int x,int z)
{
if (!x)
return;
tag(x)=col(x)=z;
}
void push_down(int x)
{
if (tag(x))
{
push_tag(ls(x),tag(x));
push_tag(rs(x),tag(x));
tag(x)=0;
}
}
void splay(int x)
{
int g=x,k=0;
Q[++k]=x;
while (!isrt(g))
g=fa(g),Q[++k]=g;
while (k)
push_down(Q[k--]);
while (!isrt(x))
{
int y=fa(x);
if (isrt(y))
rot(x); else
if (id(x)==id(y))
rot(y),rot(x); else
rot(x),rot(x);
}
}
void access(int x,int r)
{
int y;
for (y=0;x;y=x,x=fa(x))
{
splay(x);
rs(x)=y;
}
push_tag(y,r);
}
int Col(int x)
{
splay(x);
return col(x);
}
void Dfs(int u,int w)
{
for (int i=fr[u];i;i=E[i].nxt)
{
int v=E[i].v;
Dfs(v,w);
g[u]=max(g[u],g[v]);
}
ans[w]+=S2.len[u]-max(S2.len[S2.pre[u]],min(g[u],S2.len[u]));
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
for (int i=1;i<=n;++i)
S1.ins(s[i]-'a');
S1.Do_st();
for (int i=2;i<=S1.cnt;++i)
fa(i)=S1.pre[i];
scanf("%d",&q);
for (int i=1;i<=q;++i)
{
cin >> T[i];
lth[i]=T[i].length();
scanf("%d%d",&l,&r);
e[r].push_back(mp(l,i));
}
int st=1;
for (int i=1;i<=n;++i)
{
st=S1.tr[st][s[i]-'a'];
access(st,i);
for (IT it=e[i].begin();it!=e[i].end();++it)
{
S2.Clear();
int l=it->first,w=it->second,s0=1,st0=1,nlen=0;
for (int j=0;j<lth[w];++j)
S2.ins(T[w][j]-'a');
S2.Do_st();
for (int j=0;j<lth[w];++j)
{
int c=T[w][j]-'a';
st0=S2.tr[st0][c];
if (!S1.tr[1][c])
s0=1,nlen=0; else
{
while (!S1.tr[s0][c])
s0=S1.pre[s0],nlen=S1.len[s0];
s0=S1.tr[s0][c];
++nlen;
int tans;
if (Col(s0)-nlen+1>=l)
tans=nlen; else
{
int F=s0;
for (int j=20;j>=0;--j)
if (S1.f[F][j] && Col(S1.f[F][j])-S1.len[S1.f[F][j]]+1<l)
F=S1.f[F][j];
tans=max(Col(F)-l+1,S1.len[S1.f[F][0]]);
}
int k=st0;
for (int j=20;j>=0;--j)
if (S2.f[k][j] && S2.len[S2.f[k][j]]>=tans)
k=S2.f[k][j];
g[k]=max(g[k],tans);
}
}
S2.build();
Dfs(1,w);
}
}
for (int i=1;i<=q;++i)
printf("%lld\n",ans[i]);
return 0;
}
根据线段树合并解法的启发,发现倍增根本不需要,我们在需要跳祖先的地方跳祖先即可,这样可以省去一只\(\log\)。
时间复杂度:\(O(\sum \lvert T_i \rvert \log n)\)。
\(Code:\)
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<vector>
#define pr pair<int,int>
#define mp make_pair
#define ll long long
#define IT vector< pr > :: iterator
#define N 500005
#define qN 100005
using namespace std;
int l,r;
int lth[qN],g[N];
ll ans[qN];
char s[N];
string T[qN];
vector< pr >e[N];
int n,q;
struct SAM
{
int lst=1,cnt=1,tr[N << 1][26],pre[N << 1],len[N << 1],lc[N << 1];
void ins(int c)
{
int p=lst,q,np;
lst=np=++cnt;
len[np]=len[p]+1,lc[np]=len[p];
for (;p && !tr[p][c];p=pre[p])
tr[p][c]=np;
if (!p)
pre[np]=1; else
{
q=tr[p][c];
if (len[p]+1==len[q])
pre[np]=q; else
{
int g=++cnt;
memcpy(tr[g],tr[q],sizeof(tr[q]));
len[g]=len[p]+1,pre[g]=pre[q],lc[g]=lc[q];
for (;p && tr[p][c]==q;p=pre[p])
tr[p][c]=g;
pre[np]=pre[q]=g;
}
}
}
void Clear()
{
memset(tr,0,26*(cnt+1)*sizeof(int));
cnt=lst=1;
}
}S1,S2;
#define ls(x) a[x].ch[0]
#define rs(x) a[x].ch[1]
#define fa(x) a[x].f
#define tag(x) a[x].cltg
#define col(x) a[x].cl
struct LCT
{
int ch[2],f,cltg,cl;
}a[N << 1];
int Q[N << 1];
int id(int x)
{
return ls(fa(x))==x?0:1;
}
bool isrt(int x)
{
return ls(fa(x))!=x && rs(fa(x))!=x;
}
void connect(int x,int F,int son)
{
fa(x)=F;
a[F].ch[son]=x;
}
void rot(int x)
{
int y=fa(x),r=fa(y);
int yson=id(x),rson=id(y);
if (isrt(y))
fa(x)=r; else
connect(x,r,rson);
connect(a[x].ch[yson^1],y,yson);
connect(y,x,yson^1);
}
void push_tag(int x,int z)
{
if (!x)
return;
tag(x)=col(x)=z;
}
void push_down(int x)
{
if (tag(x))
{
push_tag(ls(x),tag(x));
push_tag(rs(x),tag(x));
tag(x)=0;
}
}
void splay(int x)
{
int g=x,k=0;
Q[++k]=x;
while (!isrt(g))
g=fa(g),Q[++k]=g;
while (k)
push_down(Q[k--]);
while (!isrt(x))
{
int y=fa(x);
if (isrt(y))
rot(x); else
if (id(x)==id(y))
rot(y),rot(x); else
rot(x),rot(x);
}
}
void access(int x,int r)
{
int y;
for (y=0;x;y=x,x=fa(x))
{
splay(x);
rs(x)=y;
}
push_tag(y,r);
}
int Col(int x)
{
splay(x);
return col(x);
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
for (int i=1;i<=n;++i)
S1.ins(s[i]-'a');
for (int i=2;i<=S1.cnt;++i)
fa(i)=S1.pre[i];
scanf("%d",&q);
for (int i=1;i<=q;++i)
{
cin >> T[i];
lth[i]=T[i].length();
scanf("%d%d",&l,&r);
e[r].push_back(mp(l,i));
}
int st=1;
for (int i=1;i<=n;++i)
{
st=S1.tr[st][s[i]-'a'];
access(st,i);
for (IT it=e[i].begin();it!=e[i].end();++it)
{
S2.Clear();
int l=it->first,w=it->second,s0=1,st0=1,nlen=0;
for (int j=0;j<lth[w];++j)
S2.ins(T[w][j]-'a');
for (int j=0;j<lth[w];++j)
{
int c=T[w][j]-'a';
st0=S2.tr[st0][c];
if (!S1.tr[1][c] || Col(S1.tr[1][c])<l)
s0=1,nlen=0; else
{
while (!S1.tr[s0][c])
s0=S1.pre[s0],nlen=S1.len[s0];
s0=S1.tr[s0][c];
++nlen;
if (Col(s0)-nlen+1<l)
{
while (Col(s0)-S1.len[S1.pre[s0]]<l)
s0=S1.pre[s0];
nlen=min(S1.len[s0],Col(s0)-l+1);
}
}
g[j]=nlen;
}
for (int j=2;j<=S2.cnt;++j)
ans[w]+=max(S2.len[j]-max(g[S2.lc[j]],S2.len[S2.pre[j]]),0);
}
}
for (int i=1;i<=q;++i)
printf("%lld\n",ans[i]);
return 0;
}