[HAOI2015][bzoj 4033]树上染色(树dp+复杂度分析)
【题目描述】
有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。问收益最大值是多少。
【输入格式】
第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。输入保证所有点之间是联通的。
【输出格式】
输出一个正整数,表示收益的最大值。
【输入样例1】
3 1
1 2 1
1 3 2
【输出样例1】
3
【输入样例2】
5 2
1 2 3
1 5 1
2 3 1
2 4 2
【输出样例2】
17
【样例解释】
在第二个样例中,将点1,2染黑就能获得最大收益。
【数据范围】
对于30%的数据,N<=20
对于50%的数据,N<=100
对于100%的数据,N<=2000,0<=K<=N。
题解
很显然是个树dp,然后问题来了,怎么设计状态,根据经验我们可以设计出$dp[i][j]$表示以$i$为根的子树中染了$j$个黑点的贡献,但是这里所说的贡献是对整体还是只考虑局部呢,答案是对整体的贡献,因为很显然这个题目中满足局部最优并不一定就满足整体最优,这样状态定义就结束了,我们在来考虑转移,我们设当前枚举的根节点为$x$,我们再来枚举他的子节点$y$,发现无法直接通过子节点转移到父节点,我们再考虑,是什么连接子节点和父节点,是边,那我们就可以通过边的贡献来转移,那么每条边的贡献就是边权×子树中黑点数×子树外黑点数+边权×子树中白点数×子树外白点数。这样就可以愉快的转移了
$dp[x][j+m]=\max{(dp[y][m]+dp[x][j]+val)}$
其中j枚举x子树大小,m枚举y子树大小,val就是边的贡献。
然后就是要倒着枚举(其实正着枚举也行,就是麻烦),因为如果单纯的正着枚举,会导致用更新过的来更新,而不是正常的转移,即你要用之前的子树大小和当前子树来更新父节点,但如果不做任何处理的正着枚举就会导致用已经更新到size的值来更新,就会不对。
还有就是转移要放到dfs循环里边,其实和上面一样,就是你只是用前面的size,而不是总的size,具体看代码里的注释吧
还有要注意的一点就是因为这题你把所有黑点和白点互换不会对答案产生影响,所以$k=\min{(k,n-k)}$。
一遍dfs即可。
这样做的复杂度时$O(n^2)$的,简单点说就是因为每个点最多只会和其他点乘一次。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include<cstdio> 2 using namespace std; 3 const int N=2005,M=4005,L=1<<15|1; 4 //#define int long long 5 #define rg register int 6 int first[N],nex[M],to[M],edge[M],v[N],size[N],tot,n,k;long long dp[N][N]; 7 char buf[L],*S,*T; 8 #define getchar() ((S==T&&(T=(S=buf)+fread(buf,1,L,stdin),S==T))?EOF:*S++) 9 inline int read(){ 10 rg ss=0;register char bb=getchar(); 11 while(bb<48||bb>57)bb=getchar(); 12 while(bb>=48&&bb<=57)ss=(ss<<1)+(ss<<3)+(bb^48),bb=getchar(); 13 return ss; 14 } 15 inline void add(rg a,rg b,rg c){ 16 to[++tot]=b,edge[tot]=c,nex[tot]=first[a],first[a]=tot; 17 } 18 inline int max(rg a,rg b){return a>b?a:b;} 19 inline int min(rg a,rg b){return a<b?a:b;} 20 inline long long Max(long long a,long long b){return a<b?b:a;} 21 void dfs(int x){ 22 int y; 23 size[x]=1;v[x]=1; 24 for(int i=first[x];i;i=nex[i]){ 25 if(v[y=to[i]]) continue; 26 dfs(y); 27 rg z=edge[i]; 28 rg qq=min(size[x],k),pp=min(size[y],k); 29 //倒序枚举 30 for(rg j=qq;j>=0;--j)//这块要放到里面,原因见blog 31 for(rg l=pp;l>=0;--l){ 32 long long del=1ll*z*1ll*l*1ll*(k-l)+1ll*z*1ll*(size[y]-l)*1ll*(n-k-(size[y]-l)); 33 dp[x][j+l]=Max(dp[x][j+l],dp[y][l]+dp[x][j]+del); 34 } 35 size[x]+=size[y]; 36 } 37 } 38 signed main(){ 39 n=read(),k=read(); 40 k=min(k,n-k); 41 for(rg i=1;i<n;++i){ 42 rg x=read(),y=read(),z=read(); 43 add(x,y,z); 44 add(y,x,z); 45 } 46 dfs(1); 47 printf("%lld",dp[1][k]); 48 }