luogu4383 [八省联考2018]林克卡特树(带权二分+dp)
题目大意:给定你 n 个点的一棵树 (边有边权,边权有正负)
你需要移除 k 条边,并连接 k 条权值为 0 的边,使得连接之后树的直径最大
题解:
根据 [POI2015]MOD 那道题,显然我们应该找 k+1 条树上不相交的链,求这些链的长度之和最大值
k = 0 部分分:直接求树的直径
k = 1 部分分:把 MOD 那道题 或者点头网那个第一题粘过来就行了
k <= 100 部分分:我们也考虑树形dp
由于我们只需要把长度累加,我们考虑设 f[i][j][0/1/2] 代表以 i 为根的子树中出现了 j 条链, i 的度数为 0/1/2 的子树链的长度和的最大值 xjb 背包下就行, 复杂度 O (n * k * k)
k == 10000
我们设 \(f(x) = f[1][x][0]\), 即钦定链的数量为 x 时的 dp 结果,最后我们要求的答案就是\(f(k + 1)\)
根据莫名其妙的原因,\(f(x)\) 是单峰函数 并且有最大值
显然我们无法在 \(O(n)\) 的时间复杂度内求出这个值。
我们可以考虑给定一个斜率 \(k\),假设求这个斜率的直线与 \(f(x)\) 图像的切点的坐标很容易,我们就可以发现
假设切点是 \((x_0, f(x_0))\), 那么 \(k\) 和 \(x_0\) 之间是一个单调的关系。
根据斜率是单调的这一性质,我们可以考虑二分斜率,每次求出斜率为 mid 时对应的切点位置,如果切点在需要求的 \(k\) 的左边,我们考虑减小斜率,否则考虑增大斜率。
至于怎么求出切点坐标,我们发现 用斜率为 \(k\) 的直线去切曲线 \(f(x)\) 就相当于 用斜率为 0 的直线去切曲线 \(f(x) - kx\),也就是求 \(f(x) - kx\) 的极值点
我们考虑在 dp 时候解除第二维直径数量的限制,反而在 dp 的值中去维护取得最优解时的链的数量, 并且在每加入一条链时在答案贡献中 -= 二分的斜率 k, 最后得到的答案对应的链的数量即为 x, 答案 + k * x 即为 y
#include <cstdio>
#include <vector>
using namespace std;
int n, k;
vector<pair<int, int>> out[300010];
long long mid, ans, pos;
struct dat { int k; long long val; } f[300010][3];
bool operator<(const dat &a, const dat &b) { return a.val != b.val ? (a.val < b.val) : (a.k > b.k); }
dat operator+(const dat &a, const dat &b) { return (dat){a.k + b.k, a.val + b.val}; }
void dfs(int x, int fa)
{
f[x][0] = (dat){0, 0}, f[x][1] = (dat){0, 0}, f[x][2] = (dat){1, -mid};
for (pair<int, int> i : out[x]) if (i.first != fa)
{
int v = i.first; dfs(v, x);
f[x][2] = max(f[x][2] + f[v][0], f[x][1] + f[v][1] + (dat){1, i.second - mid});
f[x][1] = max(f[x][1] + f[v][0], f[x][0] + f[v][1] + (dat){0, i.second});
f[x][0] = f[x][0] + f[v][0];
}
f[x][0] = max(f[x][0], max(f[x][2], f[x][1] + (dat){1, -mid}));
}
long long check()
{
dfs(1, 0), ans = f[1][0].val;
return pos = f[1][0].k;
}
int main()
{
scanf("%d%d", &n, &k), k++;
for (int x, y, z, i = 1; i < n; i++)
{
scanf("%d%d%d", &x, &y, &z);
out[x].push_back(make_pair(y, z));
out[y].push_back(make_pair(x, z));
}
long long l = -1e12, r = -l;
while (l < r)
{
mid = (l + r) >> 1;
if (check() >= k) l = mid + 1;
else r = mid;
}
mid = l, check();
printf("%lld\n", ans + l * k);
return 0;
}