UOJ#395. 【NOI2018】你的名字 字符串,SAM,线段树合并
原文链接https://www.cnblogs.com/zhouzhendong/p/UOJ395.html
题解
记得同步赛的时候这题我爆0了,最暴力的暴力都没调出来。
首先我们看看 68 分怎么做
——求两个串的本质不同的公共子串个数。
它是一个模板题,然而我当时并不会,甚至连SAM都忘了怎么写QAQ。
再简化一下:如何求一个串的本质不同的子串个数。
给串建一个SAM,把所有节点代表的字符串个数(也就是 Max(x) - Max(fa(x)) 加起来就好了。
回到上一个问题。
假设这两个串分别是 S,T 。对 T 建个SAM。
对于T的SAM,考虑对于它的任何一个节点 x ,算出 x 的 Right 集合代表的所有前缀与 S 的所有前缀的 LCS 的最大值(也就是这个节点代表的状态能在 S 上匹配的最长长度),设为 val(x)。然后对于所有 x 把 $(1,val(x)] \cap (Max(fa(x)),Max(x)]$ 的长度加起来就好了。
那么如何求那个最长的匹配长度?对 S 建一个 SAM,然后用 T 在 S 的 SAM 上走一遍,找到 T 的每一个前缀的 最长的是 S 的子串的后缀 然后 T 的 SAM 上的一个节点的 val 就是他在 parent 树上的所有后代节点的 Max 。
由于 S 的 SAM 可以预先建好,所以询问一个 T 串的复杂度是 $O(|T|)$ 的。
那么 S 有 [L,R] 的限制呢?
线段树合并预处理一下 S 的 SAM 的每一个节点的 Right 集合。
修改一下求最长的匹配长度的过程,保证走转移边的时候在 [L,R] 中有匹配。
注意这里有一个易错点:我们匹配失败跳 father 的时候,不能直接 len' = Max(father) ,只能不断减一。原因是在 len 不断减一的过程中可能会找到匹配,而直接跳 father 会漏过这个匹配。然而出题人数据出的很水,没注意到这个东西还是有96分!
至此,我们得到了一个 $O((|S|+\sum |T|)\log |S|)$ 的做法。
但是,由于在 SAM 上遍历节点暴力跳祖先的复杂度是 $O(n\sqrt n)$ 的,然后加个线段树合并多个 $\log$ ,总复杂度 $O(n\sqrt n \log n)$ 的可以通过原题数据……wft??(UOJ Hack数据过不去的)
代码
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) using namespace std; typedef long long LL; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=500005*4; int n,m,q; char s[N]; struct Node{ int Next[26],fa,Max,pos; }; namespace Seg{ const int S=N*35; int ls[S],rs[S],cnt=0; void Ins(int &rt,int L,int R,int x){ if (!rt) rt=++cnt; if (L==R) return; int mid=(L+R)>>1; if (x<=mid) Ins(ls[rt],L,mid,x); else Ins(rs[rt],mid+1,R,x); } int Merge(int a,int b,int L,int R){ if (!a||!b) return a+b; int rt=++cnt; if (L<R){ int mid=(L+R)>>1; ls[rt]=Merge(ls[a],ls[b],L,mid); rs[rt]=Merge(rs[a],rs[b],mid+1,R); } return rt; } int Query(int rt,int L,int R,int xL,int xR){ if (!rt||R<xL||L>xR||xL>xR) return 0; if (xL<=L&&R<=xR) return 1; int mid=(L+R)>>1; return Query(ls[rt],L,mid,xL,xR) |Query(rs[rt],mid+1,R,xL,xR); } } namespace SAM{ Node t[N]; int root,last,size; int rt[N],id[N]; void Init(){ while (size){ clr(t[size].Next); t[size].fa=t[size].Max=t[size].pos=rt[size]=0; size--; } root=last=size=1; } void extend(int c,int ps){ int p=last,np=++size,q,nq; t[np].Max=t[p].Max+1,t[np].pos=ps; Seg::Ins(rt[np],1,n,ps); for (;p&&!t[p].Next[c];p=t[p].fa) t[p].Next[c]=np; if (!p) t[np].fa=1; else { q=t[p].Next[c]; if (t[p].Max+1==t[q].Max) t[np].fa=q; else { nq=++size; t[nq]=t[q],t[nq].Max=t[p].Max+1,t[nq].pos=ps; t[np].fa=t[q].fa=nq; for (;p&&t[p].Next[c]==q;p=t[p].fa) t[p].Next[c]=nq; } } last=np; } void Sort(){ static int tax[N]; for (int i=0;i<=size;i++) tax[i]=0; for (int i=1;i<=size;i++) tax[t[i].Max]++; for (int i=1;i<=size;i++) tax[i]+=tax[i-1]; for (int i=1;i<=size;i++) id[tax[t[i].Max]--]=i; } void build(){ Sort(); for (int i=size;i>1;i--){ int x=id[i],f=t[x].fa; rt[f]=Seg::Merge(rt[f],rt[x],1,n); } } } namespace sam{ Node t[N]; int root,last,size; int id[N],val[N]; void Init(){ while (size){ clr(t[size].Next); t[size].fa=t[size].Max=t[size].pos=val[size]=0; size--; } root=last=size=1; } void extend(int c,int ps){ int p=last,np=++size,q,nq; t[np].Max=t[p].Max+1,t[np].pos=ps; for (;p&&!t[p].Next[c];p=t[p].fa) t[p].Next[c]=np; if (!p) t[np].fa=1; else { q=t[p].Next[c]; if (t[p].Max+1==t[q].Max) t[np].fa=q; else { nq=++size; t[nq]=t[q],t[nq].Max=t[p].Max+1,t[nq].pos=ps; t[np].fa=t[q].fa=nq; for (;p&&t[p].Next[c]==q;p=t[p].fa) t[p].Next[c]=nq; } } last=np; } void Sort(){ static int tax[N]; for (int i=0;i<=size;i++) tax[i]=0; for (int i=1;i<=size;i++) tax[t[i].Max]++; for (int i=1;i<=size;i++) tax[i]+=tax[i-1]; for (int i=1;i<=size;i++) id[tax[t[i].Max]--]=i; } LL solve(){ Sort(); LL ans=0; for (int i=size;i>1;i--){ int x=id[i],f=t[x].fa; val[f]=max(val[x],val[f]); ans+=max(0,t[x].Max-max(t[f].Max,val[x])); } return ans; } } int main(){ scanf("%s",s+1); n=strlen(s+1),q=read(); SAM::Init(); for (int i=1;i<=n;i++) SAM::extend(s[i]-'a',i); SAM::build(); SAM::t[0].Max=-1; while (q--){ scanf("%s",s+1); m=strlen(s+1); sam::Init(); int L=read(),R=read(); int x=1,len=0; for (int i=1;i<=m;i++){ int c=s[i]-'a',nowx=sam::size+1; sam::extend(c,i); while (x){ int nx=SAM::t[x].Next[c]; if (nx&&Seg::Query(SAM::rt[nx],1,n,L+len,R)) break; if ((--len)==SAM::t[SAM::t[x].fa].Max) x=SAM::t[x].fa; } if (!x) x=1,len=0; else { x=SAM::t[x].Next[c]; sam::val[nowx]=++len; } } printf("%lld\n",sam::solve()); } return 0; }