【CF710F】String Set Queries
题目
题目链接:https://codeforces.com/problemset/problem/710/F
维护一个字符串集合,支持三种操作:
- 加字符串
- 删字符串
- 查询集合中的所有字符串在给出的模板串中出现的次数
操作数 \(m \leq 3 \times 10^5\),输入字符串总长度 \(\sum |s_i| \leq 3\times 10^5\)。强制在线。
思路
假设没有删除操作,我们可以考虑维护若干个 AC 自动机,询问的时候分别查询答案求和,然后定期重构。
当插入一个字符串时,先单独开一个 AC 自动机存放。维护一个栈表示 AC 自动机标号。
然后不断比较现在这个 AC 自动机字符串数量和栈顶 AC 自动机字符串数量,如果相等就将两个 AC 自动机合并。
这样的话,每个时刻最多有 \(\lceil \log n\rceil\) 个 AC 自动机,并且每一个字符串最多被合并 \(\log n\) 次。所以时间复杂度 \(O(n\log n)\)。
注意每次合并之后要将涉及的两个 AC 自动机的 fail 信息完全删除,然后再重新构建 fail 树。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=300010,LG=20;
int Q,n,opt,top[2],rt[2][LG],siz[2][LG];
char s[N];
struct ACA
{
int tot,ch[N*3][26][2],fail[N*3],cnt[N*3],vis[N*3];
ll sum[N*3];
void insert(int p,char *s)
{
int len=strlen(s+1);
for (int i=1;i<=len;i++)
{
if (!ch[p][s[i]-'a'][0]) ch[p][s[i]-'a'][0]=++tot;
p=ch[p][s[i]-'a'][0];
}
cnt[p]++;
}
int merge(int x,int y)
{
if (!x || !y) return x|y;
cnt[x]+=cnt[y];
for (int i=0;i<26;i++)
ch[x][i][0]=merge(ch[x][i][0],ch[y][i][0]);
return x;
}
void build(int p)
{
queue<int> q;
for (int i=0;i<26;i++)
if (ch[p][i][0])
{
q.push(ch[p][i][0]);
ch[p][i][1]=ch[p][i][0];
fail[ch[p][i][0]]=p;
}
else ch[p][i][1]=p;
while (q.size())
{
int u=q.front(); q.pop();
for (int i=0;i<26;i++)
if (ch[u][i][0])
{
fail[ch[u][i][0]]=ch[fail[u]][i][1];
ch[u][i][1]=ch[u][i][0];
q.push(ch[u][i][0]);
}
else ch[u][i][1]=ch[fail[u]][i][1];
sum[u]=cnt[u]+sum[fail[u]];
}
}
ll query(int p,char *s)
{
int len=strlen(s+1);
ll ans=0;
for (int i=1;i<=len;i++)
{
p=ch[p][s[i]-'a'][1];
ans+=sum[p];
}
return ans;
}
}AC;
int main()
{
scanf("%d",&Q);
while (Q--)
{
scanf("%d%s",&opt,s+1);
if (opt<=2)
{
opt--; top[opt]++;
siz[opt][top[opt]]=1; rt[opt][top[opt]]=++AC.tot;
AC.insert(rt[opt][top[opt]],s);
for (;siz[opt][top[opt]]==siz[opt][top[opt]-1];top[opt]--)
{
AC.merge(rt[opt][top[opt]-1],rt[opt][top[opt]]);
siz[opt][top[opt]-1]*=2;
}
AC.build(rt[opt][top[opt]]);
}
else
{
ll ans=0;
for (int i=1;i<=top[0];i++)
ans+=AC.query(rt[0][i],s);
for (int i=1;i<=top[1];i++)
ans-=AC.query(rt[1][i],s);
printf("%lld\n",ans);
fflush(stdout);
}
}
return 0;
}