bzoj 4598: [Sdoi2016]模式字符串

题目描述

给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m的模式串s,其中每一位仍然是A到z的大写字母。

Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径形成的字符串可以由模式串S重复若干次得到?

这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分。

所谓模式串的重复,是将若干个模式串S依次相接(不能重叠)。例如当S=PLUS的时候,重复两次会得到PLUSPLUS,重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以XYXYXY不能看作是S重复若干次得到的。

输入输出格式

输入格式:

每一个数据有多组测试,

第一行输入一个整数C,表示总的测试个数。

对于每一组测试来说:

第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,

之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(第i个字符对应了第i个结点)。

之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,为模式串S。

输出格式:

给出C行,对应C组测试。

每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模式串的若干次重复.

输入输出样例

输入样例#1: 复制
1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI
输出样例#1: 复制
5

说明

1<=C<=10,3<=N<=1000000,3<=M<=1000000

题解

  题解大概看懂一点了……就是说用hash+点分治……好讨厌hash……总感觉还是半懂不懂……

  考虑每一个分治点,从他延伸下去能形成长度为多少的前缀和后缀(不包含自己和包含自己),然后两个两两组合起来计算答案

  据说时间复杂度$O(Tnlogn)$,数据就是为了卡点分的,然而因为全世界都只有前三组数据……所以能A……

  1 //minamoto
  2 #include<cstdio>
  3 #include<iostream>
  4 #include<cstring>
  5 #define N 1000003
  6 #define ull unsigned long long
  7 #define ll long long
  8 #define p 2000001001
  9 #define inf 1000000000
 10 using namespace std;
 11 template<class T>inline bool cmax(T&a,const T&b){return a<b?a=b,1:0;}
 12 inline int read(){
 13     #define num ch-'0'
 14     char ch;bool flag=0;int res;
 15     while(!isdigit(ch=getchar()))
 16     (ch=='-')&&(flag=true);
 17     for(res=num;isdigit(ch=getchar());res=res*10+num);
 18     (flag)&&(res=-res);
 19     #undef num
 20     return res;
 21 }
 22 char sr[1<<21],z[30];int C=-1,Z;
 23 inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}
 24 inline void print(ll x){
 25     if(C>1<<20)Ot();if(x<0)sr[++C]=45,x=-x;
 26     while(z[++Z]=x%10+48,x/=10);
 27     while(sr[++C]=z[Z],--Z);sr[++C]='\n';
 28 }
 29 int n,m,T,n1,rt,head[N],Next[N<<1],ver[N<<1],tot,size[N],son[N],sz,sz1;
 30 int st[N],st1[N],len[N],cnt[N],cnt1[N];
 31 ull mi[N],a[N],a1[N],b[N],c[N],val[N],sum[N],sum1[N];
 32 ll ans;
 33 bool vis[N];char s[N];
 34 inline void add(int u,int v){
 35     ver[++tot]=v,Next[tot]=head[u],head[u]=tot;
 36     ver[++tot]=u,Next[tot]=head[v],head[v]=tot;
 37 }
 38 void findrt(int u,int fa){
 39     size[u]=1,son[u]=0;
 40     for(int i=head[u];i;i=Next[i]){
 41         int v=ver[i];
 42         if(vis[v]||v==fa) continue;
 43         findrt(v,u);
 44         cmax(son[u],size[v]);
 45         size[u]+=size[v];
 46     }
 47     cmax(son[u],n1-size[u]);
 48     if(son[u]<son[rt]) rt=u;
 49 }
 50 void getdep(int u,int fa){
 51     if(b[len[u]]==sum[u]&&val[u]==a[1]) st[++sz]=u;
 52     for(int i=head[u];i;i=Next[i]){
 53         int v=ver[i];if(vis[v]||v==fa) continue;
 54         sum[v]=sum[u]*p+val[v];
 55         len[v]=len[u]+1;
 56         getdep(v,u);
 57     }
 58 }
 59 void getdep1(int u,int fa){
 60     if(c[len[u]]==sum1[u]&&val[u]==a1[1]) st1[++sz1]=u;
 61     for(int i=head[u];i;i=Next[i]){
 62         int v=ver[i];
 63         if(vis[v]||v==fa) continue;
 64         sum1[v]=sum1[u]*p+val[v];
 65         getdep1(v,u);
 66     }
 67 }
 68 void calc(int u){
 69     for(int i=0;i<=m;++i) cnt[i]=cnt1[i]=0;
 70     if(a[1]==val[u]) cnt[1]=1;
 71     if(a[m]==val[u]) cnt1[1]=1;
 72     if(m==1) ans+=cnt1[1];
 73     for(int i=head[u];i;i=Next[i]){
 74         int v=ver[i];
 75         if(vis[v]) continue;
 76         sz=0,len[v]=1,sum[v]=val[v];
 77         getdep(v,u);
 78         for(int j=1;j<=sz;++j){
 79             int t=st[j];int pos=m-(len[t]-1)%m-1;
 80             if(pos==0) pos+=m;
 81             ans+=(ll)cnt1[pos];
 82         }
 83         sz1=0,sum1[v]=val[v];
 84         getdep1(v,u);
 85         for(int j=1;j<=sz1;++j){
 86             int t=st1[j];int pos=m-(len[t]-1)%m-1;
 87             if(pos==0) pos+=m;
 88             ans+=(ll)cnt[pos];
 89         }
 90         for(int j=1;j<=sz;++j){
 91             int t=st[j];int pos=(len[t])%m+1;
 92             if(val[u]==a[pos]) ++cnt[pos];
 93         }
 94         for(int j=1;j<=sz1;++j){
 95             int t=st1[j];int pos=(len[t])%m+1;
 96             if(val[u]==a1[pos]) ++cnt1[pos];
 97         }
 98     }
 99 }
100 void solve(int u){
101     calc(u),vis[u]=1;int totsz=size[u];
102     for(int i=head[u];i;i=Next[i]){
103         int v=ver[i];
104         if(vis[v]) continue;
105         rt=0;
106         n1=size[v];
107         if(n1<m) continue;
108         findrt(v,u);
109         solve(rt);
110     }
111 }
112 int main(){
113     T=read(),mi[0]=1;
114     for(int i=1;i<=1000000;++i) mi[i]=mi[i-1]*p;
115     while(T--){
116         n=read(),m=read(),tot=0,ans=0;
117         memset(head,0,sizeof(head));
118         scanf("%s",s+1);
119         for(int i=1;i<=n;++i) val[i]=s[i]-'A'+1;
120         for(int i=1;i<n;++i){
121             int u=read(),v=read();add(u,v);
122         }
123         scanf("%s",s+1);
124         for(int i=1;i<=max(n,m);++i) a[i]=s[(i-1)%m+1]-'A'+1;
125         for(int i=1;i<=max(n,m);++i) b[i]=b[i-1]+a[i]*mi[i-1];
126         for(int i=1;i<=m;++i) a1[m-i+1]=a[i];
127         for(int i=1;i<=max(n,m);++i) a1[i]=a1[(i-1)%m+1];
128         for(int i=1;i<=max(n,m);++i) c[i]=c[i-1]+a1[i]*mi[i-1];
129         memset(vis,0,sizeof(vis));
130         son[0]=inf,rt=0,n1=n;
131         findrt(1,0);
132         solve(rt);
133         print(ans);
134     }
135     Ot();
136     return 0;
137 }

 

posted @ 2018-08-17 12:54  bztMinamoto  阅读(392)  评论(0编辑  收藏  举报
Live2D