Luogu P2414 [NOI2011]阿狸的打字机

Link
问题等价于问\(y\)点的字符串有多少个\(pre\)\(x\)的字符串作为\(suf\),也就是问Trie 树上根节点到\(y\)路径上的点中有多少点在\(fail\)树上\(x\)的子树里,树状数组维护即可。

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+7;
int read(){int x;scanf("%d",&x);return x;}
int lowbit(int x){return x&-x;}
struct node{int num,opt;};
struct data{int ver,num,Next;}edge[N];
queue<node>q;
vector<int>G[N];
int cnt=1,word,fail[N],t[N][27],flag[N],dis[N],fa[N],tot,head[N],ans[N],dfn[N],size[N],bit[N<<2],vis[N],Time;
char s[N];
int insert(int p,char c)
{
    if(t[p][c-'a'+1]) return t[p][c-'a'+1];
    return fa[++cnt]=p,t[p][c-'a'+1]=cnt;
}
void end(int p){flag[p]=++word,dis[word]=p;}
void build()
{
    int i,p;node x;
    for(i=1;i<=26;++i) if(t[1][i]) fail[t[1][i]]=1,q.push(node{t[1][i],i});
    while(!q.empty())
    {
	x=q.front(),q.pop();
	if(fail[x.num]^1)
	{
	    for(p=fail[fa[x.num]];;p=fail[p])
	    {
		if(t[p][x.opt]){p=t[p][x.opt];break;}
		if(p==1)break;
	    }
	    fail[x.num]=p;
	}
	for(i=1;i<=26;++i) if(t[x.num][i]) q.push(node{t[x.num][i],i});
    }
}
void update(int p,int x){for(;p<=cnt;p+=lowbit(p))bit[p]+=x;}
int query(int x){int sum=0;for(;x;x-=lowbit(x))sum+=bit[x];return sum;}
void dfs1(int u)
{
    dfn[u]=++Time,size[u]=1,vis[u]=1;
    for(int i=G[u].size()-1,v;~i;--i) if(!vis[v=G[u][i]]) dfs1(v),size[u]+=size[v];
}
void dfs2(int u)
{
    update(dfn[u],1);data v;
    if(flag[u]) for(int i=head[flag[u]],x;i;i=edge[i].Next) v=edge[i],x=dis[v.ver],ans[v.num]=query(dfn[x]+size[x]-1)-query(dfn[x]-1);
    for(int i=1;i<=26;++i) if(t[u][i]) dfs2(t[u][i]);
    update(dfn[u],-1);
}
int main()
{
    int n,l,i,p,u,v;
    scanf("%s",s+1),l=strlen(s+1);
    for(i=1;i<=l;++i) if(s[i]^'B'&&s[i]^'P') break;
    p=insert(1,s[i]);
    for(++i;i<=l;++i) if(s[i]=='B') p=fa[p]; else if(s[i]=='P') end(p); else p=insert(p,s[i]);
    n=read();
    for(i=1;i<=n;++i)
    {
	v=read(),u=read();
	edge[++tot]=data{v,i,head[u]},head[u]=tot;
    }
    build();
    for(i=2;i<=cnt;++i) G[i].push_back(fail[i]),G[fail[i]].push_back(i);
    dfs1(1);
    dfs2(1);
    for(i=1;i<=n;++i) printf("%d\n",ans[i]);
    return 0;
}
posted @ 2020-01-21 16:15  Shiina_Mashiro  阅读(127)  评论(0编辑  收藏  举报