Loading

P10070 [CCO2023] Travelling Trader 题解

非常好题目,使我代码长度起飞。

思路

发现 \(K\) 只有三种取值。

考虑分类讨论。

k=1

容易发现只需要求一个端点是 \(1\) 的最长链。

k=3

考虑这个时候我们将有一个遍历整个树的方案。

考虑递归的处理整个问题。

我们从该节点跳到它一个儿子的儿子。

然后递归处理这个儿子的儿子。

然后再跳到该节点的这个儿子的另一个儿子。

然后递归处理。

将所有儿子的儿子处理完以后,在跳回这个儿子。

然后继续处理其他的儿子的儿子。

这样就可以简单找到遍历整个树的方案。

k=2

考虑 \(k=2\) 怎么做。

我们可以使用树形 dp。

设:

\(f_{x,0}\) 为从 \(x\) 出发往下走,对终止节点无要求的最大贡献。

\(f_{x,1}\) 为从 \(x\) 出发往下走,对终止节点要求为 \(x\) 的某个儿子或 \(x\) 的最大贡献。

\(f_{x,2}\) 为从 \(x\) 的某个儿子出发往下走,对终止节点无要求的最大贡献。

\(f_{x,3}\) 为从 \(x\) 的某个儿子出发往下走,对终止节点要求为 \(x\) 的最大贡献。

考虑转移式。

  1. \[f_{x,0}=a_x+f_{y,2} \]

表示先走到 \(x\),在直接从 \(y\) 往下走。

  1. \[f_{x,0}=a_x+f_{y1,3}+\sum_{y\not=y1,y2} a_y+f_{y2,0} \]

表示先走到 \(x\),在再把 \(y1\) 走一圈后回到 \(y1\),然后走它的兄弟,最后在某个兄弟往下走。

  1. \[f_{x,1}=a_x+f_{y1,3}+\sum_{y\not=y1}a_y \]

表示先走到 \(x\),在再把 \(y1\) 走一圈后回到 \(y1\),然后走它的兄弟。

  1. \[f_{x,2}=f_{x,0} \]

和情况一类似。

  1. \[f_{x,2}=\sum_{y\not=y1,y2} a_y+f_{y1,1}+a_x+f_{y2,2} \]

表示先走到 \(x\) 的一些儿子,然后走到 \(y1\) 这个儿子转一圈,然后回到 \(x\),然后从 \(y2\) 往下走。

  1. \[f_{x,2}=\sum_{y\not=y1,y2,y3} a_y+f_{y1,1}+a_x+f_{y2,3}+f_{y3,0} \]

表示先走到 \(x\) 的一些儿子,然后走到 \(y1\) 这个儿子转一圈,然后回到 \(x\),然后从 \(y2\) 往下走一圈,然后从 \(y3\) 往下走。

  1. \[f_{x,3}=\sum_{y\not=y1} a_y+f_{y1,1}+a_x \]

表示先走到 \(x\) 的一些儿子,然后走到 \(y1\) 这个儿子转一圈,然后回到 \(x\)

注意很重要的一点,在记录方案时,这些顺序时不能随便颠倒的,否则容易方案不合法。

容易发现以上所有 \(dp\) 式都可以线性解决。

时间复杂度:\(O(n)\)

Code

#include <bits/stdc++.h>
using namespace std;

#define x first
#define y second
#define int long long
#define mp(x, y) make_pair(x, y)
#define eb(...) emplace_back(__VA_ARGS__)
#define fro(i, x, y) for(int i = (x); i <= (y); i++)
#define pre(i, x, y) for(int i = (x); i >= (y); i--)
inline void JYFILE19();

typedef int64_t i64;
typedef pair<int, int> PII;

bool ST;
const int N = 2e5 + 10;
const int mod = 998244353;

int n, m, a[N], dp[N], fa[N], pre[N];
vector<int> to[N];

namespace subtask1 {
	inline void dfs(int now, int fa) {
		dp[now] = a[now];
		for(auto i : to[now]) {
			if(i == fa) continue;
			dfs(i, now);
			if(dp[i] > dp[pre[now]])
				pre[now] = i;
		}
		dp[now] += dp[pre[now]];
	}
	inline void Solve() {
		dfs(1, 0);
		vector<int> ans;
		int now = 1;
		while(now) ans.eb(now), now = pre[now];
		cout << dp[1] << "\n";
		cout << ans.size() << "\n";
		for(auto i : ans) cout << i << " ";
		cout << "\n";
	}
}
namespace subtask2 {
struct Node {
	int x, op;
	inline Node(int xx, int opx) {
		x = xx, op = opx;
	}
};
struct node {
	int num, id;
	inline bool operator<(const node &tmp) const {
		return num < tmp.num;
	}
} pr[N], sf[N];
int tp, stk[N], f[N][4], dp[N][2][2][2][4];
vector<Node> g[N][4];
inline void dfs(int now, int fa) {
	int sum = 0;
	vector<int> son;
	for(auto i : to[now])
		if(i != fa) son.eb(i);
	for(auto i : son)
		dfs(i, now), sum += a[i];
	tp = 0;
	for(auto i : son) stk[++tp] = i;
	if(tp == 0) {
		fro(i, 0, 3) {
			f[now][i] = a[now];
			g[now][i].eb(now, 4);
		}
		return;
	}
	{
		int idl = 0;
		for(auto i : son)
			if(f[i][2] > f[idl][2])
				idl = i;
		pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
		fro(i, 1, tp) {
			pr[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
			sf[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
		}
		fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
		pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
		int id = 0;
		auto get = [&](int x) {
			if(x == 0) return 0ll;
			return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][0] - a[stk[x]];
		};
		fro(i, 1, tp) if(get(id) <= get(i)) id = i;
		f[now][0] = a[now] + get(id) + sum;
		if(f[now][0] < a[now] + f[idl][2]) {
			f[now][0] = a[now] + f[idl][2];
			g[now][0].eb(now, 4);
			g[now][0].eb(idl, 2);
		}
		else {
			int id1 = max(pr[id - 1], sf[id + 1]).id;
			int id2 = stk[id];
			g[now][0].eb(now, 4);
			if(id1) g[now][0].eb(id1, 3);
			for(auto i : son)
				if(i != id1 && i != id2)
					g[now][0].eb(i, 4);
			if(id2) g[now][0].eb(id2, 0);
		}
	}
	{
		int id = 0;
		for(auto i : son)
			if(f[i][3] - a[i] > f[id][3] - a[id])
				id = i;
		f[now][1] = a[now] + f[id][3] - a[id] + sum;
		g[now][1].eb(now, 4);
		if(id) g[now][1].eb(id, 3);
		for(auto i : son) if(i != id)
			g[now][1].eb(i, 4);
	}
	{
		int num1 = f[now][0];
		pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
		fro(i, 1, tp) {
			pr[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
			sf[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
		}
		fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
		pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
		int id = 0;
		auto get = [&](int x) {
			if(x == 0) return 0ll;
			return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][2] - a[stk[x]];
		};
		fro(i, 1, tp) if(get(id) <= get(i)) id = i;
		int num2 = a[now] + get(id) + sum;
		fro(i, 0, tp) {
			fro(op1, 0, 1) {
				fro(op2, 0, 1) {
					fro(op3, 0, 1) {
						dp[i][op1][op2][op3][0] = -1e18;
						dp[i][op1][op2][op3][1] = 0;
						dp[i][op1][op2][op3][2] = 0;
						dp[i][op1][op2][op3][3] = 0;
					}
				}
			}
		}
		dp[0][0][0][0][0] = 0;
		fro(i, 1, tp) {
			fro(op1, 0, 1) { fro(op2, 0, 1) { fro(op3, 0, 1) {
				fro(k, 0, 3) dp[i][op1][op2][op3][k] = dp[i - 1][op1][op2][op3][k];
			}}}
			fro(op1, 0, 1) {
				fro(op2, 0, 1) {
					fro(op3, 0, 1) {
						int num = dp[i - 1][op1][op2][op3][0];
						int A = dp[i - 1][op1][op2][op3][1];
						int B = dp[i - 1][op1][op2][op3][2];
						int C = dp[i - 1][op1][op2][op3][3];
						if(op1 == 0) {
							if(dp[i][1][op2][op3][0] < num - a[stk[i]] + f[stk[i]][0]) {
								dp[i][1][op2][op3][0] = num - a[stk[i]] + f[stk[i]][0];
								dp[i][1][op2][op3][1] = stk[i];
								dp[i][1][op2][op3][2] = B;
								dp[i][1][op2][op3][3] = C;
							}
						}
						if(op2 == 0) {
							if(dp[i][op1][1][op3][0] < num - a[stk[i]] + f[stk[i]][1]) {
								dp[i][op1][1][op3][0] = num - a[stk[i]] + f[stk[i]][1];
								dp[i][op1][1][op3][1] = A;
								dp[i][op1][1][op3][2] = stk[i];
								dp[i][op1][1][op3][3] = C;
							}
						}
						if(op3 == 0) {
							if(dp[i][op1][op2][1][0] < num - a[stk[i]] + f[stk[i]][3]) {
								dp[i][op1][op2][1][0] = num - a[stk[i]] + f[stk[i]][3];
								dp[i][op1][op2][1][1] = A;
								dp[i][op1][op2][1][2] = B;
								dp[i][op1][op2][1][3] = stk[i];
							}
						}
					}
				}
			}
		}
		int num3 = 0, f1 = 0, f2 = 0, f3 = 0;
		fro(op1, 0, 1) {
			fro(op2, 0, 1) {
				fro(op3, 0, 1) {
					if(num3 < dp[tp][op1][op2][op3][0]) {
						num3 = dp[tp][op1][op2][op3][0];
						f1 = op1, f2 = op2, f3 = op3;
					}
				}
			}
		}
		num3 += sum + a[now];
		f[now][2] = max({num1, num2, num3});
		if(num1 >= num2 && num1 >= num3) {
			g[now][2] = g[now][0];
		}
		else if(num2 >= num1 && num2 >= num3) {
			int id1 = max(pr[id - 1], sf[id + 1]).id;
			int id2 = stk[id];
			for(auto i : son)
				if(i != id1 && i != id2)
					g[now][2].eb(i, 4);
			if(id1) g[now][2].eb(id1, 1);
			g[now][2].eb(now, 4);
			if(id2) g[now][2].eb(id2, 2);
		}
		else {
			int id1 = dp[tp][f1][f2][f3][1];
			int id2 = dp[tp][f1][f2][f3][2];
			int id3 = dp[tp][f1][f2][f3][3];
			for(auto i : son)
				if(i != id1 && i != id2 && i != id3)
					g[now][2].eb(i, 4);
			if(id2) g[now][2].eb(id2, 1);
			g[now][2].eb(now, 4);
			if(id3) g[now][2].eb(id3, 3);
			if(id1) g[now][2].eb(id1, 0);
		}
	}
	{
		int id = 0;
		for(auto i : son)
			if(f[i][1] - a[i] > f[id][1] - a[id])
				id = i;
		f[now][3] = a[now] + f[id][1] - a[id] + sum;
		for(auto i : son) if(i != id)
			g[now][3].eb(i, 4);
		if(id) g[now][3].eb(id, 1);
		g[now][3].eb(now, 4);
	}
}
vector<int> res;
inline void print(int x, int op) {
	for(auto i : g[x][op]) {
		if(i.op == 4) res.eb(i.x);
		else print(i.x, i.op);
	}
}
inline void Solve() {
	dfs(1, 0);
	int num = max({f[1][0], f[1][1]});
	fro(i, 0, 1) if(num == f[1][i]) { print(1, i); break; }
	cout << num <<"\n";
	cout << res.size() << "\n";
	for(auto i : res) cout << i << " ";
	cout << "\n";
}
}
namespace subtask3 {
	vector<int> ans;
	inline void dfs(int now) {
		for(auto i : to[now]) {
			if(i == fa[now]) continue;
			fa[i] = now, dfs(i);
		}
	}
	inline void calc(int now) {
		ans.eb(now);
		for(auto i : to[now]) {
			if(i == fa[now]) continue;
			for(auto j : to[i]) {
				if(j == fa[i]) continue;
				calc(j);
			}
			ans.eb(i);
		}
	}
	inline void Solve() {
		dfs(1), calc(1);
		int num = 0;
		fro(i, 1, n) num += a[i];
		cout << num << "\n";
		cout << ans.size() << "\n";
		for(auto i : ans) cout << i << " ";
		cout << "\n";
	}
}

signed main() {
	JYFILE19();
	cin >> n >> m;
	fro(i, 1, n - 1) {
		int x, y;
		cin >> x >> y;
		to[x].eb(y);
		to[y].eb(x);
	}
	fro(i, 1, n) cin >> a[i];
	if(m == 1) subtask1::Solve();
	if(m == 2) subtask2::Solve();
	if(m == 3) subtask3::Solve();
	return 0;
}

bool ED;
inline void JYFILE19() {
	// freopen("", "r", stdin);
	// freopen("", "w", stdout);
	ios::sync_with_stdio(0), cin.tie(0);
	double MIB = fabs((&ED-&ST)/1048576.), LIM = 1024;
	cerr << "MEMORY: " << MIB << endl, assert(MIB<=LIM);
}
posted @ 2024-02-15 22:09  JiaY19  阅读(64)  评论(0编辑  收藏  举报