[BZOJ4817][Sdoi2017]树点涂色
[BZOJ4817][Sdoi2017]树点涂色
试题描述
Bob有一棵n个点的有根树,其中1号点是根节点。Bob在每个点上涂了颜色,并且每个点上的颜色不同。定义一条路
径的权值是:这条路径上的点(包括起点和终点)共有多少种不同的颜色。Bob可能会进行这几种操作:
1 x:
把点x到根节点的路径上所有的点染上一种没有用过的新颜色。
2 x y:
求x到y的路径的权值。
3 x:
在以x为根的子树中选择一个点,使得这个点到根节点的路径权值最大,求最大权值。
Bob一共会进行m次操作
输入
第一行两个数n,m。
接下来n-1行,每行两个数a,b,表示a与b之间有一条边。
接下来m行,表示操作,格式见题目描述
1<=n,m<=100000
输出
每当出现2,3操作,输出一行。
如果是2操作,输出一个数表示路径的权值
如果是3操作,输出一个数表示权值的最大值
输入示例
5 6 1 2 2 3 3 4 3 5 2 4 5 3 3 1 4 2 4 5 1 5 2 4 5
输出示例
3 4 2 2
数据规模及约定
见“输入”
题解
操作 1 即为 LCT 里面的 access 操作;对于询问,我们只需要知道每个点到根的路径上有多少条虚边就好了(令节点 x 到根的路径上虚边条数为 tot[x]):操作 2,查询路径 (u, v) 的话就是 tot[u] + tot[v] - 2tot[lca(u,v)](lca(u, v) 即节点 u 和 v 的最近公共祖先);操作 3,就是查询一个子树内最大的 tot。
所以可以用线段树维护 dfs 序,access 时删除或添加一条虚边对应子树集体 -1 或子树集体 +1(注意找 dfs 序中的区间时要找到 splay 中最靠左的节点,即深度最小的节点)。
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cctype> #include <algorithm> using namespace std; int read() { int x = 0, f = 1; char c = getchar(); while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); } while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); } return x * f; } #define maxn 100010 #define maxm 200010 #define maxlog 17 namespace TREE { int n, m, head[maxn], nxt[maxm], to[maxm]; void AddEdge(int a, int b) { to[++m] = b; nxt[m] = head[a]; head[a] = m; swap(a, b); to[++m] = b; nxt[m] = head[a]; head[a] = m; return ; } int fa[maxn][maxlog], dep[maxn], dl[maxn], dr[maxn], uid[maxn], clo; void build(int u) { uid[dl[u] = ++clo] = u; for(int i = 1; i < maxlog; i++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa[u][0]) { fa[to[e]][0] = u; dep[to[e]] = dep[u] + 1; build(to[e]); } dr[u] = clo; return ; } int lca(int a, int b) { if(dep[a] < dep[b]) swap(a, b); for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - (1 << i) >= dep[b]) a = fa[a][i]; for(int i = maxlog - 1; i >= 0; i--) if(fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i]; return a == b ? a : fa[b][0]; } } using namespace TREE; struct SEG { int maxv[maxn<<2], addv[maxn<<2]; SEG() { memset(addv, 0, sizeof(addv)); } void build(int o, int l, int r) { if(l == r) maxv[o] = dep[uid[l]]; //, printf("%d -> %d: %d\n", l, uid[l], maxv[o]); else { int mid = l + r >> 1, lc = o << 1, rc = lc | 1; build(lc, l, mid); build(rc, mid + 1, r); maxv[o] = max(maxv[lc], maxv[rc]); } return ; } void pushdown(int o, int l, int r) { if(!addv[o]) return ; if(l == r){ addv[o] = 0; return ; } int lc = o << 1, rc = lc | 1; addv[lc] += addv[o]; addv[rc] += addv[o]; maxv[lc] += addv[o]; maxv[rc] += addv[o]; addv[o] = 0; return ; } void update(int o, int l, int r, int ql, int qr, int v) { if(ql > qr || !ql || !qr) return ; pushdown(o, l, r); if(ql <= l && r <= qr) { addv[o] += v; maxv[o] += v; return ; } int mid = l + r >> 1, lc = o << 1, rc = lc | 1; if(ql <= mid) update(lc, l, mid, ql, qr, v); if(qr > mid) update(rc, mid + 1, r, ql, qr, v); maxv[o] = max(maxv[lc], maxv[rc]); return ; } int query(int o, int l, int r, int ql, int qr) { if(!ql || !qr) return 0; pushdown(o, l, r); if(ql <= l && r <= qr) return maxv[o]; int mid = l + r >> 1, lc = o << 1, rc = lc | 1, ans = 0; if(ql <= mid) ans = max(ans, query(lc, l, mid, ql, qr)); if(qr > mid) ans = max(ans, query(rc, mid + 1, r, ql, qr)); return ans; } } seg; struct LCT { int fa[maxn], ch[maxn][2], rt[maxn]; void init() { for(int i = 1; i <= n; i++) rt[i] = i, fa[i] = TREE::fa[i][0]; rt[0] = 0; // for(int i = 1; i <= n; i++) printf("%d%c", fa[i], i < n ? ' ' : '\n'); return ; } bool isrt(int u) { return !fa[u] || (ch[fa[u]][0] != u && ch[fa[u]][1] != u); } void maintain(int o) { if(!o) return ; if(ch[o][0]) rt[o] = rt[ch[o][0]]; else rt[o] = o; return ; } void rotate(int u) { int y = fa[u], z = fa[y], l = 0, r = 1; if(!isrt(y)) ch[z][ch[z][1]==y] = u; if(ch[y][1] == u) swap(l, r); fa[u] = z; fa[y] = u; fa[ch[u][r]] = y; ch[y][l] = ch[u][r]; ch[u][r] = y; maintain(y); maintain(u); return ; } void splay(int u) { while(!isrt(u)) { int y = fa[u], z = fa[y]; if(!isrt(y)) { if(ch[y][0] == u ^ ch[z][0] == y) rotate(u); else rotate(y); } rotate(u); } return ; } void access(int u) { splay(u); seg.update(1, 1, n, dl[rt[ch[u][1]]], dr[rt[ch[u][1]]], 1); /*printf("+1s: %d\n", rt[ch[u][1]]);*/ ch[u][1] = 0; maintain(u); while(fa[u]) { splay(fa[u]); seg.update(1, 1, n, dl[rt[ch[fa[u]][1]]], dr[rt[ch[fa[u]][1]]], 1); // printf("+1s: %d\n", rt[ch[fa[u]][1]]); seg.update(1, 1, n, dl[rt[u]], dr[rt[u]], -1); // printf("-1s: %d\n", rt[u]); ch[fa[u]][1] = u; maintain(fa[u]); splay(u); } return ; } } lct; int main() { n = read(); int q = read(); for(int i = 1; i < n; i++) { int a = read(), b = read(); AddEdge(a, b); } dep[1] = 0; build(1); seg.build(1, 1, n); lct.init(); while(q--) { int tp = read(), u = read(), v; if(tp == 1) lct.access(u); if(tp == 2) { v = read(); int c = lca(u, v); // printf("lca(%d, %d) = %d\n", u, v, c); printf("%d\n", seg.query(1, 1, n, dl[u], dl[u]) + seg.query(1, 1, n, dl[v], dl[v]) - (seg.query(1, 1, n, dl[c], dl[c]) << 1) + 1); } if(tp == 3) printf("%d\n", seg.query(1, 1, n, dl[u], dr[u]) + 1); } return 0; }