[BZOJ4987] Tree

题目

从前有棵树。

找出K个点A1,A2,…,Ak
使得∑dis(Ai,Ai+1),(1<=i<=K-1)最小。

题解

先考虑几个显而易见的性质:

1.选出的点一定是相邻的(不然距离会更大)

2.对于选出的点,如果从ak再走回a1,那么就相当于每条边经过了两次

由于题目没有包含dis(ak,a1),因此就相当于选出的点中的一条链可以只经过一次,其余的需要经过两次。

(如果包含的话就是个很水的树形背包)

就是说,在选点的同时还要选出只计算一次的那条链。

涉及到路径的一般考虑路径拼接的dp方式

即,设一个状态用来记录当前已经决定了多少个路径的端点,即0,1,2三种取值

又因为要取k个,需要树形背包

结合起来就是:

设$f[i][j][k]$表示以i为根的子树中选择点i,共选出j条边,且包含的链端点数目为k的最小代价。

接下来讨论转移方程

对于每一个节点id,枚举每一个儿子t,在dfs完后合并

cost[i]表示id到t的距离

0的只能从0的来

$dp[id][j+k][0]=min(dp[id][j][0]+dp[t][k][0]+cost[i]*2)$

1的可以分类讨论端点是在已经枚举的子树内还是新加的子树内

$dp[id][j+k][1]=min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i])$

2的有(0,2),(1,1),(2,0)三种情况

$dp[id][j+k][2]=min($

$dp[id][j][2]+dp[t][k][0]+cost[i]*2,$

$dp[id][j][1]+dp[t][k][1]+cost[i],$

$dp[id][j][0]+dp[t][k][2]+cost[i]*2)$

 

注意要倒序枚举,因为每个物品只能选一个

时间复杂度

咋一看像是$n^3$的,其实是$n^2$

对于节点u,设其每个儿子子树大小为$a_k$,总大小为A

复杂度就是$a_1+a_2*a_1+a_3*(a_1+a_2)+a_4*(a_1+a_2+a_3)$

这样不好分析,我们将上式×2

$a_1+a1*(a_2+a_3+a_4)+a_2*(a_1+a_3+a_4)+a_3*(a_1+a_2+a_4)+a_4*(a_1+a_2+a_3)$

$a_1+\sum_{i=1}^{k} a_i*A-a_i^2$

 

忽略第一项,剩下的可以跟父亲抵消掉

最后只剩下root的A

也就是$n^2$

代码

#include<iostream> 
#include<cstdio>
#include<cstring>
using namespace std;
#define N 10000
int head[N],cost[N],to[N],nxt[N],cnt,dp[3010][3010][3],n,K,ind,sz[N];
void connect(int a,int b,int c)
{
	to[++cnt]=b,cost[cnt]=c,nxt[cnt]=head[a],head[a]=cnt;
	to[++cnt]=a,cost[cnt]=c,nxt[cnt]=head[b],head[b]=cnt;
}
void dfs(int id,int fa)
{
	sz[id]=1;
	dp[id][1][1]=dp[id][1][0]=0;
	//cout<<id<<" "<<fa<<endl;
	for(int i=head[id];i;i=nxt[i])
	{
		int t=to[i];
		if(t==fa)  continue;
		dfs(t,id);
		for(int j=sz[id];j>=0;j--)
		{
			for(int k=1;k<=sz[t];k++)
			{
				dp[id][j+k][0]=min(dp[id][j+k][0],dp[id][j][0]+dp[t][k][0]+cost[i]*2);
				dp[id][j+k][1]=min(dp[id][j+k][1],min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i]));
				int minn=min(min(dp[id][j][2]+dp[t][k][0]+cost[i]*2,dp[id][j][1]+dp[t][k][1]+cost[i]),dp[id][j][0]+dp[t][k][2]+cost[i]*2);
				dp[id][j+k][2]=min(dp[id][j+k][2],minn);
			}
		}
		sz[id]+=sz[t];
	}
	//cout<<id<<":\n";
//	for(int i=1;i<=sz[id];i++) printf("%d %d %d\n",dp[id][i][0],dp[id][i][1],dp[id][i][2]);
}
int main()
{
	cin>>n>>K;
	for(int i=1;i<n;i++)
	{
		int a,b,c;
		scanf("%d%d%d",&a,&b,&c);
		connect(a,b,c);
	}
	memset(dp,0x3f,sizeof(dp));
	dfs(1,0);
	int ans=0x7fffffff;
	for(int i=1;i<=n;i++) ans=min(ans,dp[i][K][2]);
	cout<<ans;
}

  

 

posted @ 2020-09-16 12:55  linzhuohang  阅读(178)  评论(0编辑  收藏  举报