BZOJ4477[Jsoi2015]字符串树——可持久化trie树
题目描述
萌萌买了一颗字符串树的种子,春天种下去以后夏天就能长出一棵很大的字
符串树。字符串树很奇特,树枝上都密密麻麻写满了字符串,看上去很复杂的样
子。
【问题描述】
字符串树本质上还是一棵树,即N个节点N-1条边的连通无向无环图,节点
从1到N编号。与普通的树不同的是,树上的每条边都对应了一个字符串。萌萌
和JYY在树下玩的时候,萌萌决定考一考JYY。每次萌萌都写出一个字符串S和
两个节点U,V,需要JYY立即回答U和V之间的最短路径(即,之间边数最少的
路径。由于给定的是一棵树,这样的路径是唯一的)上有多少个字符串以为前
缀。
JYY虽然精通编程,但对字符串处理却不在行。所以他请你帮他解决萌萌的难题。
输入
输入第一行包含一个整数N,代表字符串树的节点数量。
接下来N-1行,每行先是两个数U,V,然后是一个字符串S,表示节点和U节
点V之间有一条直接相连的边,这条边上的字符串是S。输入数据保证给出的是一
棵合法的树。
接下来一行包含一个整数Q,表示萌萌的问题数。
接来下Q行,每行先是两个数U,V,然后是一个字符串S,表示萌萌的一个问
题是节点U和节点V之间的最短路径上有多少字符串以S为前缀。
输入中所有字符串只包含a-z的小写字母。
1<=N,Q<=100,000,且输入所有字符串长度不超过10。
输出
输出Q行,每行对应萌萌的一个问题的答案。
样例输入
4
1 2 ab
2 4 ac
1 3 bc
3
1 4 a
3 4 b
3 2 ab
1 2 ab
2 4 ac
1 3 bc
3
1 4 a
3 4 b
3 2 ab
样例输出
2
1
1
1
1
在原树上每个点建一个版本的可持久化$trie$树,每个点的可持久化$trie$树继承父节点的$trie$树并将它与父节点边上的字符串加入到这个点的$trie$树中。查询一个串是多少个串的前缀直接在$trie$树上跑到对应点取这个点子树大小即可。由于每个点的$trie$树保存了这个点到根路径上字符串的信息,所以每次询问的答案就是$query(u)+query(v)-2*query(lca(u,v))$。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> using namespace std; struct miku { int son[28],sum; }s[1000010]; int n,m; int x,y; int cnt; int root[100010]; int head[100010]; int next[200010]; char ch[100010][12]; int to[200010]; int tot; int f[100010][19]; int dep[100010]; char rd[12]; int id[200010]; void add(int x,int y,int i) { next[++tot]=head[x]; head[x]=tot; to[tot]=y; id[tot]=i; } int build(int pre,int dep,char *ch,int lim) { int rt=++cnt; s[rt]=s[pre]; s[rt].sum++; if(dep==lim) { return rt; } s[rt].son[ch[dep]-'a']=build(s[pre].son[ch[dep]-'a'],dep+1,ch,lim); return rt; } void dfs(int x) { for(int i=1;i<=17;i++) { f[x][i]=f[f[x][i-1]][i-1]; } for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x][0]) { f[to[i]][0]=x; dep[to[i]]=dep[x]+1; root[to[i]]=build(root[x],0,ch[id[i]],strlen(ch[id[i]])); dfs(to[i]); } } } int lca(int x,int y) { if(dep[x]<dep[y]) { swap(x,y); } int d=dep[x]-dep[y]; for(int i=0;i<=17;i++) { if(d&(1<<i)) { x=f[x][i]; } } if(x==y) { return x; } for(int i=17;i>=0;i--) { if(f[x][i]!=f[y][i]) { x=f[x][i]; y=f[y][i]; } } return f[x][0]; } int query(int rt,int lim) { for(int i=0;i<lim;i++) { rt=s[rt].son[rd[i]-'a']; } return s[rt].sum; } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { scanf("%d%d%s",&x,&y,ch[i]); add(x,y,i); add(y,x,i); } dfs(1); scanf("%d",&m); while(m--) { scanf("%d%d%s",&x,&y,rd); int anc=lca(x,y); int len=strlen(rd); printf("%d\n",query(root[x],len)+query(root[y],len)-2*query(root[anc],len)); } }