luogu4383 [八省联考2018]林克卡特树(带权二分+dp)

link

题目大意:给定你 n 个点的一棵树 (边有边权,边权有正负)

你需要移除 k 条边,并连接 k 条权值为 0 的边,使得连接之后树的直径最大

题解:

根据 [POI2015]MOD 那道题,显然我们应该找 k+1 条树上不相交的链,求这些链的长度之和最大值

k = 0 部分分:直接求树的直径

k = 1 部分分:把 MOD 那道题 或者点头网那个第一题粘过来就行了

k <= 100 部分分:我们也考虑树形dp

由于我们只需要把长度累加,我们考虑设 f[i][j][0/1/2] 代表以 i 为根的子树中出现了 j 条链, i 的度数为 0/1/2 的子树链的长度和的最大值 xjb 背包下就行, 复杂度 O (n * k * k)

k == 10000

我们设 \(f(x) = f[1][x][0]\), 即钦定链的数量为 x 时的 dp 结果,最后我们要求的答案就是\(f(k + 1)\)

根据莫名其妙的原因,\(f(x)\) 是单峰函数 并且有最大值

显然我们无法在 \(O(n)\) 的时间复杂度内求出这个值。

我们可以考虑给定一个斜率 \(k\),假设求这个斜率的直线与 \(f(x)\) 图像的切点的坐标很容易,我们就可以发现

假设切点是 \((x_0, f(x_0))\), 那么 \(k\)\(x_0\) 之间是一个单调的关系。

根据斜率是单调的这一性质,我们可以考虑二分斜率,每次求出斜率为 mid 时对应的切点位置,如果切点在需要求的 \(k\) 的左边,我们考虑减小斜率,否则考虑增大斜率。

至于怎么求出切点坐标,我们发现 用斜率为 \(k\) 的直线去切曲线 \(f(x)\) 就相当于 用斜率为 0 的直线去切曲线 \(f(x) - kx\),也就是求 \(f(x) - kx\) 的极值点

我们考虑在 dp 时候解除第二维直径数量的限制,反而在 dp 的值中去维护取得最优解时的链的数量, 并且在每加入一条链时在答案贡献中 -= 二分的斜率 k, 最后得到的答案对应的链的数量即为 x, 答案 + k * x 即为 y

#include <cstdio>
#include <vector>
using namespace std;

int n, k;
vector<pair<int, int>> out[300010];
long long mid, ans, pos;
struct dat { int k; long long val; } f[300010][3];

bool operator<(const dat &a, const dat &b) { return a.val != b.val ? (a.val < b.val) : (a.k > b.k); }
dat operator+(const dat &a, const dat &b) { return (dat){a.k + b.k, a.val + b.val}; }

void dfs(int x, int fa)
{
	f[x][0] = (dat){0, 0}, f[x][1] = (dat){0, 0}, f[x][2] = (dat){1, -mid};
	for (pair<int, int> i : out[x]) if (i.first != fa)
	{
		int v = i.first; dfs(v, x);
		f[x][2] = max(f[x][2] + f[v][0], f[x][1] + f[v][1] + (dat){1, i.second - mid});
		f[x][1] = max(f[x][1] + f[v][0], f[x][0] + f[v][1] + (dat){0, i.second});
		f[x][0] = f[x][0] + f[v][0];
	}
	f[x][0] = max(f[x][0], max(f[x][2], f[x][1] + (dat){1, -mid}));
}

long long check()
{
	dfs(1, 0), ans = f[1][0].val;
	return pos = f[1][0].k;
}

int main()
{
	scanf("%d%d", &n, &k), k++;
	for (int x, y, z, i = 1; i < n; i++)
	{
		scanf("%d%d%d", &x, &y, &z);
		out[x].push_back(make_pair(y, z));
		out[y].push_back(make_pair(x, z));
	}
	long long l = -1e12, r = -l;
	while (l < r)
	{
		mid = (l + r) >> 1;
		if (check() >= k) l = mid + 1;
		else r = mid;
	}
	mid = l, check();
	printf("%lld\n", ans + l * k);
	return 0;
}
posted @ 2019-03-22 20:42  ghj1222  阅读(127)  评论(0编辑  收藏  举报