CF710F-String Set Queries【AC自动机,二进制分组】
正题
题目链接:https://www.luogu.com.cn/problem/CF710F
题目大意
\(T\)次操作
- 往集合中加入一个字符串
- 往集合中删除一个字符串
- 给出一个模式串求出现的集合里面的字符串个数
解题思路
删除的话改成加入一个权值为\(-1\)的字符串就是全都是加入操作了。
然后就可以像[SDOI2014]向量集一样的做法了,维护一个线段树,然后第\(i\)次加入修改第\(i\)个节点,然后回朔的时候,如果一个区间\([l,r]\)加入了\(r-l+1\)个字符串(加满了)的话就直接把下面的\(AC\)自动机合并(或者直接把\(l\sim r\)重新暴力加入到一个\(AC\)自动机),然后匹配的时候把目前合并了的顶层\(AC\)自动机分别求和就好了。
时间复杂度\(O(n\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<queue>
#define ll long long
using namespace std;
const int N=3e5+10;
int T,n,cnt,val[N],rt[N<<2],siz[N<<2],last[N<<2];
vector<int> v[N];
char s[N];
struct ACM{
int ch[N<<3][26],fail[N<<3],cnt;
ll w[N<<3];queue<int> q;
void Insert(int &now,vector<int>&s,int val){
if(!now)now=++cnt;int x=now;
for(int i=0;i<s.size();i++){
int c=s[i];
if(!ch[x][c])ch[x][c]=++cnt;
x=ch[x][c];
}
w[x]+=val;
return;
}
void GetFail(int rt){
for(int i=0;i<26;i++)
if(ch[rt][i])q.push(ch[rt][i]);
else ch[rt][i]=rt;
while(!q.empty()){
int x=q.front();q.pop();
if(!fail[x])fail[x]=rt;
w[x]+=w[fail[x]];
for(int i=0;i<26;i++){
if(!ch[x][i])ch[x][i]=ch[fail[x]][i];
else{
fail[ch[x][i]]=ch[fail[x]][i];
q.push(ch[x][i]);
}
}
}
return;
}
ll Find(int x,int n,char *s){
ll ans=0;
for(int i=0;i<n;i++){
x=ch[x][s[i]-'a'];
ans+=w[x];
}
return ans;
}
void Clear(int d){
while(cnt>d){
memset(ch[cnt],0,sizeof(ch[cnt]));
fail[cnt]=w[cnt]=0;cnt--;
}
return;
}
}A;
void Change(int x,int L,int R,int pos){
if(!siz[x])last[x]=A.cnt;
if(L==R){siz[x]++;A.Insert(rt[x],v[L],val[L]);A.GetFail(rt[x]);return;}
int mid=(L+R)>>1;
if(pos<=mid)Change(x*2,L,mid,pos);
else Change(x*2+1,mid+1,R,pos);
siz[x]=siz[x*2]+siz[x*2+1];
if(siz[x]==R-L+1){
A.Clear(last[x]);
for(int i=L;i<=R;i++)
A.Insert(rt[x],v[i],val[i]);
A.GetFail(rt[x]);
}
return;
}
ll Ask(int x,int L,int R,int l,int r){
if(L==l&&R==r&&siz[x]==R-L+1)return A.Find(rt[x],n,s);
int mid=(L+R)>>1;
if(r<=mid)return Ask(x*2,L,mid,l,r);
if(l>mid)return Ask(x*2+1,mid+1,R,l,r);
return Ask(x*2,L,mid,l,mid)+Ask(x*2+1,mid+1,R,mid+1,r);
}
int main()
{
scanf("%d",&T);
for(int p=1;p<=T;p++){
int op;
scanf("%d%s",&op,s);
// op=p%3+1;s[0]='a';
n=strlen(s);
if(op<=2){
v[++cnt].resize(n);val[cnt]=(op==1)?1:-1;
for(int i=0;i<n;i++)v[cnt][i]=s[i]-'a';
Change(1,1,T,cnt);
}
else{
if(!cnt){puts("0");fflush(stdout);continue;}
printf("%lld\n",Ask(1,1,T,1,cnt));
fflush(stdout);
}
}
return 0;
}