P4383 [八省联考2018]林克卡特树lct
题目链接
题意分析
一句话题意就是 : 让你选出\((k+1)\)条不相交的链 使得这些链的边权总和最大 (这些链可以是点)
我们考虑使用树形\(DP\)
\(dp[i][j][0/1/2]\)表示以\(i\)为根的子树选出\(j\)条链 并且\(j\)的度数是\(0/1/2\)的最大总和
那么我们使用树上背包进行转移
\[dp[u][j][0]=dp[u][j-p][0]+dp[v][p][0]
\]
\[dp[u][j][1]=max(dp[u][j-p][1]+dp[v][p][0],dp[u][j-p][0]+dp[v][p][1]+w[now])
\]
\[dp[u][j][2]=max(dp[u][j-p][2]+dp[v][p][0],dp[u][j-p][1]+dp[v][p-1][1]+w[i])
\]
但是这是妥妥的\(O(nk^2)\)
所以考虑优化 我们发现最终答案是\(dp[1][k][0]\)
也不知道为什么发现这是一个上凸函数
也就是\(f''(x)<0\)
所以我们考虑二分\(k\)所在点的斜率
那么该斜率的直线同该函数的且切点就是\((x,f(x))\)
怎么求? ? ?
\[y=mx+b
\]
\[f(x)=mx+b
\]
\[b=mx
\]
\[max(b)=max(f(x)-mx)
\]
我们二分出这个位置就可以了
CODE:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<cstdlib>
#include<string>
#include<queue>
#include<map>
#include<stack>
#include<list>
#include<set>
#include<deque>
#include<vector>
#include<ctime>
#define ll long long
#define inf 0x7fffffff
#define N 6000008
#define IL inline
#define M 1008611
#define D double
#define maxn 110
#define R register
using namespace std;
template<typename T>IL void read(T &_)
{
T __=0,___=1;char ____=getchar();
while(!isdigit(____)) {if(____=='-') ___=0;____=getchar();}
while(isdigit(____)) {__=(__<<1)+(__<<3)+____-'0';____=getchar();}
_=___ ? __:-__;
}
/*-------------OI使我快乐-------------*/
ll n,k,tot;
ll le,ri,ans;
ll to[N],nex[N],head[N],w[N];
struct Node{
ll cnt;ll val;
friend Node operator +(const Node &A,const Node &B)
{return (Node){A.cnt+B.cnt,A.val+B.val};}
friend bool operator <(const Node &A,const Node &B)
{return A.val==B.val ? A.cnt<B.cnt:A.val<B.val;}
}dp[N][3];
IL void add(ll x,ll y,ll z)
{to[++tot]=y;nex[tot]=head[x];head[x]=tot;w[tot]=z;}
IL void dfs(ll now,ll fat,ll mid)
{
dp[now][0]=dp[now][1]=(Node){0,0};dp[now][2]=(Node){1,-mid};
//这里一个点看做一条链
for(R ll i=head[now];i;i=nex[i])
{
ll v=to[i];
if(v==fat) continue;
dfs(v,now,mid);
dp[now][2]=max(dp[now][2]+dp[v][0],dp[now][1]+dp[v][1]+(Node){1,w[i]-mid});
dp[now][1]=max(dp[now][1]+dp[v][0],dp[now][0]+dp[v][1]+(Node){0,w[i]});
dp[now][0]=dp[now][0]+dp[v][0];
}
dp[now][0]=max(dp[now][0],max(dp[now][1]+(Node){1,-mid},dp[now][2]));
}
IL bool check(ll mid)
{
dfs(1,0,mid);
return dp[1][0].cnt>=k;
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(n);read(k);++k;
for(R ll i=1,x,y,z;i<n;++i)
{
read(x);read(y);read(z);
add(x,y,z);add(y,x,z);
}
le=-1e13;ri=1e13;
while(le<=ri)
{
ll mid=(le+ri)>>1;
if(check(mid)) le=mid+1,ans=mid;
else ri=mid-1;
}
check(ans);
printf("%lld\n",dp[1][0].val+ans*k);
// fclose(stdin);
// fclose(stdout);
return 0;
}