【BZOJ #4231】—回忆树(Kmp+Ac自动机)

传送门

考虑把经过LcaLca提出来暴力KmpKmp统计
总长度只有O(S)O(|S|)
现在只用考虑一段链
差分后就是到根的一条链

考虑从根dfsdfsdfsdfs的同时在AcAc自动机上走,把每个点点权加一
对于每个询问就是failfail树子树和
走出去的的时候减一

对正反串分别建一个AcAc自动机就可以了

#include<bits/stdc++.h>
using namespace std;
const int RLEN=1<<20|1;
inline char gc(){
    static char ibuf[RLEN],*ib,*ob;
    (ob==ib)&&(ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
    return (ob==ib)?EOF:*ib++;
}
#define gc getchar
inline int read(){
    char ch=gc();
    int res=0,f=1;
    while(!isdigit(ch))f^=ch=='-',ch=gc();
    while(isdigit(ch))res=(res+(res<<2)<<1)+(ch^48),ch=gc();
    return f?res:-res;
}
#define ll long long
#define re register
#define pii pair<int,int>
#define pic pair<int,char>
#define fi first
#define se second
#define pb push_back
#define cs const
#define bg begin
#define poly vector<int>
#define chemx(a,b) ((a)<(b)?(a)=(b):0)
#define chemn(a,b) ((a)>(b)?(a)=(b):0)
cs int N=100005,M=300005;
namespace Kmp{
	int n,m,nxt[M];
	inline int calc(char *a,int _n,char *b){
		n=_n,m=strlen(b+1);
		for(int i=0,j=2;j<=m;j++){
			while(i&&b[i+1]!=b[j])i=nxt[i];
			if(b[i+1]==b[j])i++;
			nxt[j]=i;
		}
		int res=0;
		for(int i=0,j=1;j<=n;j++){
			while(i&&b[i+1]!=a[j])i=nxt[i];
			if(b[i+1]==a[j])i++;
			if(i==m)res++,i=nxt[i];
		}
		return res;
	}
}
struct Ac{
	int nxt[M][26],fail[M],tot,ed[M],val[M];
	inline void insert(char *s,int id){
		int p=0;
		for(int i=1,len=strlen(s+1);i<=len;i++){
			int c=s[i]-'a';
			if(!nxt[p][c])nxt[p][c]=++tot;
			p=nxt[p][c];
		}
		ed[id]=p;
	}
	queue<int> q;
	inline void buildfail(){
		for(int i=0;i<26;i++)
		if(nxt[0][i])q.push(nxt[0][i]);
		while(!q.empty()){
			int p=q.front();q.pop();
			for(int c=0;c<26;c++){
				int v=nxt[p][c];
				if(v)fail[v]=nxt[fail[p]][c],q.push(v);
				else nxt[p][c]=nxt[fail[p]][c];
			}
		}
	}
	vector<int> e[M];
	int in[N],out[N],dfn;
	void dfs(int u){
		in[u]=++dfn;
		for(int &v:e[u]){
			dfs(v);
		}
		out[u]=dfn;
	}
	inline void build(){
		buildfail();
		for(int i=1;i<=tot;i++)e[fail[i]].pb(i);
		dfs(0);
	}
	int tr[N];
	#define lb(x) (x&(-x))
	inline void update(int u,int k){
		for(int p=in[u];p<=dfn;p+=lb(p))tr[p]+=k;
	}
	inline int qry(int p,int res=0){
		for(;p;p-=lb(p))res+=tr[p];return res;
	}
	inline int query(int id){
		int u=ed[id];
		return qry(out[u])-qry(in[u]-1);
	}
}ac[2];
int n,m,fa[N][20],dep[N],ans[N];
vector<pic> e[N];
char val[N];
struct ask{
	int id,coef,kd;
	ask(int _i=0,int _c=0,int _k=0):id(_i),coef(_c),kd(_k){}
};
vector<ask> q[N];
char s[M];
inline int Lca(int u,int v){
	if(dep[u]<dep[v])swap(u,v);
	for(int i=19;~i;i--)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
	if(u==v)return u;
	for(int i=19;~i;i--)
	if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
inline int find(int u,int k){
	if(k<0)return u;
	for(int i=19;~i;i--)if(k&(1<<i))u=fa[u][i];
	return u;
}
void dfs1(int u){
	for(int i=1;i<=19;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
	for(pic &x:e[u]){
		int v=x.fi;
		if(v==fa[u][0])continue;
		val[v]=x.se,fa[v][0]=u;
		dep[v]=dep[u]+1,dfs1(v);
	}
}
char stk1[N],stk2[N];
int top1,top2;
inline int calc(int u,int v,int lca,char *s){
	top1=top2=0;
	while(u!=lca)stk1[++top1]=val[u],u=fa[u][0];
	while(v!=lca)stk2[++top2]=val[v],v=fa[v][0];
	while(top2)stk1[++top1]=stk2[top2--];
	return Kmp::calc(stk1,top1,s);
}
void dfs2(int u,int node0,int node1){
	ac[0].update(node0,1),ac[1].update(node1,1);
	for(ask &x:q[u]){
		ans[x.id]+=x.coef*ac[x.kd].query(x.id);
	}
	for(pic &x:e[u]){
		int v=x.fi;
		if(v==fa[u][0])continue;
		dfs2(v,ac[0].nxt[node0][val[v]-'a'],ac[1].nxt[node1][val[v]-'a']);
	}
	ac[0].update(node0,-1),ac[1].update(node1,-1);
}
int main(){
	n=read(),m=read();
	for(int i=1;i<n;i++){
		int u=read(),v=read();scanf("%s",s+1);
		e[u].pb(pic(v,s[1])),e[v].pb(pic(u,s[1]));
	}
	dep[1]=1,dfs1(1);
	for(int i=1;i<=m;i++){
		int u=read(),v=read(),lca=Lca(u,v);
		scanf("%s",s+1);int len=strlen(s+1);
		int u1=find(u,dep[u]-dep[lca]-len+1),v1=find(v,dep[v]-dep[lca]-len+1);
		ans[i]+=calc(u1,v1,lca,s);
		if(dep[u]-dep[lca]>=len)q[u].pb(ask(i,1,1)),q[u1].pb(ask(i,-1,1));
		if(dep[v]-dep[lca]>=len)q[v].pb(ask(i,1,0)),q[v1].pb(ask(i,-1,0));
		ac[0].insert(s,i),reverse(s+1,s+len+1),ac[1].insert(s,i);
	}
	ac[0].build(),ac[1].build();
	dfs2(1,0,0);
	for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
}
posted @ 2019-09-25 17:16  Stargazer_cykoi  阅读(116)  评论(0编辑  收藏  举报