Query on a tree VI [SP16549]

【题目描述】

给你一棵 \(n\) 个点的树,编号 \(1\sim n\)。每个点可以是黑色,可以是白色。初始时所有点都是黑色。下面有两种操作:

  • 0 u:询问有多少个节点 \(v\) 满足路径 \(u\)\(v\) 上所有节点(包括 \(u\))都拥有相同的颜色。

  • 1 u:翻转 \(u\) 的颜色。

【输入/输出格式】

不关心

\(n,m\le 10^5\)

最近不知道为什么一直在敲数据结构。。。感觉要换换题型了

题解

随便找个点当根吧

如果有两个点\(u,v\)满足查询操作那个条件 我们就说\(u,v\)联通

注意到我们只需要维护一个点子树里有多少点和它联通

对于查询操作只需要找到深度最浅的和查询点联通的祖先就可以了

为了方便操作 我们让\(cnt[x][0]\)表示如果\(x\)是黑点 那么子树里有多少点和它联通 \(cnt[x][1]\)表示白点

那么对于修改操作 我们假设是把\(x\)从黑改成白

我们只需要找到那个深度最浅的和\(x\)联通的祖先\(p\) 然后把\(fa[p]\sim fa[x]\)这条链上所有点的\(cnt[i][0]\)减掉\(cnt[x][0]\)
然后更改\(x\)的颜色
再找到此时深度最浅的和\(x\)联通的祖先\(p2\)(注意\(x\)的颜色变了 所以和祖先的联通也已经变了) 把\(fa[p2]\sim fa[x]\)这条链上所有点的\(cnt[i][1]\)加上\(cnt[x][1]\)

因为\(x\)变白点之后子树里的黑点就不和外面联通了 而子树里的白点就会和外面联通

(实际上你会发现\(p\)\(p2\)中有一个肯定就是\(x\) 因为\(x\)的父亲要么是白点要么是黑点 但是无所谓)

白改黑同理

区间修改树剖就可以了 问题在于如何快速找到深度最浅的和\(x\)联通的祖先?

还是树剖 线段树再维护一下区间内有多少个黑点白点 那么从\(fa[x]\)开始一直往上跳重链 如果整段都和\(x\)颜色一样就继续跳 否则一定可以线段树二分找到第一个和\(x\)颜色不一样的点

但是这里写起来就会比较麻烦。。。这题估计还是LCT简单点

时间复杂度\(O(n\log^2 n)\)

#include <bits/stdc++.h>
#define N 100005
using namespace std;

inline int read() {
	int x = 0, f = 1; char ch = getchar();
	for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
	for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
	return x * f;
}

int n, m, col[N];
int head[N], pre[N<<1], to[N<<1], sz;
int dfn[N], rnk[N], tme, d[N], siz[N], top[N], son[N], fa[N];

inline void addedge(int u, int v) {
	pre[++sz] = head[u]; head[u] = sz; to[sz] = v;
	pre[++sz] = head[v]; head[v] = sz; to[sz] = u;
}

void dfs(int x) {
	siz[x] = 1; 
	for (int i = head[x]; i; i = pre[i]) {
		int y = to[i];
		if (y == fa[x]) continue;
		d[y] = d[x] + 1; fa[y] = x;
		dfs(y);
		siz[x] += siz[y];
		if (!son[x] || siz[son[x]] < siz[y]) son[x] = y;
	}
}

void dfs2(int x, int _top) {
	top[x] = _top; dfn[x] = ++tme; rnk[tme] = x;
	if (son[x]) dfs2(son[x], _top);
	for (int i = head[x]; i; i = pre[i]) {
		int y = to[i];
		if (y == fa[x] || y == son[x]) continue;
		dfs2(y, y);
	}
}

struct segtree{
	int l, r, cnt[2], tag[2], sum[2]; //0:black 1:white 
} tr[N<<2];

#define lson ind<<1
#define rson ind<<1|1

inline void pushup(int ind) {
	tr[ind].cnt[0] = tr[lson].cnt[0] + tr[rson].cnt[0];
	tr[ind].cnt[1] = tr[lson].cnt[1] + tr[rson].cnt[1];
	tr[ind].sum[0] = tr[lson].sum[0] + tr[rson].sum[0];
	tr[ind].sum[1] = tr[lson].sum[1] + tr[rson].sum[1];
}

void build(int ind, int l, int r) {
	tr[ind].l = l; tr[ind].r = r; tr[ind].tag[0] = tr[ind].tag[1] = 0;
	if (l == r) {
		tr[ind].cnt[0] = siz[rnk[l]]; tr[ind].cnt[1] = 1;
		tr[ind].sum[0] = 1; tr[ind].sum[1] = 0;
		return;
	}
	int mid = (l + r) >> 1;
	build(lson, l, mid); build(rson, mid+1, r);
	pushup(ind);
}

void pushdown(int ind) {
	if (tr[ind].tag[0]) {
		int v = tr[ind].tag[0]; tr[ind].tag[0] = 0;
		tr[lson].cnt[0] += v; tr[lson].tag[0] += v;
		tr[rson].cnt[0] += v; tr[rson].tag[0] += v;
	}
	if (tr[ind].tag[1]) {
		int v = tr[ind].tag[1]; tr[ind].tag[1] = 0;
		tr[lson].cnt[1] += v; tr[lson].tag[1] += v;
		tr[rson].cnt[1] += v; tr[rson].tag[1] += v;
	}
}

void update(int ind, int x, int y, int v, int c) {
	int l = tr[ind].l, r = tr[ind].r;
	if (x <= l && r <= y) {
		tr[ind].cnt[c] += (r - l + 1) * v; tr[ind].tag[c] += v;
		return;
	}
	pushdown(ind);
	int mid = (l + r) >> 1;
	if (x <= mid) update(lson, x, y, v, c);
	if (mid < y) update(rson, x, y, v, c);
	pushup(ind);
}

int query(int ind, int pos, int c) {
	int l = tr[ind].l, r = tr[ind].r;
	if (l == r) return tr[ind].cnt[c];
	pushdown(ind);
	int mid = (l + r) >> 1;
	if (pos <= mid) return query(lson, pos, c);
	else return query(rson, pos, c);
}

void change(int ind, int pos, int c) {
	int l = tr[ind].l, r = tr[ind].r;
	if (l == r) {
		tr[ind].sum[c^1] = 0;
		tr[ind].sum[c] = 1;
		return;
	}
	int mid = (l + r) >> 1;
	if (pos <= mid) change(lson, pos, c);
	else change(rson, pos, c);
	pushup(ind);
}

int find(int ind, int x, int y, int c) {
	int l = tr[ind].l, r = tr[ind].r;
	if (l == r) {
		if (tr[ind].sum[c]) return l;
		else return 0;
	}
	if (x <= l && r <= y) {
		if (!tr[ind].sum[c]) return 0;
	}
	int mid = (l + r) >> 1;
	if (mid >= y) return find(lson, x, y, c);
	if (x > mid) return find(rson, x, y, c);
	int ret = find(rson, x, y, c);
	if (!ret) return find(lson, x, y, c);
	else return ret;
}

void Update(int x) {
	int c = col[x], tmp[2] = {query(1, dfn[x], 0), query(1, dfn[x], 1)}; 
	col[x] ^= 1;
	int xx = fa[x];
	while (xx) { //边找边修改
		int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
		if (lst) {
			update(1, lst, dfn[xx], -tmp[c], c);
			break;
		} else {
			update(1, dfn[top[xx]], dfn[xx], -tmp[c], c);
			xx = fa[top[xx]];
		}
	}
	xx = fa[x];
	while (xx) {
		int lst = find(1, dfn[top[xx]], dfn[xx], c);
		if (lst) {
			update(1, lst, dfn[xx], tmp[c^1], c^1);
			break;
		} else {
			update(1, dfn[top[xx]], dfn[xx], tmp[c^1], c^1);
			xx = fa[top[xx]];
		}
	}
	change(1, dfn[x], col[x]);
}

int Query(int x) {
	int c = col[x], xx = fa[x], lstson = x;
	while (xx) {
		int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
		if (lst) {
			if (lst == dfn[xx]) {
				return query(1, dfn[lstson], c);
			} else {
				return query(1, lst + 1, c);
			}
		}
		lstson = top[xx];
		xx = fa[top[xx]];
	}
	return query(1, dfn[1], c); 
}

int main() {
	n = read(); 
	for (int i = 1, u, v; i < n; i++) {
		u = read(), v = read();
		addedge(u, v);
	}
	dfs(1); dfs2(1, 1);
	build(1, 1, n);
	m = read();
	for (int i = 1, tp, x; i <= m; i++) {
		tp = read(), x = read();
		if (!tp) {
			printf("%d\n", Query(x));
		} else {
			Update(x);
		}
	}
	return 0;
} 
posted @ 2020-04-26 17:56  AK_DREAM  阅读(190)  评论(0编辑  收藏  举报