[SDOI2016] 模式字符串 (BZOJ4598 & VIJOS1995)

首先直接点分+hash就可以做,每个点用hash判断是否为S重复若干次后的前缀或后缀,每个子树与之前的结果O(m)暴力合并。在子树大小<m时停止分治,则总复杂度为O(nlog(n/m))。

问题在于n<=1e6。据说有O(n)的DP做法?写点分的话需要一大波常数优化……据说SDOI现场写了这题的全卡常T了……注意BZOJ并没有大数据,如果常数够小的话可以去VIJOS提交。

#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long u64;
typedef u64 ll;
const int N=1e6+5;
struct edge{int v;edge*s;}e[N*2];
edge*l1=e,*h[N];
void ins(int u,int v){
	edge s={v,h[u]};
	*(h[u]=l1++)=s;
}
int q,n,m,s1,s2,s3,d=1,siz[N],f1[N],f2[N],f3[N],f4[N];
bool vis[N];
char t1[N],t2[N];
ll c1[N],c2[N];
u64 ans;
void dfs1(int u,int v){
	int s4=0;
	s3=max(s3,d++);
	siz[u]=1;
	for(edge*i=h[u];i;i=i->s)
		if(i->v!=v&&!vis[i->v]){
			dfs1(i->v,u);
			siz[u]+=siz[i->v];
			s4=max(s4,siz[i->v]);
		}
	if(max(s4,s1-siz[u])*2<=s1)
		s2=u;
	--d;
}
void dfs3(int u,int v,ll w){
	++d,w=w*223+t1[u];
	if(w==c1[d])++f1[d%m];
	if(w==c2[d])++f2[d%m];
	for(edge*i=h[u];i;i=i->s)
		if(i->v!=v&&!vis[i->v])
			dfs3(i->v,u,w);
	--d;
}
void dfs2(int u){
	s3=0,dfs1(u,0),u=s2;
	if(s3*2-1<m)return;
	for(int j=0;j<m;++j)
		f3[j]=f4[j]=0;
	f3[1]=f4[1]=1;
	for(edge*i=h[u];i;i=i->s)
		if(!vis[i->v]){
			for(int j=0;j<m;++j)
				f1[j]=f2[j]=0;
			dfs3(i->v,u,t1[u]);
			for(int j=0;j<m;++j)
				ans+=(u64)f1[j]*f4[(m-j+1)%m]+(u64)f2[j]*f3[(m-j+1)%m];
			for(int j=0;j<m;++j)
				f3[j]+=f1[j],f4[j]+=f2[j];
		}
	vis[u]=1;
	int s5=s1;
	for(edge*i=h[u];i;i=i->s)
		if(!vis[i->v]){
			s1=siz[i->v]<siz[u]?siz[i->v]:s5-siz[u];
			if(s1>=m)dfs2(i->v);
		}
}
struct buf{
	char z[1<<25],*s;
	buf():s(z){
		z[fread(z,1,sizeof z,stdin)]=0;
	}
	void pre(char*v){
		while(*s<48)++s;
		while(*s>32)*v++=*s++;
		*v=0;
	}
	operator int(){
		int x=0;
		while(*s<48)++s;
		while(*s>32)
			x=x*10+*s++-48;
		return x;
	}
}it;
int main(){
	q=it;
	while(q--){
		n=it,m=it,it.pre(t1+1);
		for(int i=1;i<=n;++i)
			h[i]=0,vis[i]=0;
		l1=e;
		for(int i=2;i<=n;++i){
			int u=it,v=it;
			ins(u,v),ins(v,u);
		}
		it.pre(t2+1);
		ll w=1;
		int i=1;
		int j=m;
		for(int k=1;k<=n;++k){
			c1[k]=c1[k-1]+w*t2[i];
			c2[k]=c2[k-1]+w*t2[j];
			w*=223;
			if(++i>m)i=1;
			if(--j<1)j=m;
		}
		ans=0,s1=n,dfs2(1);
		printf("%llu\n",ans);
	}
}
posted @ 2017-03-02 13:15  f321dd  阅读(244)  评论(0编辑  收藏  举报