【BZOJ4297】[PA2015]Rozstaw szyn 树形DP

【BZOJ4297】[PA2015]Rozstaw szyn

Description

给定一棵有n个点,m个叶子节点的树,其中m个叶子节点分别为1到m号点,每个叶子节点有一个权值r[i]。你需要给剩下n-m个点各指定一个权值,使得树上相邻两个点的权值差的绝对值之和最小。

Input

第一行包含两个正整数n,m(2<=n<=500000,1<=m<=n),分别表示点数和叶子数。
接下来n-1行,每行两个正整数u,v(1<=u,v<=n),表示u与v之间有一条边。
接下来m行,每行一个正整数,依次为r[1],r[2],...,r[m](1<=r[i]<=500000),表示每个叶子的权值。

Output

输出一个整数,即树上相邻两个点的权值差的绝对值之和的最小值。

Sample Input

6 4
1 5
2 5
3 6
4 6
5 6
5
10
20
40

Sample Output

35

题解:思路同BZOJ1304,咱们先来证几个结论:

1.我们从下往上逐层贪心,每次选择一个点的取值范围时,只保证它与它的儿子之间差的绝对值之和最小,而不考虑它的父亲。这样为什么是对的呢?假如x的最优值为v,我们为了使它的父亲更优,将x的取值改为v+d,那么x与x父亲之间的差会减小d,但 x的所有值<=v的儿子 与x之间的差都增加了d。具体地,如果x有a个儿子,那么增加量至少是d。显然是没有一开始优的。

2.以哪个非叶子节点为根进行DP,最后得到的答案都是一样的。假如当前根为x,x的儿子是y。那么如果x的最优取值区间被y包含,相当于x和y之间的差可以为0,那么如果把y当成根,则y的取值区间显然也会被x包含(不要问为什么显然~)。否则我们不考虑x-y这条边,x的取值范围是[l,r],那么在考虑y的贡献后x的取值范围只可能是[...,l]或[r,...],即其他点对x的影响可视为不变,那么只需要最后加上x-y的贡献即可。把y当根也是同理,所以将那个点当成根答案都是一样的。

所以具体做法:随便找一个点当根进行DP,然后用每个点的儿子的最优取值区间来得到当前点的最优取值区间。具体地,我们将x的所有儿子的最优取值区间的左右端点放到一起排序,然后取中间的那段即可。

 

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=500010;
typedef long long ll;
int n,m,cnt;
ll ans;
int to[maxn<<1],next[maxn<<1],head[maxn],l[maxn],r[maxn],p[maxn<<1];
inline void add(int a,int b)
{
	to[cnt]=b,next[cnt]=head[a],head[a]=cnt++;
}
void dfs(int x,int fa)
{
	if(x<=m)	return ;
	int i,tot=0;
	for(i=head[x];i!=-1;i=next[i])	if(to[i]!=fa)	dfs(to[i],x);
	for(i=head[x];i!=-1;i=next[i])	if(to[i]!=fa)	p[++tot]=l[to[i]],p[++tot]=r[to[i]];
	sort(p+1,p+tot+1);
	l[x]=p[tot>>1],r[x]=p[(tot>>1)+1];
	for(i=head[x];i!=-1;i=next[i])	if(to[i]!=fa&&(r[to[i]]<l[x]||l[to[i]]>l[x]))
		ans+=min(abs(l[to[i]]-l[x]),abs(r[to[i]]-l[x]));
}
inline int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')	f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret*f;
}
int main()
{
	//freopen("bz4297.in","r",stdin);
	n=rd(),m=rd();
	int i,j,a,b;
	memset(head,-1,sizeof(head));
	for(i=1;i<n;i++)	a=rd(),b=rd(),add(a,b),add(b,a);
	for(i=1;i<=m;i++)	l[i]=r[i]=rd();
	if(n==m)
	{
		for(i=1;i<=n;i++)	for(j=head[i];j!=-1;j=next[j])	ans+=abs(l[to[j]]-l[i]);
		printf("%lld",ans>>1);
		return 0;
	}
	dfs(n,0);
	printf("%lld",ans);
	return 0;
}

 

posted @ 2017-10-28 17:01  CQzhangyu  阅读(510)  评论(0编辑  收藏  举报