[SDOI2017]苹果树

首先,观察题意,可以发现在最长链下再接一个点,结果一定更优。
也就是说,可以免费选一条最长链,之后正常选。
我们枚举选的最长链,然后算出剩下部分的最优解。
有4部分:
1、链上每个点都选一个。
2、链上剩下的部分。
3、链的左面。
4、链的右面。

1可以直接计算。
那么,我们需要先进行树形背包,然后再通过某方式将其余3个合并。
我们知道,在此问题中,合并2个背包是\(O(k)\)的;
但3个及以上则是\(O(k^2)\)的,无法承受。
所以,我们只能在计算中就把其中两个合并,这样就只需合并2个了。
可以发现,3和4是正常的树形背包,而2是一个贪心的问题。
但是,我们没有时间给链上的点排序,再贪心选择。
所以,只能将2转为正常的树形背包问题。
可以这样想:选x则要选fx,选x的第二个则要选x的第一个。
那么,我们可以把大于1的拆点,拆成1和a-1。连上父子关系。
不难发现,这样我们就只需要3,4两个合并了。所以,只要算出3,4部分,就能\(O(k)\)出解了。

先考虑如何进行树形背包:

树形背包有2种实现:dfs合并的和dfs序上dp的。
由于本题每个节点上有多个,所以之前的\(O(nk)\)的分析不适用,复杂度是\(O(nk^2)\),显然超时。
而且,复杂度的瓶颈合并背包至少是\(O(k^2)\)的,这个是max卷积,不能优化。
所以,这种方法不行。

但是,由于我们不需要知道每个节点的子节点的选择信息(如每个点选择了多少子树节点),
所以,可以考虑dfs序上dp的算法。见博客

这个算法的状态数是\(O(nk)\)的,且相当于逐渐添加,没有合并背包。
添加的过程是一个多重背包(由于每个节点上有多个),可以用单调队列优化。这部分可以做到\(O(nk)\)
而且,我们发现:对于问题3,4(即树链的左右),在dfs序上是一段连续区间。这意味着我们可以直接得出3,4的dp值。

对树做先序遍历,可以得到链的右面(后缀)。对树做后遍历,可以得到链的左面(前缀)。

总结下:
首先,拆点。
然后,对树进行先后序遍历,并用多重背包的单调队列优化算出dp值,\(O(nk)\)
最后,枚举一个叶子,在\(O(k)\)时间算出结果,总共\(O(nk)\)

此外,对于子节点,在3,4两部分都会被算到,要注意排除。
注意卡常。

代码:

#include <stdio.h> 
#define inf 999999999
#define setdp(i, j, x) dp[i * (k + 1) + j] = x
#define getdp(i, j) dp[(i) * (k + 1) + j]
#define getod(i, j) ld[(i) * (k + 1) + j] 
int fr[40010],ne[40010],v[40010],bs = 0,sl[40010],sz[40010],n,k;
void addb(int a, int b) {
	v[bs] = b;
	ne[bs] = fr[a];
	fr[a] = bs++;
}
int xl[40010],si[40010],jl[40010],tm = 0,x1[40010],x2[40010];
void dfs1(int u) {
	x1[u] = tm;
	xl[tm++] = u;
	si[u] = 1;
	for (int i = fr[u]; i != -1; i = ne[i]) {
		jl[v[i]] = jl[u] + sz[v[i]];
		dfs1(v[i]);
		si[u] += si[v[i]];
	}
}
void dfs2(int u) {
	for (int i = fr[u]; i != -1; i = ne[i]) dfs2(v[i]);
	xl[++tm] = u;
	x2[u] = tm;
}
int dl[500010],dz[500010],he = 0,ta = 0,dp[60000010],ld[60000010];
void insert(int i, int x) {
	dz[i] = x;
	while (he < ta && dz[dl[ta - 1]] <= x) ta -= 1;
	dl[ta++] = i;
}
void del(int i) {
	if (he < ta && dl[he] == i) he += 1;
}
int getma() {
	if (he < ta) return dz[dl[he]];
	else return - inf;
}
bool ez[20010],kz[20010];
int main() {
	int T;
	scanf("%d", &T);
	while (T--) {
		scanf("%d%d", &n, &k);
		bs = 0;
		for (int i = 1; i <= n + n; i++) fr[i] = -1;
		for (int i = 1; i <= n; i++) ez[i] = kz[i] = false;
		for (int i = 1; i <= n; i++) {
			int a;
			scanf("%d%d%d", &a, &sl[i], &sz[i]);
			if (i > 1) addb(a, i);
			ez[a] = true;
		}
		for (int i = 1; i <= n; i++) {
			if (sl[i] > 1) {
				sl[i + n] = sl[i] - 1;
				sz[i + n] = sz[i];
				addb(i, i + n);
				sl[i] = 1;
				kz[i] = true;
			}
		}
		jl[1] = sz[1];
		tm = 0;
		dfs1(1);
		for (int i = tm - 1; i >= 0; i--) {
			he = ta = 0;
			for (int j = 0; j <= k; j++) {
				int u = xl[i],
				ma = getdp(i + si[u], j);
				del(j - sl[u] - 1);
				if (j > 0) insert(j - 1, getdp(i + 1, j - 1) - sz[u] * (j - 1));
				int t = getma() + sz[u] * j;
				if (t > ma) ma = t;
				setdp(i, j, ma);
			}
		}
		for (int i = 0; i <= tm * (k + 1) + k; i++) {
			ld[i] = dp[i];
			dp[i] = 0;
		}
		tm = 0;
		dfs2(1);
		for (int i = 1; i <= tm; i++) {
			he = ta = 0;
			for (int j = 0; j <= k; j++) {
				int u = xl[i],
				ma = getdp(i - si[u], j);
				del(j - sl[u] - 1);
				if (j > 0) insert(j - 1, getdp(i - 1, j - 1) - sz[u] * (j - 1));
				int t = getma() + sz[u] * j;
				if (t > ma) ma = t;
				setdp(i, j, ma);
			}
		}
		int jg = -inf;
		for (int i = 1; i <= n; i++) {
			if (ez[i]) continue;
			int ma = -inf;
			for (int j = 0; j <= k; j++) {
				int t = getod(x1[i] + 1, j);
				if (!kz[i]) t += getdp(x2[i] - 1, k - j);
				else t += getdp(x2[i + n] - 1, k - j);
				if (t > ma) ma = t;
			}
			ma += jl[i];
			if (ma > jg) jg = ma;
		}
		printf("%d\n", jg);
		for (int i = 0; i <= tm * (k + 1) + k; i++) ld[i] = dp[i] = 0;
	}
	return 0;
}
posted @ 2019-09-13 22:08  lnzwz  阅读(205)  评论(0编辑  收藏  举报