[BZOJ4817][Sdoi2017]树点涂色

[BZOJ4817][Sdoi2017]树点涂色

试题描述

Bob有一棵n个点的有根树,其中1号点是根节点。Bob在每个点上涂了颜色,并且每个点上的颜色不同。定义一条路
径的权值是:这条路径上的点(包括起点和终点)共有多少种不同的颜色。Bob可能会进行这几种操作:
1 x:
把点x到根节点的路径上所有的点染上一种没有用过的新颜色。
2 x y:
求x到y的路径的权值。
3 x:
在以x为根的子树中选择一个点,使得这个点到根节点的路径权值最大,求最大权值。
Bob一共会进行m次操作

输入

第一行两个数n,m。
接下来n-1行,每行两个数a,b,表示a与b之间有一条边。
接下来m行,表示操作,格式见题目描述
1<=n,m<=100000

输出

每当出现2,3操作,输出一行。
如果是2操作,输出一个数表示路径的权值
如果是3操作,输出一个数表示权值的最大值

输入示例

5 6
1 2
2 3
3 4
3 5
2 4 5
3 3
1 4
2 4 5
1 5
2 4 5

输出示例

3
4
2
2

数据规模及约定

见“输入

题解

操作 1 即为 LCT 里面的 access 操作;对于询问,我们只需要知道每个点到根的路径上有多少条虚边就好了(令节点 x 到根的路径上虚边条数为 tot[x]):操作 2,查询路径 (u, v) 的话就是 tot[u] + tot[v] - 2tot[lca(u,v)](lca(u, v) 即节点 u 和 v 的最近公共祖先);操作 3,就是查询一个子树内最大的 tot。

所以可以用线段树维护 dfs 序,access 时删除或添加一条虚边对应子树集体 -1 或子树集体 +1(注意找 dfs 序中的区间时要找到 splay 中最靠左的节点,即深度最小的节点)。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define maxlog 17

namespace TREE {
	int n, m, head[maxn], nxt[maxm], to[maxm];
	
	void AddEdge(int a, int b) {
		to[++m] = b; nxt[m] = head[a]; head[a] = m;
		swap(a, b);
		to[++m] = b; nxt[m] = head[a]; head[a] = m;
		return ;
	}
	
	int fa[maxn][maxlog], dep[maxn], dl[maxn], dr[maxn], uid[maxn], clo;
	void build(int u) {
		uid[dl[u] = ++clo] = u;
		for(int i = 1; i < maxlog; i++) fa[u][i] = fa[fa[u][i-1]][i-1];
		for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa[u][0]) {
			fa[to[e]][0] = u;
			dep[to[e]] = dep[u] + 1;
			build(to[e]);
		}
		dr[u] = clo;
		return ;
	}
	int lca(int a, int b) {
		if(dep[a] < dep[b]) swap(a, b);
		for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - (1 << i) >= dep[b]) a = fa[a][i];
		for(int i = maxlog - 1; i >= 0; i--) if(fa[a][i] != fa[b][i]) a = fa[a][i], b = fa[b][i];
		return a == b ? a : fa[b][0];
	}
}
using namespace TREE;

struct SEG {
	int maxv[maxn<<2], addv[maxn<<2];
	
	SEG() { memset(addv, 0, sizeof(addv)); }
	
	void build(int o, int l, int r) {
		if(l == r) maxv[o] = dep[uid[l]]; //, printf("%d -> %d: %d\n", l, uid[l], maxv[o]);
		else {
			int mid = l + r >> 1, lc = o << 1, rc = lc | 1;
			build(lc, l, mid); build(rc, mid + 1, r);
			maxv[o] = max(maxv[lc], maxv[rc]);
		}
		return ;
	}
	
	void pushdown(int o, int l, int r) {
		if(!addv[o]) return ;
		if(l == r){ addv[o] = 0; return ; }
		int lc = o << 1, rc = lc | 1;
		addv[lc] += addv[o]; addv[rc] += addv[o];
		maxv[lc] += addv[o]; maxv[rc] += addv[o];
		addv[o] = 0;
		return ;
	}
	void update(int o, int l, int r, int ql, int qr, int v) {
		if(ql > qr || !ql || !qr) return ;
		pushdown(o, l, r);
		if(ql <= l && r <= qr) {
			addv[o] += v; maxv[o] += v;
			return ;
		}
		int mid = l + r >> 1, lc = o << 1, rc = lc | 1;
		if(ql <= mid) update(lc, l, mid, ql, qr, v);
		if(qr > mid) update(rc, mid + 1, r, ql, qr, v);
		maxv[o] = max(maxv[lc], maxv[rc]);
		return ;
	}
	
	int query(int o, int l, int r, int ql, int qr) {
		if(!ql || !qr) return 0;
		pushdown(o, l, r);
		if(ql <= l && r <= qr) return maxv[o];
		int mid = l + r >> 1, lc = o << 1, rc = lc | 1, ans = 0;
		if(ql <= mid) ans = max(ans, query(lc, l, mid, ql, qr));
		if(qr > mid) ans = max(ans, query(rc, mid + 1, r, ql, qr));
		return ans;
	}
} seg;

struct LCT {
	int fa[maxn], ch[maxn][2], rt[maxn];
	
	void init() {
		for(int i = 1; i <= n; i++) rt[i] = i, fa[i] = TREE::fa[i][0];
		rt[0] = 0;
//		for(int i = 1; i <= n; i++) printf("%d%c", fa[i], i < n ? ' ' : '\n');
		return ;
	}
	
	bool isrt(int u) { return !fa[u] || (ch[fa[u]][0] != u && ch[fa[u]][1] != u); }
	void maintain(int o) {
		if(!o) return ;
		if(ch[o][0]) rt[o] = rt[ch[o][0]];
		else rt[o] = o;
		return ;
	}
	void rotate(int u) {
		int y = fa[u], z = fa[y], l = 0, r = 1;
		if(!isrt(y)) ch[z][ch[z][1]==y] = u;
		if(ch[y][1] == u) swap(l, r);
		fa[u] = z; fa[y] = u; fa[ch[u][r]] = y;
		ch[y][l] = ch[u][r]; ch[u][r] = y;
		maintain(y); maintain(u);
		return ;
	}
	void splay(int u) {
		while(!isrt(u)) {
			int y = fa[u], z = fa[y];
			if(!isrt(y)) {
				if(ch[y][0] == u ^ ch[z][0] == y) rotate(u);
				else rotate(y);
			}
			rotate(u);
		}
		return ;
	}
	void access(int u) {
		splay(u); seg.update(1, 1, n, dl[rt[ch[u][1]]], dr[rt[ch[u][1]]], 1); /*printf("+1s: %d\n", rt[ch[u][1]]);*/ ch[u][1] = 0; maintain(u);
		while(fa[u]) {
			splay(fa[u]);
			seg.update(1, 1, n, dl[rt[ch[fa[u]][1]]], dr[rt[ch[fa[u]][1]]], 1); // printf("+1s: %d\n", rt[ch[fa[u]][1]]);
			seg.update(1, 1, n, dl[rt[u]], dr[rt[u]], -1); // printf("-1s: %d\n", rt[u]);
			ch[fa[u]][1] = u;
			maintain(fa[u]);
			splay(u);
		}
		return ;
	}
} lct;

int main() {
	n = read(); int q = read();
	for(int i = 1; i < n; i++) {
		int a = read(), b = read();
		AddEdge(a, b);
	}
	dep[1] = 0; build(1);
	
	seg.build(1, 1, n);
	lct.init();
	while(q--) {
		int tp = read(), u = read(), v;
		if(tp == 1) lct.access(u);
		if(tp == 2) {
			v = read(); int c = lca(u, v);
//			printf("lca(%d, %d) = %d\n", u, v, c);
			printf("%d\n", seg.query(1, 1, n, dl[u], dl[u]) + seg.query(1, 1, n, dl[v], dl[v]) - (seg.query(1, 1, n, dl[c], dl[c]) << 1) + 1);
		}
		if(tp == 3) printf("%d\n", seg.query(1, 1, n, dl[u], dr[u]) + 1);
	}
	
	return 0;
}

 

posted @ 2017-04-19 19:47  xjr01  阅读(214)  评论(0编辑  收藏  举报