4383 [八省联考 2018] 林克卡特树(WQS 二分+DP)
给定一颗 \(n\) 个点的树,每条边有边权 \(v(|v|\le 10^6)\),要求删去其中任意 \(k\) 条边,使得剩余联通块的直径之和最大。求出这个最大值。
\(0\le k<n\le 3\times 10^5,10s,1GB\)。
问题是怎么求直径?!直径不就是最大的链吗?!
发现原问题等价于选择 \(k+1\) 条互不相交的链使得链的总价值最大。
设 \(dp_{i,j,0/1/2}\) 表示到 \(i\) 子树,已经选择了 \(j\) 条链,当前根上的选择情况分别为当前根不和父亲合并、向下连接一条链、向下连接形成两条链的情况时最大收益。
在加入一棵子树的时候合并答案:
在出子树的时候将 \(dp_{x,1},dp_{x,2}\) 都合并到 \(dp_{x,0}\) 表示 \(x\) 不和父亲合并的情况。
那么这样可以 \(\mathcal{O(nk)}\) 转移。
看题解,可以了解到这个状态和 \(k\) 的关系是:随 \(k\) 增大,答案先增大后减小。
马后炮 yy 一下发现其实比较容易感性理解,太小了就没得选大边,太大了就不得不因为链不能重合舍弃大边。
那么在这个以选择链的数量为横坐标,当前最大收益为纵坐标的二维 DP 上,我们要得出以 \(k\) 为横坐标的点的答案。
如果我们直接去掉上面 DP 的 \(j\) 一维,我们能够得到对于所有 \(k\) 的收益最大值,可以通过二分斜率来找到 \(k\) 的值。
二分选择一条链的额外代价,去掉 DP 中“已经选择了几条链”的那一维,直接记录选择链的最大权值之和。同时需要维护一个计数器数组和 DP 一起转移统计已经选择了多少条链。
那么如果选择的链多了就增大选择一条链的额外代价,否则减少,知道刚好等于 \(k\),那么最终权值就是 \(val+k\times e\),其中 \(e\) 是额外代价。
这样最终求出的最大值加上 \(e\times k\) 就是答案了。
其中有几个地方需要注意:
- 二分时,如果答案选择的链的数量 \(\ge k\) 则更新 \(ans\),因为链的数量为 \(n\) 总是能够取到,而数量极少则不一定能取到。
- 横坐标为 \(k\) 的点可能和 \(k-1,k+1\) 构成直线,不一定能够准确二分到 \(k\),所以最终答案要加上\(e\times k\) 而不是当前选择链的条数。
- 二分的值域要到 \([-n\times 10^6,n\times 10^6]\)。
#define Maxn 300005
#define int long long
int n,k,tot,curmuti;
int hea[Maxn],nex[Maxn<<1],ver[Maxn<<1],edg[Maxn<<1];
struct NODE
{
int val,hav;
NODE(int _val=0,int _hav=0):val(_val),hav(_hav){};
inline bool friend operator < (NODE x,NODE y)
{ return (x.val!=y.val)?x.val<y.val:x.hav<y.hav; }
inline NODE friend operator + (NODE x,NODE y)
{ return NODE(x.val+y.val,x.hav+y.hav); }
};
NODE dp[Maxn][3];
inline void add(int x,int y,int d){ ver[++tot]=y,nex[tot]=hea[x],hea[x]=tot,edg[tot]=d; }
void dfs(int x,int fa)
{
dp[x][2]=max(dp[x][2],NODE(-curmuti,1));
for(int i=hea[x];i;i=nex[i]) if(ver[i]!=fa)
{
dfs(ver[i],x);
dp[x][2]=max(
dp[x][2]+dp[ver[i]][0],
dp[x][1]+dp[ver[i]][1]+NODE(edg[i]-curmuti,1));
dp[x][1]=max(
dp[x][1]+dp[ver[i]][0],
dp[x][0]+dp[ver[i]][1]+NODE(edg[i],0));
dp[x][0]=dp[x][0]+dp[ver[i]][0];
}
dp[x][0]=max(dp[x][0],max(dp[x][1]+NODE(-curmuti,1),dp[x][2]));
}
signed main()
{
n=rd(),k=rd()+1;
for(int i=1,x,y,d;i<n;i++) x=rd(),y=rd(),d=rd(),add(x,y,d),add(y,x,d);
int nl=-n*1000000,nr=n*1000000,ret=0;
while(nl<=nr)
{
int mid=(nl+nr)>>1;
memset(dp,0,sizeof(dp)),curmuti=mid,dfs(1,0);
if(dp[1][0].hav>=k) ret=mid,nl=mid+1;
else nr=mid-1;
}
memset(dp,0,sizeof(dp)),curmuti=ret,dfs(1,0);
printf("%lld\n",dp[1][0].val+ret*k);
return 0;
}