[BZOJ4987] Tree
题目
从前有棵树。
题解
先考虑几个显而易见的性质:
1.选出的点一定是相邻的(不然距离会更大)
2.对于选出的点,如果从ak再走回a1,那么就相当于每条边经过了两次
由于题目没有包含dis(ak,a1),因此就相当于选出的点中的一条链可以只经过一次,其余的需要经过两次。
(如果包含的话就是个很水的树形背包)
就是说,在选点的同时还要选出只计算一次的那条链。
涉及到路径的一般考虑路径拼接的dp方式,
即,设一个状态用来记录当前已经决定了多少个路径的端点,即0,1,2三种取值
又因为要取k个,需要树形背包
结合起来就是:
设$f[i][j][k]$表示以i为根的子树中选择点i,共选出j条边,且包含的链端点数目为k的最小代价。
接下来讨论转移方程
对于每一个节点id,枚举每一个儿子t,在dfs完后合并
cost[i]表示id到t的距离
0的只能从0的来
$dp[id][j+k][0]=min(dp[id][j][0]+dp[t][k][0]+cost[i]*2)$
1的可以分类讨论端点是在已经枚举的子树内还是新加的子树内
$dp[id][j+k][1]=min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i])$
2的有(0,2),(1,1),(2,0)三种情况
$dp[id][j+k][2]=min($
$dp[id][j][2]+dp[t][k][0]+cost[i]*2,$
$dp[id][j][1]+dp[t][k][1]+cost[i],$
$dp[id][j][0]+dp[t][k][2]+cost[i]*2)$
注意要倒序枚举,因为每个物品只能选一个
时间复杂度
咋一看像是$n^3$的,其实是$n^2$
对于节点u,设其每个儿子子树大小为$a_k$,总大小为A
复杂度就是$a_1+a_2*a_1+a_3*(a_1+a_2)+a_4*(a_1+a_2+a_3)$
这样不好分析,我们将上式×2
$a_1+a1*(a_2+a_3+a_4)+a_2*(a_1+a_3+a_4)+a_3*(a_1+a_2+a_4)+a_4*(a_1+a_2+a_3)$
$a_1+\sum_{i=1}^{k} a_i*A-a_i^2$
忽略第一项,剩下的可以跟父亲抵消掉
最后只剩下root的A
也就是$n^2$
代码
#include<iostream> #include<cstdio> #include<cstring> using namespace std; #define N 10000 int head[N],cost[N],to[N],nxt[N],cnt,dp[3010][3010][3],n,K,ind,sz[N]; void connect(int a,int b,int c) { to[++cnt]=b,cost[cnt]=c,nxt[cnt]=head[a],head[a]=cnt; to[++cnt]=a,cost[cnt]=c,nxt[cnt]=head[b],head[b]=cnt; } void dfs(int id,int fa) { sz[id]=1; dp[id][1][1]=dp[id][1][0]=0; //cout<<id<<" "<<fa<<endl; for(int i=head[id];i;i=nxt[i]) { int t=to[i]; if(t==fa) continue; dfs(t,id); for(int j=sz[id];j>=0;j--) { for(int k=1;k<=sz[t];k++) { dp[id][j+k][0]=min(dp[id][j+k][0],dp[id][j][0]+dp[t][k][0]+cost[i]*2); dp[id][j+k][1]=min(dp[id][j+k][1],min(dp[id][j][1]+dp[t][k][0]+cost[i]*2,dp[id][j][0]+dp[t][k][1]+cost[i])); int minn=min(min(dp[id][j][2]+dp[t][k][0]+cost[i]*2,dp[id][j][1]+dp[t][k][1]+cost[i]),dp[id][j][0]+dp[t][k][2]+cost[i]*2); dp[id][j+k][2]=min(dp[id][j+k][2],minn); } } sz[id]+=sz[t]; } //cout<<id<<":\n"; // for(int i=1;i<=sz[id];i++) printf("%d %d %d\n",dp[id][i][0],dp[id][i][1],dp[id][i][2]); } int main() { cin>>n>>K; for(int i=1;i<n;i++) { int a,b,c; scanf("%d%d%d",&a,&b,&c); connect(a,b,c); } memset(dp,0x3f,sizeof(dp)); dfs(1,0); int ans=0x7fffffff; for(int i=1;i<=n;i++) ans=min(ans,dp[i][K][2]); cout<<ans; }