BZOJ3881 COCI2015 Divljak
题目大意
给定一个有$n$个字符串的集合$S$,一个初始为空的字符串集合$T$,支持两种操作。
1、向$T$中加入一个新的字符串$K$
2、给定$i$,询问$T$中有多少个字符串包含了$S_i$这个子串。
题解
先考虑暴力怎么做,先对$S$建立$AC$自动机,每插入一个字符串$K$,在$AC$自动机上跑一边,对匹配到的每一个点在和它们$fail$树上的祖先(即到根的路径上的点)中的单词结尾的节点的答案都有$+1$的贡献。
现在想办法优化这个暴力,不难发现可以使用在$fail$树上重链剖分$+$线段树维护树链的并,但这样代码本身的复杂度和时间复杂度似乎都太高了,我们可以用更优美的方法维护树链的并。
考虑将遍历到的点按照$fail$树上的$dfs$序排序,利用容斥原理可得树链的并恰好是排好序后每个点到根的路径被正向计算一次和每两个相邻的点的$lca$被反向计算一次。
可以用差分维护树上到根的路径,若想让$x$到根的路径上每个点权值$+m$则只需要领$x$点权值$+m$,然后求子树和即可。
这个预处理每个点的$dfs$序,用树状数组维护$dfs$序列区间和即可。
复杂度大约是$O(|T|\log|T|+|S|)$,具体实现还需要重点只加入一次之类的卡常,这里不再赘述。
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #include<cmath> #define LL long long #define M 100010 using namespace std; namespace IO{ const int BS=(1<<20)+5; int Top=0; char OT[BS],*OS=OT,*HD,*TL,SS[20]; const char *fin=OT+BS-1; void flush(){fwrite(OT,1,OS-OT,stdout);} void Putchar(char c){*OS++ =c;if(OS==fin)flush(),OS=OT;} void write(int x){ if(!x){Putchar('0');return;} if(x<0) x=-x,Putchar('-'); while(x) SS[++Top]=x%10,x/=10; while(Top) Putchar(SS[Top]+'0'),--Top; } int read(){ int nm=0; char cw=getchar();for(;!isdigit(cw);cw=getchar()); for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0'); return nm; } } using namespace IO; int n,m,pos[M],t[M*20][27],rt,cnt,tmp,sz[M*20],c[M*20],mxs[M*20],dep[M*20],to[M*20]; int q[M*20],tg[M*20],vis[M*20],hd,tl,dfn[M*20],tp[M*20],fa[M*20],fs[M*20],nt[M*20],tot; char s[M*20]; void add(int pos,int x){for(int k=pos;k<=cnt;k+=(k&-k)) c[k]+=x;} int qry(int pos){int tt=0;for(int k=pos;k;k-=(k&-k)) tt+=c[k];return tt;} int ins(int &x,char *k,int rem){ if(!x) x=++cnt,fs[x]=-1; if(!rem) return x; return ins(t[x][*k-'a'],k+1,rem-1); } void link(int x,int y){nt[tmp]=fs[x],fs[x]=tmp,to[tmp++]=y;} void dfs1(int x){ sz[x]=1; for(int i=fs[x];i!=-1;i=nt[i]){ dep[to[i]]=dep[x]+1; dfs1(to[i]),sz[x]+=sz[to[i]]; if(sz[mxs[x]]<sz[to[i]]) mxs[x]=to[i]; } } void dfs2(int x,int dtp){ tp[x]=dtp,dfn[x]=++cnt; if(mxs[x]) dfs2(mxs[x],dtp); else return; for(int i=fs[x];i!=-1;i=nt[i]) if(to[i]!=mxs[x]) dfs2(to[i],to[i]); } int lca(int x,int y){ while(tp[x]!=tp[y]){ if(dep[tp[x]]<dep[tp[y]]) y=fa[tp[y]]; else x=fa[tp[x]]; } return dep[x]>dep[y]?y:x; } void tk(int x){if(vis[x]<tot) vis[x]=tot,q[++m]=x;} void check(int x,char *k,int rem){tk(x);if(rem)check(t[x][*k-'a'],k+1,rem-1);} bool cmp(int x,int y){return dfn[x]<dfn[y];} int main(){ n=read(),fa[rt]=rt,dep[rt]=1; for(int i=1;i<=n;i++) scanf("%s",s),pos[i]=ins(rt,s,strlen(s)); for(int i=0;i<26;i++) if(!t[rt][i]) t[rt][i]=rt;else fa[t[rt][i]]=rt,q[tl++]=t[rt][i]; for(cnt=0;hd<tl;){ int x=q[hd++]; link(fa[x],x); for(int k=0;k<26;k++) if(!t[x][k])t[x][k]=t[fa[x]][k];else fa[t[x][k]]=t[fa[x]][k],q[tl++]=t[x][k]; } dfs1(rt),dfs2(rt,rt),m=0,tot=1; for(int tpe,v,T=read();T;T--,tot++){ tpe=read(); if(tpe==2) v=pos[read()],write(qry(dfn[v]+sz[v]-1)-qry(dfn[v]-1)),Putchar('\n'); else{ scanf("%s",s),m=0,check(rt,s,strlen(s)); sort(q+1,q+m+1,cmp); if(m) add(dfn[q[1]],1); for(int i=2;i<=m;i++) add(dfn[lca(q[i-1],q[i])],-1),add(dfn[q[i]],1); } } flush(); return 0; }
当然由于正解复杂度卡得很近,也给暴力留下了很多的机会,比如只是在暴力的基础,预处理每个节点$fail$树上的最近的单词结尾节点的祖先,在$AC$自动机上匹配的时候,每匹配到一个节点暴力向上跳到所有的单词结尾节点,遇到这一轮之前到达的就不继续跳了,据说这个复杂度最多是带个$|T|\sqrt{n}$,但凭借着优秀的常数和玄学的数据跑得比大多数正解要快。
#include<algorithm> #include<iostream> #include<cstring> #include<cstdio> #define M 100010 using namespace std; namespace IO{ const int BS=(1<<21)+5; int Top=0; char Buffer[BS],OT[BS],*OS=OT,*HD,*TL,SS[20]; const char *fin=OT+BS-1; char Getchar(){if(HD==TL){TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);} return (HD==TL)?EOF:*HD++;} void flush(){fwrite(OT,1,OS-OT,stdout);} void Putchar(char c){*OS++ =c;if(OS==fin)flush(),OS=OT;} void write(int x){ if(!x){Putchar('0');return;} if(x<0) x=-x,Putchar('-'); while(x) SS[++Top]=x%10,x/=10; while(Top) Putchar(SS[Top]+'0'),--Top; } int read(){ int nm=0; char cw=Getchar(); for(;!isdigit(cw);cw=Getchar()); for(;isdigit(cw);cw=Getchar()) nm=nm*10+(cw-'0'); return nm; } int reads(char *k){ char cc=Getchar();int Tot=0; while(!islower(cc)) cc=Getchar(); for(;islower(cc);cc=Getchar()) k[Tot++]=cc; return Tot; } }using namespace IO; int n,m,pos[M],t[M*20][27],tot,rt,cnt,tmp,q[M*20],hd,tl; int ans[M*20],fa[M*20],last[M*20],vis[M*20]; char s[M*20]; bool vs[M*20]; int ins(int x,int rem){ for(int i=0;i<rem;x=t[x][s[i++]-'a']){ if(!t[x][s[i]-'a']) t[x][s[i]-'a']=++cnt; } return x; } void check(int x,int rem){ for(int i=0;i<=rem;x=t[x][s[i++]-'a']){ for(int now=x;now&&vis[now]<tot;now=last[now]) ++ans[now],vis[now]=tot; } } int main(){ n=read(),fa[rt=cnt=1]=rt; for(int i=1;i<=n;i++) vs[pos[i]=ins(rt,reads(s))]=true; for(int i=0;i<26;i++) if(!t[rt][i]) t[rt][i]=rt;else fa[t[rt][i]]=rt,q[tl++]=t[rt][i]; for(cnt=0;hd<tl;){ int x=q[hd++]; last[x]=vs[fa[x]]?fa[x]:last[fa[x]]; for(int k=0;k<26;k++){ if(!t[x][k])t[x][k]=t[fa[x]][k]; else fa[t[x][k]]=t[fa[x]][k],q[tl++]=t[x][k]; } } tot=1; for(int tpe,v,T=read();T;T--,tot++){ tpe=read(); if(tpe&1) check(rt,reads(s)); else v=pos[read()],write(ans[v]),Putchar('\n'); } flush(); return 0; }