BZOJ2555 - SubString
Description
给出初始字符串\(s_0(|s_0|\leq6\times10^5)\),进行\(m(m\leq10^4)\)次操作,操作有两种:
- 在当前字符串的后面插入一个字符串\(s\)。
- 询问字符串\(s\)作为连续子串在当前字符串中出现了几次。
其中\(\sum|s|\leq3\times10^6\)。
Solution
其实就是要求我们动态维护后缀自动机的\(Right\)集合大小。
状态\(s\)的\(Right\)集合的大小,等于\(parent\)树上以\(s\)为根的子树中有多少个点在SAM的主链上。我们可以用lct来维护整棵\(parent\)树,当我们在主链上加入一个点\(np\)时,将\(np\)到\(rt\)这条路径上的所有点\(+1\)。
时间复杂度\(O(mlog(\sum|s|))\)。
Code
//SubString
#include <algorithm>
#include <cstdio>
using std::swap;
int const N=2e6;
char s[N]; int Q;
void decode(char s[],int mask)
{
int len=0; while(s[len]) len++;
for(int j=0;j<len;j++) mask=(mask*131+j)%len,swap(s[j],s[mask]);
}
namespace lct
{
int fa[N],ch[N][2],sum[N]; int add[N];
int wh(int p) {return p==ch[fa[p]][1];}
bool notRt(int p) {return p==ch[fa[p]][wh(p)];}
void doAdd(int p,int x) {if(p) sum[p]+=x,add[p]+=x;}
void update(int p) {/*并没有什么需要update的*/ }
void pushdw(int p) {if(add[p]) doAdd(ch[p][0],add[p]),doAdd(ch[p][1],add[p]),add[p]=0;}
void rotate(int p)
{
int q=fa[p],r=fa[q],w=wh(p);
fa[p]=r; if(notRt(q)) ch[r][wh(q)]=p;
fa[ch[q][w]=ch[p][w^1]]=q;
fa[ch[p][w^1]=q]=p;
update(q),update(p);
}
void pushdwRt(int p) {if(notRt(p)) pushdwRt(fa[p]); pushdw(p);}
void splay(int p)
{
pushdwRt(p);
for(int q=fa[p];notRt(p);rotate(p),q=fa[p]) if(notRt(q)) rotate(wh(p)^wh(q)?p:q);
}
void access(int p) {for(int q=0;p;q=p,p=fa[p]) splay(p),ch[p][1]=q,update(p);}
void link(int p,int q) {fa[p]=q;}
void cut(int p) {access(p),splay(p); fa[ch[p][0]]=0,ch[p][0]=0; update(p);}
void add_p2rt(int p) {access(p),splay(p),doAdd(p,1);}
}
int rt,ndCnt,last;
int prt[N],ch[N][2],len[N];
using lct::link; using lct::cut;
void ins(int x)
{
int p=last,np=++ndCnt;
len[np]=len[p]+1,last=np;
for(p;p&&!ch[p][x];p=prt[p]) ch[p][x]=np;
if(!p) {prt[np]=rt,link(np,rt); return;}
int q=ch[p][x];
if(len[q]==len[p]+1) {prt[np]=q,link(np,q); return;}
int nq=++ndCnt; len[nq]=len[p]+1;
for(int i=0;i<2;i++) ch[nq][i]=ch[q][i];
cut(q); lct::sum[nq]=lct::sum[q];
prt[nq]=prt[q],link(nq,prt[q]);
prt[q]=nq,link(q,nq);
prt[np]=nq,link(np,nq);
for(p;ch[p][x]==q;p=prt[p]) ch[p][x]=nq;
}
void insStr(char s[]) {for(int i=1;s[i];i++) ins(s[i]-'A'),lct::add_p2rt(last);}
int query(char s[])
{
int p=rt;
for(int i=1;s[i];i++) if(ch[p][s[i]-'A']) p=ch[p][s[i]-'A']; else return 0;
lct::access(p),lct::splay(p); return lct::sum[p];
}
int main()
{
int mask=0; scanf("%d",&Q);
last=rt=++ndCnt;
scanf("%s",s+1); insStr(s);
while(Q--)
{
char opt[10];
scanf("%s%s",opt,s+1); decode(s+1,mask);
if(opt[0]=='A') insStr(s);
else {int res=query(s); mask^=res,printf("%d\n",res);}
}
return 0;
}
P.S.
断更了两天...这题我前天做的