C

内存限制:256 MiB时间限制:1000 ms

题目描述

给定一棵有n个节点的无根树,树上的每个点有一个非负整数点权。定义一条路径的价值为路径上的点权和-路径上的点权最大值。 给定参数P,我!=们想知道,有多少不同的树上简单路径,满足它的价值恰好是P的倍数。 注意:单点算作一条路径;u!=v时,(u,v)和(v,u)只算一次。

输入格式

第一行包含两个整数n,p,意义见题面描述。 接下来n-1行,每行两个整数u,v,表示一条树边。 接下来一行n个整数,第i个整数vali表示点i的权值。

输出格式

输出包含一行一个整数,表示答案。

样例

样例输入

5 2
1 2
1 3
2 4
3 5
1 3 3 1 2

样例输出

9

数据范围与提示

满足条件的路径有(1,1),(2,2),(3,3),(4,4),(5,5),(1,4),(2,3),(2,5),(3,5)
评分方式
本题共有25个测试点,每个测试点满分为4分,不设置部分分;
对每个测试点,如果你的程序答案正确,得满分,否则得0分。
对所有测试点,我们有:

性质A:保证树上不存在度数大于2的点。
性质B:树的形态通过以某个点为根,其他点依次“随机父亲”生成。
数据有一定梯度。

 

 

 

 

 

题解

上一次写点分治是在5个月前。。。。学会了就丢掉了

慢慢地该把点分治的东西捡回来了。。。。

 

考场上只想到了序列分治+单调队列,然后就没有看那个P<=10^7的条件,于是写了个map。。。

好好的一个nlogn就变成了nlognlogn。。。

其实用一个桶就可以了

 

先来讲一下序列分治的做法吧。

一个序列,按中点分治,两边分别做一次单调队列,然后就像归并排序一样,看哪边小就加哪边的答案,依次递增的加入点值

统计答案的时候就是计算另外一边的满足条件的数的个数,这个用两个桶就可以了:

56分代码:

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
using namespace std;
inline int gi()
{
	char c;int num=0,flg=1;
	while((c=getchar())<'0'||c>'9')if(c=='-')flg=-1;
	while(c>='0'&&c<='9'){num=num*10+c-48;c=getchar();}
	return num*flg;
}

#define N 100005
#define INF 0x3f3f3f3f
int n,mod;
int a[N],b[N],ct,d[N];
int s1[N],s2[N],top1,top2;
int mp1[10000005],mp2[10000005];
int fir[N],to[2*N],nxt[2*N],cnt;
long long ans;

void adde(int a,int b)
{
	to[++cnt]=b;nxt[cnt]=fir[a];fir[a]=cnt;
	to[++cnt]=a;nxt[cnt]=fir[b];fir[b]=cnt;
	d[a]++;d[b]++;
}

void force(int u,int ff,int sum,int mx)
{
	if(((sum-mx)%mod+mod)%mod==0)
		ans++;
	for(int v,p=fir[u];p;p=nxt[p]){
		v=to[p];
		if(v!=ff)
			force(v,u,(sum+a[v])%mod,max(mx,a[v]));
	}
}


void dfs(int u,int ff)
{
	b[++ct]=a[u];
	for(int p=fir[u];p;p=nxt[p])
		if(to[p]!=ff)
			dfs(to[p],u);
}
void FZ(int l,int r)
{
	if(l==r){ans++;return;}
	int mid=(l+r)>>1;
	FZ(l,mid);FZ(mid+1,r);
	
	top1=top2=0;
	s1[++top1]=INF;s2[++top2]=INF;
	for(int i=l;i<=mid;i++){
		while(top1&&s1[top1]<=a[i])top1--;
		s1[++top1]=a[i];
	}
	for(int i=r;i>mid;i--){
		while(top2&&s2[top2]<=a[i])top2--;
		s2[++top2]=a[i];
	}
	
	int sum1=0,sum2=0,L=mid,R=mid+1,mx;
	while(top1&&top2){
		if(s1[top1]>=s2[top2]){
			mx=s2[top2];top2--;
			for(;R<=r&&a[R]<s2[top2];R++){
				sum2=(sum2+a[R])%mod;
				mp2[sum2]++;
				int pos=((mx-sum2)%mod+mod)%mod;
				ans+=1ll*mp1[pos];
			}
		}
		else{
			mx=s1[top1];top1--;
			for(;L>=l&&a[L]<s1[top1];L--){
				sum1=(sum1+a[L])%mod;
				mp1[sum1]++;
				int pos=((mx-sum1)%mod+mod)%mod;
				ans+=1ll*mp2[pos];
			}
		}
	}
	
	for(int i=l;i<=mid;i++)
		mp1[sum1]--,sum1=((sum1-a[i])%mod+mod)%mod;
	for(int i=r;i>mid;i--)
		mp2[sum2]--,sum2=((sum2-a[i])%mod+mod)%mod;
}
int main()
{
	
	//freopen("1.in","r",stdin);
	freopen("c.in","r",stdin);
	freopen("c.out","w",stdout);
	
	int i,u,v,mxd=0;
	n=gi();mod=gi();
	for(i=1;i<n;i++){u=gi();v=gi();adde(u,v);}
	for(i=1;i<=n;i++){a[i]=gi();mxd=max(mxd,d[i]);}
	if(n<=2000){
		for(i=1;i<=n;i++)
			force(i,0,a[i]%mod,a[i]);
		printf("%lld",(ans-n)/2+n);
	}
	else if(mxd==2){
		for(i=1;i<=n;i++)
			if(d[i]==1){dfs(i,0);break;}
		memcpy(a,b,sizeof(b));
		FZ(1,n);
		printf("%lld",ans);
	}
	else
		printf("%d",n);
}

 

 

 

由于对每个儿子都开一个单调队列来多路归并排序比较麻烦,所以我们就直接把所有点都统计进来,按照mx值排一个序(多一个logn也是可以过的),从小到大加入桶中,边加边统计即可解决一个点的问题,然后再套一个点分治就A了!!!

等等,还有一个问题就是去重

其实就是点分治的经典去重方法,对每一个分治中心的邻接点都做一次同样的dfs,计算它们的答案

100分代码:O(nlognlogn)

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
inline int gi()
{
	char c;int num=0,flg=1;
	while((c=getchar())<'0'||c>'9')if(c=='-')flg=-1;
	while(c>='0'&&c<='9'){num=num*10+c-48;c=getchar();}
	return num*flg;
}
#define N 100005
long long ans;
int a[N],siz[N],mi,rt;
int tmp[N],mx[N],sum[N],con[10000005],tot,mod;
int fir[N],to[2*N],nxt[2*N],cnt;
bool vis[N];
void adde(int a,int b)
{
	to[++cnt]=b;nxt[cnt]=fir[a];fir[a]=cnt;
	to[++cnt]=a;nxt[cnt]=fir[b];fir[b]=cnt;
}
void dfs1(int u,int ff,int all)
{
	int tmx=0;siz[u]=1;
	for(int v,p=fir[u];p;p=nxt[p]){
		if((v=to[p])!=ff&&!vis[v]){
			dfs1(v,u,all);
			tmx=max(tmx,siz[v]);
			siz[u]+=siz[v];
		}
	}
	tmx=max(tmx,all-siz[u]);
	if(tmx<mi){mi=tmx;rt=u;}
}
int getrt(int u,int all){mi=0x3f3f3f3f;dfs1(u,0,all);return rt;}

bool cmp(const int &x,const int &y){return mx[x]<mx[y];}
void dfs2(int u,int ff,int su,int tmx)
{
	su=(su+a[u])%mod;tmx=max(tmx,a[u]);
	tmp[++tot]=tot;mx[tot]=tmx;sum[tot]=su;
	siz[u]=1;
	for(int v,p=fir[u];p;p=nxt[p]){
		if((v=to[p])!=ff&&!vis[v]){
			dfs2(v,u,su,tmx);
			siz[u]+=siz[v];
		}
	}
}
void solve(int u)
{
	tot=0;dfs2(u,0,0,0);sort(tmp+1,tmp+tot+1,cmp);
	for(int i=1;i<=tot;i++){
		ans+=con[(mod-(sum[tmp[i]]-a[u]-mx[tmp[i]])%mod)%mod];
		con[(sum[tmp[i]]%mod+mod)%mod]++;
	}
	for(int i=1;i<=tot;i++)
		con[(sum[tmp[i]]%mod+mod)%mod]=0;
	vis[u]=1;
	for(int v,p=fir[u];p;p=nxt[p]){
		if(!vis[v=to[p]]){
			tot=0;dfs2(v,u,0,a[u]);sort(tmp+1,tmp+tot+1,cmp);
			for(int i=1;i<=tot;i++){
				ans-=con[(mod-(sum[tmp[i]]+a[u]-mx[tmp[i]])%mod)%mod];
				con[(sum[tmp[i]]%mod+mod)%mod]++;
			}
			for(int i=1;i<=tot;i++)
				con[(sum[tmp[i]]%mod+mod)%mod]=0;
		}
	}
	for(int v,p=fir[u];p;p=nxt[p])
		if(!vis[v=to[p]]) solve(getrt(v,siz[v]));
}
int main()
{
	//freopen("c1.in","r",stdin);
	int n,i,u,v;
	n=gi();mod=gi();
	for(i=1;i<n;i++){u=gi();v=gi();adde(u,v);}
	for(i=1;i<=n;i++)a[i]=gi();
	solve(getrt(1,n));
	printf("%lld\n",ans+n);
}

 

 

 

这应该是一个点分治裸题啊啊啊啊啊啊啊啊啊啊啊!!!!!

我太菜了,点分治忘完了。。。