Description
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m
的模式串s,其中每一位仍然是A到z的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径
形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,
重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以X
YXYXY不能看作是S重复若干次得到的。
Input
每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(
第i个字符对应了第i个结点).
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,
为模式串S。
1<=C<=10,3<=N<=10000003<=M<=1000000
Output
给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模
式串的若干次重复.
点分治统计答案,hash判定当前点到中心的路径是否是由模式串重复得到的串的前/后缀,为保证复杂度,递归至子树大小不足m时结束,时间复杂度O(Tnlog(n/m))
#include<bits/stdc++.h> int _(){ int x=0,c=getchar(); while(c<48)c=getchar(); while(c>47)x=x*10+c-48,c=getchar(); return x; } typedef unsigned long long u64; const int N=1e6+7,p=2939; int T,n,m; char s1[N],s2[N]; int es[N*2],enx[N*2],e0[N],ep=2,sz[N],SZ,CG; u64 h1[N],h2[N],ans; bool ed[N]; void f2(int w,int pa){ sz[w]=1; for(int i=e0[w],u;i;i=enx[i]){ u=es[i]; if(u==pa||ed[u])continue; f2(u,w); sz[w]+=sz[u]; } } int md,dw; void f3(int w,int pa){ if(++dw>md)md=dw; bool is=sz[w]*2>=SZ; for(int i=e0[w],u;i;i=enx[i]){ u=es[i]; if(u==pa||ed[u])continue; f3(u,w); if(sz[u]*2>SZ)is=0; } if(is)CG=w; --dw; } int mod[N]; int t1[N],t2[N],a1[N],a2[N]; void f4(int w,int pa,u64 h,int dep){ h=h*p+s1[w]; if(h==h1[dep])++a1[dep[mod]]; if(h==h2[dep])++a2[dep[mod]]; for(int i=e0[w],u;i;i=enx[i]){ u=es[i]; if(u==pa||ed[u])continue; f4(u,w,h,dep+1); } } void f1(int w){ f2(w,0); SZ=sz[w]; md=dw=0; f3(w,0); w=CG; ed[w]=1; if(md*2<m||SZ<m)return; memset(t1,0,sizeof(int)*(m+1)); memset(t2,0,sizeof(int)*(m+1)); for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(ed[u])continue; memset(a1,0,sizeof(int)*(m+1)); memset(a2,0,sizeof(int)*(m+1)); f4(u,w,s1[w],2); a1[m]=a1[0]; a2[m]=a2[0]; a1[m+1]=a1[1]; a2[m+1]=a2[1]; ans+=a1[0]+a2[0]; for(int x=0;x<m;++x){ ans+=u64(a1[m+1-x])*t2[x]; ans+=u64(a2[m+1-x])*t1[x]; } for(int x=0;x<m;++x)t1[x]+=a1[x],t2[x]+=a2[x]; } for(int i=e0[w];i;i=enx[i]){ int u=es[i]; if(!ed[u])f1(u); } } int main(){ for(T=_();T;--T){ ans=0; n=_();m=_(); memset(e0,0,sizeof(int)*(n+1)); memset(ed,0,n+1); scanf("%s",s1+1); for(int i=1,a,b;i<n;++i){ a=_();b=_(); es[ep]=b;enx[ep]=e0[a];e0[a]=ep++; es[ep]=a;enx[ep]=e0[b];e0[b]=ep++; } scanf("%s",s2+1); u64 pp=1; for(int i=1,j=1,k=m;i<=n;++i,pp*=p){ mod[i]=i%m; h1[i]=s2[j]*pp+h1[i-1]; h2[i]=s2[k]*pp+h2[i-1]; if(m<++j)j=1; if(!--k)k=m; } if(m>1)f1(1); else for(int i=1;i<=n;++i)if(s1[i]==s2[1])++ans; printf("%llu\n",ans); } return 0; }