[BZOJ 2434][Noi2011]阿狸的打字机(AC自动机+树状数组+dfs序)
Description
打字机上只有28个按键,分别印有26个小写英文字母和'B'、'P'两个字母。经阿狸研究发现,这个打字机是这样工作的:
·输入小写字母,打字机的一个凹槽中会加入这个字母(这个字母加在凹槽的最后)。
·按一下印有'B'的按键,打字机凹槽中最后一个字母会消失。
·按一下印有'P'的按键,打字机会在纸上打印出凹槽中现有的所有字母并换行,但凹槽中的字母不会消失。
例如,阿狸输入aPaPBbP,纸上被打印的字符如下:
a aa ab 我们把纸上打印出来的字符串从1开始顺序编号,一直到n。打字机有一个非常有趣的功能,在打字机中暗藏一个带数字的小键盘,在小键盘上输入两个数(x,y)(其中1≤x,y≤n),打字机会显示第x个打印的字符串在第y个打印的字符串中出现了多少次。
阿狸发现了这个功能以后很兴奋,他想写个程序完成同样的功能,你能帮助他么?
Solution
居然1A【然而看了题解】
查询x在y中出现了多少次,即y的节点有多少可以通过fail跳到x的尾节点
但是这个太暴力了,于是我们建一下fail树(fail的反向边,由fail指针指向原节点)
转化为查询x的尾节点可以从fail树到达的y的节点有多少个
然后就可以通过dfs序和树状数组做了
离线,对于相同的y一次性处理,标记出所有y的节点,查询每一个x的子树中y的节点的出现次数
#include<iostream> #include<cstdio> #include<cstring> #include<cstdlib> #include<queue> #define MAXN 100005 using namespace std; int m,sz,root,pos[MAXN],in[MAXN],out[MAXN],dfn_clock=0,c[MAXN],ans[MAXN],dfn[MAXN]; int head1[MAXN],head2[MAXN],cnt1=0,cnt2=0; struct Node2 { int next,to; }e1[MAXN],e2[MAXN]; void addedge1(int u,int v) { e1[++cnt1].next=head1[u]; head1[u]=cnt1; e1[cnt1].to=v; } void addedge2(int u,int v) { e2[++cnt2].next=head2[u]; head2[u]=cnt2; e2[cnt2].to=v; } char s[MAXN]; struct Node1 { int next[26],par,fail; }trie[MAXN]; int newnode(int f) { trie[++sz].fail=0; trie[sz].par=f; memset(trie[sz].next,0,sizeof(trie[sz].next)); return sz; } void insert() { sz=0,root=newnode(0); int i=0,p=root,_clock=0; while(s[i]) { if(s[i]=='B') p=trie[p].par; else if(s[i]=='P') pos[++_clock]=p; else { int idx=s[i]-'a'; if(!trie[p].next[idx])trie[p].next[idx]=newnode(p); p=trie[p].next[idx]; } i++; } } queue<int>q; void build() { q.push(root); while(!q.empty()) { int p=q.front();q.pop(); for(int i=0;i<26;i++) { int t=trie[p].fail; while(t&&!trie[t].next[i])t=trie[t].fail; if(trie[p].next[i]) { trie[trie[p].next[i]].fail=t?trie[t].next[i]:root; q.push(trie[p].next[i]); } else trie[p].next[i]=t?trie[t].next[i]:root; } } } void dfs(int u) { if(dfn[u])return; ++dfn_clock; in[u]=dfn[u]=dfn_clock; for(int i=head1[u];~i;i=e1[i].next) dfs(e1[i].to); out[u]=dfn_clock; } int lowbit(int x){return x&(-x);} void add(int x,int p) { while(p<=sz+1) { c[p]+=x; p+=lowbit(p); } } int query(int p) { int res=0; while(p>0) { res+=c[p]; p-=lowbit(p); } return res; } void solve() { int i=0,p=root,_clock=0; while(s[i]) { if(s[i]=='B') add(-1,dfn[p]),p=trie[p].par; else if(s[i]=='P') { ++_clock; for(int j=head2[_clock];~j;j=e2[j].next) { int t=e2[j].to; ans[j]=query(out[pos[t]])-query(in[pos[t]]-1); } } else { int idx=s[i]-'a'; p=trie[p].next[idx]; add(1,dfn[p]); } i++; } } int main() { memset(head1,-1,sizeof(head1)); memset(head2,-1,sizeof(head2)); scanf("%s",s); insert(),build(); for(int i=1;i<=sz;i++) addedge1(trie[i].fail,i); scanf("%d",&m); for(int i=1;i<=m;i++) { int x,y; scanf("%d%d",&x,&y); addedge2(y,x); } dfs(0); solve(); for(int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }