[二分答案][树形dp] HDU 6769 In Search of Gold
2020 Multi-University Training Contest 2 (1007)
题目大意
给定一棵 \(n\) 个点的树 \((2\leq n\leq 20000)\),树的每一条边上有 \(a,b\) 两种权值。对于每一条边,请你合理地选择权值 \(a\) 或 \(b\),使得树的直径最小,要求权值 \(a\) 必须选择 \(k\) 次,权值 \(b\) 必须选择 \(n-k-1\) 次 \((k\leq 20)\)。
题解
考虑二分答案。每次二分到一个树的直径 \(x\),我们要去判定是否存在一种分配,使得树的直径不超过 \(x\)。
设 \(dp[u][j]\) 表示在以 \(u\) 为根的子树内,选择 \(j\) 条边使用 \(a\) 类边权,其余边使用 \(b\) 类边权,在子树内不产生长度大于 \(x\) 的树上路径的条件下,\(u\) 到其子树内最远点的距离的最小值。
那么最终若 \(dp[1][k]\leq x\),则满足要求。
dp的过程其实和树形dp求树的直径比较类似,稍微改一下就行了,详见代码。
Code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string>
#include <cstdio>
#include <vector>
using namespace std;
#define RG register int
#define LL long long
template<typename elemType>
inline void Read(elemType &T){
elemType X=0,w=0; char ch=0;
while(!isdigit(ch)) {w|=ch=='-';ch=getchar();}
while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
T=(w?-X:X);
}
struct Graph{
struct edge{int Next,to;LL a,b;};
edge G[40010];
int head[20010];
int cnt;
Graph():cnt(2){}
void clear(int node_num=0){
cnt=2;
if(node_num==0) memset(head,0,sizeof(head));
else fill(head,head+node_num+5,0);
}
void add_edge(int u,int v,LL a,LL b){
G[cnt].a=a;
G[cnt].b=b;
G[cnt].to=v;
G[cnt].Next=head[u];
head[u]=cnt++;
}
};
Graph G;
int Size[20010];
LL dp[20010][21],temp[100];
int T,N,K;
void DFS(int u,int fa,LL Len){
Size[u]=dp[u][0]=0;
for(int i=G.head[u];i;i=G.G[i].Next){
int v=G.G[i].to;
if(v==fa)continue;
DFS(v,u,Len);
int Num=min(Size[u]+Size[v]+1,K);
for(RG j=0;j<=Num;++j)
temp[j]=Len+1;
for(int j=0;j<=Size[u];++j){
for(int k=0;k<=Size[v] && j+k<=K;++k){
if(dp[u][j]+dp[v][k]+G.G[i].a<=Len)
temp[j+k+1]=min(temp[j+k+1],max(dp[u][j],dp[v][k]+G.G[i].a));
if(dp[u][j]+dp[v][k]+G.G[i].b<=Len)
temp[j+k]=min(temp[j+k],max(dp[u][j],dp[v][k]+G.G[i].b));
}
}
Size[u]=Num;
for(RG j=0;j<=Num;++j)
dp[u][j]=temp[j];
}
}
bool Judge(LL Len){
DFS(1,0,Len);
if(dp[1][K]<=Len) return true;
return false;
}
LL Solve(LL L,LL R){
LL Res=0;
while(L<=R){
LL mid=(L+R)>>1;
if(Judge(mid)){Res=mid;R=mid-1;}
else L=mid+1;
}
return Res;
}
int main(){
Read(T);
while(T--){
Read(N);Read(K);
G.clear(N);
LL Sum=0;
for(RG i=1;i<=N-1;++i){
int u,v;LL a,b;
Read(u);Read(v);
Read(a);Read(b);
G.add_edge(u,v,a,b);
G.add_edge(v,u,a,b);
Sum+=max(a,b);
}
printf("%lld\n",Solve(1,Sum));
}
return 0;
}