luogu P2414 [NOI2011] 阿狸的打字机
题面传送门
AC自动机蛮好的一道题,需要有一定对AC自动机的理解。
首先把所有字符串扔到AC自动机里面去。然后求出fail指针。
对于每个询问暴力匹配肯定是不行的,考虑怎么转化。
可以发现如果把fail指针抽出来,那么一个节点的子树内所有节点都能与这个节点匹配。
那么就可以把询问离线然后dfs搞。
但是问题是我们不知道每个节点对应哪个字符串。
那么开个vector表示所有这个节点属于的区间,然后树状数组差分维护即可。
代码实现:
#include<cstdio>
#include<queue>
#include<cstring>
#include<vector>
#define beg(x) int cur=s.h[x]
#define end cur
#define go cur=tmp.z
using namespace std;
int n,m,k,x,y,z,flag[100039],st[100039],sh,cnt,now,tot[100039],c[100039],id[100039],head=1;
char a[100039];
inline void get(int x,int y){while(x<=head)c[x]+=y,x+=x&-x;}
inline int find(int x){int ans=0;while(x) ans+=c[x],x-=x&-x;return ans;}
struct AC{int son[26],fail;}f[100039];
queue<int> q;
inline void bfs(){
register int i;
for(i=0;i<=25;i++) if(f[0].son[i]) q.push(f[0].son[i]);
while(!q.empty()){
now=q.front();q.pop();
for(i=0;i<=25;i++) {
if(f[now].son[i]) f[f[now].son[i]].fail=f[f[now].fail].son[i],q.push(f[now].son[i]);
else f[now].son[i]=f[f[now].fail].son[i];
}
}
}
struct ques{int to,id;};
vector<ques> g[100039];
vector<int> pl[100039];
struct yyy{int to,z;};
struct ljb{
int head,h[100039];yyy f[200039];
inline void add(int x,int y){f[++head]=(yyy){y,h[x]};h[x]=head;}
}s;
inline void dfs(int x){
yyy tmp;ques tmps;int i,now1,now2;
for(i=0;i<g[x].size();i++){
tmps=g[x][i];
tot[tmps.id]-=find(tmps.to);
}
for(i=0;i<pl[x].size();i+=2)now1=pl[x][i],now2=pl[x][i+1],get(now1,1),get(now2+1,-1);
for(beg(x);end;go)tmp=s.f[cur],dfs(tmp.to);
for(i=0;i<g[x].size();i++){
tmps=g[x][i];
tot[tmps.id]+=find(tmps.to);
}
}
int main(){
freopen("1.in","r",stdin);
register int i;
scanf("%s",a+1);m=strlen(a+1);
for(i=1;i<=m;i++){
if(a[i]=='B') pl[now].push_back(head),now=st[--sh];
else if(a[i]=='P') id[++head]=now;
else now=(f[now].son[a[i]-'a']?f[now].son[a[i]-'a']:(f[now].son[a[i]-'a']=++cnt)),st[++sh]=now,pl[now].push_back(head+1);
}bfs();
for(i=0;i<=cnt;i++) if((pl[i].size())&1) pl[i].push_back(head);
for(i=1;i<=cnt;i++)s.add(f[i].fail,i);
scanf("%d",&n);
for(i=1;i<=n;i++) scanf("%d%d",&x,&y),g[id[x+1]].push_back((ques){y+1,i});
dfs(0);
for(i=1;i<=n;i++) printf("%d\n",tot[i]);
}