[BZOJ 4817] [SDOI 2017] 树点涂色
Description
Bob有一棵 \(n\) 个点的有根树,其中 \(1\) 号点是根节点。Bob在每个点上涂了颜色,并且每个点上的颜色不同。
定义一条路径的权值是:这条路径上的点(包括起点和终点)共有多少种不同的颜色。
Bob可能会进行这几种操作:
1 x
:把点 \(x\) 到根节点的路径上所有的点染上一种没有用过的新颜色。2 x y
:求 \(x\) 到 \(y\) 的路径的权值。3 x
:在以 \(x\) 为根的子树中选择一个点,使得这个点到根节点的路径权值最大,求最大权值。
Bob一共会进行 \(m\) 次操作。
Input
第一行两个数 \(n,m\)。
接下来 \(n-1\) 行,每行两个数 \(a,b\),表示 \(a\) 与 \(b\) 之间有一条边。
接下来 \(m\) 行,表示操作,格式见题目描述。
Output
每当出现 \(2,3\) 操作,输出一行。
如果是 \(2\) 操作,输出一个数表示路径的权值。
如果是 \(3\) 操作,输出一个数表示权值的最大值。
Sample Input
5 6
1 2
2 3
3 4
3 5
2 4 5
3 3
1 4
2 4 5
1 5
2 4 5
Sample Output
3
4
2
2
HINT
共10个测试点
测试点1,\(1\leq n,m\leq1000\)。
测试点2、3,没有 \(2\) 操作。
测试点4、5,没有 \(3\) 操作。
测试点6,树的生成方式是,对于 \(i~(2\leq i \leq n)\),在 \(1\) 到 \(i-1\) 中随机选一个点作为 \(i\) 的父节点。
测试点7,\(1\leq n,m\leq 50000\)。
测试点8,\(1\leq n \leq 50000\)。
测试点9、10,无特殊限制。
对所有数据,\(1\leq n \leq 10^5\)。
Solution
LCT,实边连接的是相同的颜色,轻边连接不同的颜色,那么点 \(x\) 到根的路径的权值即为经过的轻边的数量 \(+1\),对于操作 \(1\),直接 \(access(x)\),实边 \((fa[x],x)\) 变为轻边时,将子树 \(x\) 的权值整体 \(+1\),轻边 \((fa[x],x)\) 变为实边时,将子树 \(x\) 的权值整体 \(-1\)。
注意我们需要将其子树更改的节点应为当前 \(splay\) 中深度最小的那个,而不是 \(splay\) 的根。
对于操作 \(2\),答案为 \(val[x]+val[y]-2\times val[lca(x,y)]+1\),其中 \(val[x]\) 为 \(x\) 到根的路径的权值,用线段树维护。
操作 \(3\) 就是在线段树上查询区间最大值。
注意LCT的 \(fa\) 和树链剖分的 \(fa\) 不要设成同一个数组。
Code
#include <cstdio>
#include <algorithm>
const int N = 100001;
struct Edge { int v, nxt; } e[N << 1];
int n, m, cnt, head[N], siz[N], dfn[N], fa[N], ff[N], dep[N], top[N], val[N], tot, son[N], mx[N << 2], tag[N << 2], ch[N][2];
int read() {
int x = 0; char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = (x << 3) + (x << 1) + (c ^ 48), c = getchar();
return x;
}
void adde(int u, int v) {
e[++tot].nxt = head[u], head[u] = tot, e[tot].v = v;
}
void dfs1(int u, int f) {
siz[u] = 1, fa[u] = ff[u] = f, dep[u] = dep[f] + 1;
for (int i = head[u]; i; i = e[i].nxt) if (e[i].v != f) {
dfs1(e[i].v, u), siz[u] += siz[e[i].v];
if (siz[e[i].v] > siz[son[u]]) son[u] = e[i].v;
}
}
void dfs2(int u, int t) {
dfn[u] = ++cnt, val[dfn[u]] = dep[u], top[u] = t;
if (son[u]) dfs2(son[u], t);
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].v != fa[u] && e[i].v != son[u]) dfs2(e[i].v, e[i].v);
}
int lca(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
else y = fa[top[y]];
}
return dep[x] < dep[y] ? x : y;
}
void pushdown(int cur) {
mx[cur << 1] += tag[cur], tag[cur << 1] += tag[cur];
mx[cur << 1 | 1] += tag[cur], tag[cur << 1 | 1] += tag[cur];
tag[cur] = 0;
}
void build(int cur, int l, int r) {
if (l == r) { mx[cur] = val[l]; return; }
int mid = (l + r) >> 1;
build(cur << 1, l, mid), build(cur << 1 | 1, mid + 1, r);
mx[cur] = std::max(mx[cur << 1], mx[cur << 1 | 1]);
}
void update(int cur, int l, int r, int L, int R, int x) {
if (L <= l && r <= R) { mx[cur] += x, tag[cur] += x; return; }
int mid = (l + r) >> 1;
if (tag[cur]) pushdown(cur);
if (L <= mid) update(cur << 1, l, mid, L, R, x);
if (mid < R) update(cur << 1 | 1, mid + 1, r, L, R, x);
mx[cur] = std::max(mx[cur << 1], mx[cur << 1 | 1]);
}
int query1(int cur, int l, int r, int p) {
if (l == r) return mx[cur];
int mid = (l + r) >> 1;
if (tag[cur]) pushdown(cur);
if (p <= mid) return query1(cur << 1, l, mid, p);
return query1(cur << 1 | 1, mid + 1, r, p);
}
int query2(int cur, int l, int r, int L, int R) {
if (L <= l && r <= R) return mx[cur];
int mid = (l + r) >> 1, res = 0;
if (tag[cur]) pushdown(cur);
if (L <= mid) res = query2(cur << 1, l, mid, L, R);
if (mid < R) res = std::max(res, query2(cur << 1 | 1, mid + 1, r, L, R));
return res;
}
int get(int x) {
return ch[ff[x]][1] == x;
}
bool isroot(int x) {
return ch[ff[x]][0] != x && ch[ff[x]][1] != x;
}
void rotate(int x) {
int y = ff[x], z = ff[y], k = get(x);
if (!isroot(y)) ch[z][ch[z][1] == y] = x;
ch[y][k] = ch[x][k ^ 1], ch[x][k ^ 1] = y;
ff[ch[y][k]] = y, ff[y] = x, ff[x] = z;
}
void splay(int x) {
for (int y; !isroot(x); rotate(x))
if (!isroot(y = ff[x])) rotate(get(x) ^ get(y) ? x : y);
}
int find(int x) {
while (ch[x][0]) x = ch[x][0];
return x;
}
void access(int x) {
for (int y = 0, z; x; y = x, x = ff[x]) {
splay(x), z = find(ch[x][1]);
if (z) update(1, 1, n, dfn[z], dfn[z] + siz[z] - 1, 1);
ch[x][1] = y, z = find(y);
if (z) update(1, 1, n, dfn[z], dfn[z] + siz[z] - 1, -1);
}
}
int main() {
n = read(), m = read();
for (int i = 1, u, v; i < n; ++i) u = read(), v = read(), adde(u, v), adde(v, u);
dfs1(1, 0), dfs2(1, 1), build(1, 1, n);
while (m--) {
int opt = read(), x = read(), y;
if (opt == 1) access(x);
else if (opt == 2) y = read(), printf("%d\n", query1(1, 1, n, dfn[x]) + query1(1, 1, n, dfn[y]) - (query1(1, 1, n, dfn[lca(x, y)]) << 1) + 1);
else printf("%d\n", query2(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
}
return 0;
}