浅谈二进制分组
不知道为啥叫这个名字
感觉就是一个类似2048的东西
把操作分组
假设原来的分组每块大小为
4 2
考虑加入一个操作,变成
4 2 1
再加一个操作
4 2 1 1
合并
4 2 2
合并
4 4
合并
8
可见,如果有n个操作,最多会有log组,每个操作最多被合并log次
给道题试看看
CF710F
考虑每个组建一个AC自动机,然后查询就是把每个组的AC自动机都跑一遍
加入和删除分别建两颗AC自动机,用加入的减去删除的就是答案
注意细节实现
code:
#include<bits/stdc++.h>
#define N 600050
using namespace std;
struct ACT {
int ch[N][27], chh[N][27], size[N], root[N], cnt = 0, tot = 0, fail[N], s[N], ss[N];
void iinsert(int p, char *st) {
int len = strlen(st + 1);// printf(" %d\n", len);
for(int i = 1; i <= len; i ++) {
int c = st[i] - 'a';
if(!chh[p][c]) chh[p][c] = ++ tot;
p = chh[p][c];
} s[p] ++;
}
void build(int p) {
queue<int> q;
for(int i = 0; i < 26; i ++) {
ch[p][i] = chh[p][i];
if(ch[p][i]) q.push(ch[p][i]), fail[ch[p][i]] = p;
else ch[p][i] = p;
}
ss[p] = 0;
while(q.size()) {
int x = q.front(); q.pop();
// printf(" * %d * ", x);
ss[x] = ss[fail[x]] + s[x];
for(int i = 0; i < 26; i ++) { ch[x][i] = chh[x][i];
if(ch[x][i]) {
fail[ch[x][i]] = ch[fail[x]][i];
q.push(ch[x][i]);
} else ch[x][i] = ch[fail[x]][i];
}
}
// for(int i = 1; i <= tot; i ++) printf(" %d ", s[i]); printf("\n");
// for(int i = 1; i <= tot; i ++) printf(" %d ", ss[i]); printf("\n");
// for(int i = 1; i <= tot; i ++) printf(" %d ", fail[i]); printf("\n");
}
int merge(int x, int y) {
if(!x || !y) return x | y; //printf(" %d %d\n", x, y);
s[x] += s[y];
for(int i = 0; i < 26; i ++) chh[x][i] = merge(chh[x][i], chh[y][i]);
return x;
}
void insert(char *st) {
size[++ cnt] = 1; root[cnt] = ++ tot;
iinsert(root[cnt], st);
while(size[cnt] == size[cnt - 1]) {
root[cnt - 1] = merge(root[cnt - 1], root[cnt]);
size[-- cnt] <<= 1;
} build(root[cnt]);
}
int query(int p, char *st) {
int len = strlen(st + 1), ret = 0;
for(int i = 1; i <= len; i ++) {
p = ch[p][st[i] - 'a'];// printf("*%d*", p);
ret += ss[p];
}
return ret;
}
int calc(char *st) { //printf(" %d\n", cnt); cout << st << endl;
int ret = 0;
for(int i = 1; i <= cnt; i ++) ret += query(root[i], st);
return ret;
}
} s1, s2;
int t;
char st[N];
int main() {
scanf("%d", &t);
while(t --) {
int o;
scanf("%d %s", &o, st + 1);
if(o == 1) s1.insert(st);
if(o == 2) s2.insert(st);
if(o == 3) printf("%d\n", s1.calc(st) - s2.calc(st)), fflush(stdout);
}
return 0;
}