Loading

P9352 [JOI 2023 Final] Cat Exercise (树形 dp+trick+并查集)

P9352 [JOI 2023 Final] Cat Exercise

upd 2024.11.15

树形 dp+trick+并查集

第一眼是:这是个无根树,猫移动的位置不会遵循以某固定根的树向子树跳,所以觉得这题有点复杂。

考虑到这题不是很贪心,那从 dp 思考。怎么 dp 没有后效性?但是又发现跳的高度是递减的。按照这个 dp 就有无后效性了。

\(f_u\) 表示从 \(u\) 开始跳最多跳跃次数。写出转移需要知道儿子连通块的最大值,并查集可以维护。

若我们以当前猫在的位置 \(u\) 为根,那么猫的下一步移动就会走到其中一个子树中。因为猫只有在我们把障碍放到当前的位置时才会移动,所以一定无法回到 \(u\) 点。要指定进入某个子树,只需要把其他子树都堵住即可。

考虑树形 dp。设 \(f_u\) 表示\(u\) 为根(猫所在位置)的最大移动次数。由于猫的终点无法确定,而起点一定在最高点,所以我们考虑从终点转移到起点。那么就考虑猫咪下一步会在哪些位置。设 \(pos_u\) 表示 \(u\) 子树中的最高点,那么有:

\(f_u=\max(f_u,f_{pos_v}+dist(u,pos_v))(a_u>a_{pos_v})\)

这样做最后答案就是最高点的 \(f_u\)

我们直接钦定一个根是无法满足无后效性的,\(u\) 的转移是对于所有子树(包括其父亲所构成的子树),类似于换根 dp,但是不考虑它。考虑怎样满足无后效性。我们发现我们需要找到一个新的偏序,看到转移的时候一定有 \(a_u>a_{pos_v}\),所以我们按照 \(a\) 从小到大转移就可以。

每个点权都不同,所以考虑一个 trick,连边 \((a_u,a_v)\),这样子构成的新树与原树结构上相同,本质上只是编号变为点权,\(dist\) 数据结构维护。考虑如何找到 \(pos_v\),可以用并查集维护当前枚举过的点构成的连通块,枚举小根合并在大根上即可。

复杂度 \(O(n\log n)\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define pb push_back

typedef long long i64;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 200010;
int n;
i64 dep[N], anc[N][20], dp[N], fa[N], a[N];
std::vector<int> e[N];
void dfs(int u, int fa) {
	dep[u] = dep[fa] + 1;
	anc[u][0] = fa;
	for(int i = 1; i <= 19; i++) anc[u][i] = anc[anc[u][i - 1]][i - 1];
	for(auto v : e[u]) {
		if(v == fa) continue;
		dfs(v, u);
	}
}
int lca(int u, int v) {
	if(dep[u] < dep[v]) std::swap(u, v);
	for(int i = 19; i >= 0; i--) if(dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if(u == v) return u;
	for(int i = 19; i >= 0; i--) if(anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
int dist(int u, int v) {
	int rt = lca(u, v);
	return dep[u] + dep[v] - 2 * dep[rt];
}
int find(int x) {
	if(x != fa[x]) fa[x] = find(fa[x]);
	return fa[x];
}
void Solve() {
	std::cin >> n;
	for(int i = 1; i <= n; i++) {
		fa[i] = i;
		std::cin >> a[i];
	} 
	for(int i = 1; i < n; i++) {
		int u, v;
		std::cin >> u >> v;
		e[a[u]].pb(a[v]), e[a[v]].pb(a[u]);
	}
	dfs(1, 0);
	for(int u = 1; u <= n; u++) {
		for(auto v : e[u]) {
			int fv = find(v);
			if(fv < u) fa[fv] = u, dp[u] = std::max(dp[u], dp[fv] + dist(u, fv));
		}
	}
	std::cout << dp[n] << "\n";
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
	Solve();

	return 0;
}
posted @ 2024-04-11 21:13  Fire_Raku  阅读(14)  评论(0编辑  收藏  举报