JZOJ 6807. 【2020.10.29提高组模拟】tree(树上差分)

JZOJ 6807. 【2020.10.29提高组模拟】tree

题目大意

  • 无根树上 N N N个点被染成 M M M中颜色,确定一个根,使得某个子树内包含所有的颜色且子树的根深度最大,求最大的深度。
  • N ≤ 1 0 6 N\leq 10^6 N106 M ≤ 1 0 5 M\leq10^5 M105

题解

  • 不难想到先可以钦定一个根,题中“子树”对应的是有根数中某个子树或整棵树除去某个子树的部分,
  • 一种很暴力的想法是设 f i , j f_{i,j} fi,j表示以 i i i为根的子树中 j j j出现的次数,枚举儿子转移,
  • 如果“子树”对应某个子树,那么需要满足该子树中所有颜色全都出现至少一次,“深度”为子树的根往上走的最长路径;
  • 如果“子树”对应整棵树除去某个子树的部分,那么需要满足该子树内不能包含任意一种颜色的所有节点,“深度”为子树的根往下走的最长路径 + 1 +1 +1
  • 复杂度过高无法接受,还可以想到用线段树启发式合并来优化,但常数还是过大。
  • 可以两种情况分开考虑:
  • 第一种情况,不需要知道每种颜色分别出现了几次,而只需要知道出现的颜色种类数,考虑树上差分,分别把每种颜色的点标记 + 1 +1 +1,按DFS序相邻两点的LCA标记 − 1 -1 1,这样可以在树上遍历时直接得出以每个点为根的子树中颜色种类数;
  • 第二种情况,只要有一种颜色的所有节点被包含在子树中就不可行,分别把每种颜色所有节点的LCA标记上,这些点及以上的点为根的子树中一定包含至少一种颜色的所有节点,向上传递标记,标记过的点都不可行,其他的可行。
  • 记得要先预处理每个点出发的最长路径,简单的树上DP。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 1000010
#define M 100010
int n, m, ans = 0;
int a[N], dp[N], dfn[N], mx[N][2], ms[N][2], f[N][20];
int last[N], nxt[N * 2], to[N * 2], len = 0;
int b[N], c[N], d[N], la[M], fi[M];
int cmp(int x, int y) {
	return dfn[x] < dfn[y];
}
void add(int x, int y) {
	to[++len] = y;
	nxt[len] = last[x];
	last[x] = len;
}
void dfs(int k, int fa) {
	f[k][0] = fa;
	for(int i = 1; i < 20; i++) f[k][i] = f[f[k][i - 1]][i - 1];
	dp[k] = dp[fa] + 1;
	dfn[k] = ++dfn[0];
	mx[k][0] = mx[k][1] = -1;
	ms[k][0] = ms[k][1] = 0;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
		int x = to[i];
		dfs(x, k);
		if(ms[x][0] + 1 >= ms[k][0]) mx[k][1] = mx[k][0], ms[k][1] = ms[k][0], mx[k][0] = x, ms[k][0] = ms[x][0] + 1;
		else if(ms[x][0] + 1 > ms[k][1]) mx[k][1] = x, ms[k][1] = ms[x][0] + 1;
	}
}
int lca(int x, int y) {
	if(dp[x] < dp[y]) swap(x, y);
	for(int i = 19; i >= 0; i--) if(dp[x] - (1 << i) >= dp[y]) x = f[x][i];
	if(x == y) return x;
	for(int i = 19; i >= 0; i--) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}
void solve(int k, int fa, int t) {
	int s = 0;
	for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
		if(mx[k][0] == to[i]) solve(to[i], k, max(t, ms[k][1]) + 1); else solve(to[i], k, max(t, ms[k][0]) + 1);
		c[k] += c[to[i]];
		d[k] += d[to[i]];
	}
	if(c[k] == m) ans = max(ans, t + 1);
	if(d[k] == 0) ans = max(ans, ms[k][0] + 2);
}
int read() {
	int s = 0;
	char x = getchar();
	while(x < '0' || x > '9') x = getchar();
	while(x >= '0' && x <= '9') s = s * 10 + x - 48, x = getchar();
	return s;
}
int main() {
	int i, x, y;
	n = read(), m = read();
	for(i = 1; i <= n; i++) a[i] = read();
	for(i = 1; i < n; i++) {
		x = read(), y = read();
		add(x, y), add(y, x);
	}
	dfs(1, 0);
	for(i = 1; i <= n; i++) b[i] = i;
	sort(b + 1, b + n + 1, cmp);
	for(i = 1; i <= n; i++) {
		if(la[a[b[i]]]) c[b[i]]++, c[lca(b[i], la[a[b[i]]])]--; else c[b[i]]++, fi[a[b[i]]] = b[i];
		la[a[b[i]]] = b[i];
	}
	for(i = 1; i <= m; i++) d[lca(fi[i], la[i])]++;
	solve(1, 0, 0);
	printf("%d\n", ans);
	return 0;
}
posted @ 2020-11-01 16:29  AnAn_119  阅读(135)  评论(0编辑  收藏  举报