#trie,树链剖分#洛谷 6088 [JSOI2015]字符串树
分析
显然树上的问题可以转换成根节点到两点的答案减去2倍根节点到LCA的答案
化边为点,考虑子节点承接父节点的trie,再加入一条新的字符串,
在循环的过程中统计一个位置被多少个字符串经过,
这样在查询的时候直接访问某个trie跳到末尾找到答案
代码
#include <cstdio>
#include <cctype>
#define rr register
using namespace std;
const int N=100101; char s[N][11];
struct node{int y,next;}e[N<<1];
int Len[N],trie[N*10][26],sum[N*10],n,k=1,dep[N];
int dfn[N],son[N],top[N],fat[N],as[N],rt[N],Tot,tot,big[N];
inline signed iut(){
rr signed ans=0; rr char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
inline void print(int ans){
if (ans>9) print(ans/10);
putchar(ans%10+48);
}
inline signed Insert(int Rt,int j){
rr int trt=++Tot;
for (rr int i=1;i<=Len[j];++i){
for (rr int p=0;p<26;++p)
trie[Tot][p]=trie[Rt][p];
sum[Tot]=sum[Rt]+1;
trie[Tot][s[j][i]-97]=Tot+1,
Rt=trie[Rt][s[j][i]-97],++Tot;
}
sum[Tot]=sum[Rt]+1;
for (rr int p=0;p<26;++p)
trie[Tot][p]=trie[Rt][p];
return trt;
}
inline signed query(int Rt){
for (rr int j=1;j<=Len[0];++j)
Rt=trie[Rt][s[0][j]-97];
return sum[Rt];
}
inline void dfs1(int x,int fa){
dep[x]=dep[fa]+1,fat[x]=fa,son[x]=1;
for (rr int i=as[x],mson=-1;i;i=e[i].next)
if (e[i].y!=fa){
rt[e[i].y]=Insert(rt[x],i>>1);
dfs1(e[i].y,x),son[x]+=son[e[i].y];
if (son[e[i].y]>mson) big[x]=e[i].y,mson=son[e[i].y];
}
}
inline void dfs2(int x,int linp){
dfn[x]=++tot,top[x]=linp;
if (!big[x]) return; dfs2(big[x],linp);
for (rr int i=as[x];i;i=e[i].next)
if (e[i].y!=fat[x]&&e[i].y!=big[x])
dfs2(e[i].y,e[i].y);
}
inline signed Lca(int x,int y){
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]]) x^=y,y^=x,x^=y;
x=fat[top[x]];
}
if (dep[x]>dep[y]) x^=y,y^=x,x^=y;
return x;
}
signed main(){
n=iut();
for (rr int i=1;i<n;++i){
rr int x=iut(),y=iut();
e[++k]=(node){y,as[x]},as[x]=k;
e[++k]=(node){x,as[y]},as[y]=k;
rr char c=getchar();
while (!isalpha(c)) c=getchar();
while (isalpha(c)) s[i][++Len[i]]=c,c=getchar();
}
dfs1(1,0),dfs2(1,1);
for (rr int Q=iut();Q;--Q,putchar(10)){
rr int x=iut(),y=iut(),lca=Lca(x,y);
rr char c=getchar(); Len[0]=0;
while (!isalpha(c)) c=getchar();
while (isalpha(c)) s[0][++Len[0]]=c,c=getchar();
print(query(rt[x])+query(rt[y])-2*query(rt[lca]));
}
return 0;
}