JZOJ 6807. 【2020.10.29提高组模拟】tree(树上差分)
JZOJ 6807. 【2020.10.29提高组模拟】tree
题目大意
- 无根树上 N N N个点被染成 M M M中颜色,确定一个根,使得某个子树内包含所有的颜色且子树的根深度最大,求最大的深度。
- N ≤ 1 0 6 N\leq 10^6 N≤106, M ≤ 1 0 5 M\leq10^5 M≤105
题解
- 不难想到先可以钦定一个根,题中“子树”对应的是有根数中某个子树或整棵树除去某个子树的部分,
- 一种很暴力的想法是设 f i , j f_{i,j} fi,j表示以 i i i为根的子树中 j j j出现的次数,枚举儿子转移,
- 如果“子树”对应某个子树,那么需要满足该子树中所有颜色全都出现至少一次,“深度”为子树的根往上走的最长路径;
- 如果“子树”对应整棵树除去某个子树的部分,那么需要满足该子树内不能包含任意一种颜色的所有节点,“深度”为子树的根往下走的最长路径 + 1 +1 +1。
- 复杂度过高无法接受,还可以想到用线段树启发式合并来优化,但常数还是过大。
- 可以两种情况分开考虑:
- 第一种情况,不需要知道每种颜色分别出现了几次,而只需要知道出现的颜色种类数,考虑树上差分,分别把每种颜色的点标记 + 1 +1 +1,按DFS序相邻两点的LCA标记 − 1 -1 −1,这样可以在树上遍历时直接得出以每个点为根的子树中颜色种类数;
- 第二种情况,只要有一种颜色的所有节点被包含在子树中就不可行,分别把每种颜色所有节点的LCA标记上,这些点及以上的点为根的子树中一定包含至少一种颜色的所有节点,向上传递标记,标记过的点都不可行,其他的可行。
- 记得要先预处理每个点出发的最长路径,简单的树上DP。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define N 1000010
#define M 100010
int n, m, ans = 0;
int a[N], dp[N], dfn[N], mx[N][2], ms[N][2], f[N][20];
int last[N], nxt[N * 2], to[N * 2], len = 0;
int b[N], c[N], d[N], la[M], fi[M];
int cmp(int x, int y) {
return dfn[x] < dfn[y];
}
void add(int x, int y) {
to[++len] = y;
nxt[len] = last[x];
last[x] = len;
}
void dfs(int k, int fa) {
f[k][0] = fa;
for(int i = 1; i < 20; i++) f[k][i] = f[f[k][i - 1]][i - 1];
dp[k] = dp[fa] + 1;
dfn[k] = ++dfn[0];
mx[k][0] = mx[k][1] = -1;
ms[k][0] = ms[k][1] = 0;
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
int x = to[i];
dfs(x, k);
if(ms[x][0] + 1 >= ms[k][0]) mx[k][1] = mx[k][0], ms[k][1] = ms[k][0], mx[k][0] = x, ms[k][0] = ms[x][0] + 1;
else if(ms[x][0] + 1 > ms[k][1]) mx[k][1] = x, ms[k][1] = ms[x][0] + 1;
}
}
int lca(int x, int y) {
if(dp[x] < dp[y]) swap(x, y);
for(int i = 19; i >= 0; i--) if(dp[x] - (1 << i) >= dp[y]) x = f[x][i];
if(x == y) return x;
for(int i = 19; i >= 0; i--) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
void solve(int k, int fa, int t) {
int s = 0;
for(int i = last[k]; i; i = nxt[i]) if(to[i] != fa) {
if(mx[k][0] == to[i]) solve(to[i], k, max(t, ms[k][1]) + 1); else solve(to[i], k, max(t, ms[k][0]) + 1);
c[k] += c[to[i]];
d[k] += d[to[i]];
}
if(c[k] == m) ans = max(ans, t + 1);
if(d[k] == 0) ans = max(ans, ms[k][0] + 2);
}
int read() {
int s = 0;
char x = getchar();
while(x < '0' || x > '9') x = getchar();
while(x >= '0' && x <= '9') s = s * 10 + x - 48, x = getchar();
return s;
}
int main() {
int i, x, y;
n = read(), m = read();
for(i = 1; i <= n; i++) a[i] = read();
for(i = 1; i < n; i++) {
x = read(), y = read();
add(x, y), add(y, x);
}
dfs(1, 0);
for(i = 1; i <= n; i++) b[i] = i;
sort(b + 1, b + n + 1, cmp);
for(i = 1; i <= n; i++) {
if(la[a[b[i]]]) c[b[i]]++, c[lca(b[i], la[a[b[i]]])]--; else c[b[i]]++, fi[a[b[i]]] = b[i];
la[a[b[i]]] = b[i];
}
for(i = 1; i <= m; i++) d[lca(fi[i], la[i])]++;
solve(1, 0, 0);
printf("%d\n", ans);
return 0;
}
哈哈哈哈哈哈哈哈哈哈