Query on a tree VI [SP16549]
【题目描述】
给你一棵 \(n\) 个点的树,编号 \(1\sim n\)。每个点可以是黑色,可以是白色。初始时所有点都是黑色。下面有两种操作:
-
0 u
:询问有多少个节点 \(v\) 满足路径 \(u\) 到 \(v\) 上所有节点(包括 \(u\))都拥有相同的颜色。 -
1 u
:翻转 \(u\) 的颜色。
【输入/输出格式】
不关心
\(n,m\le 10^5\)
最近不知道为什么一直在敲数据结构。。。感觉要换换题型了
题解
随便找个点当根吧
如果有两个点\(u,v\)满足查询操作那个条件 我们就说\(u,v\)联通
注意到我们只需要维护一个点子树里有多少点和它联通
对于查询操作只需要找到深度最浅的和查询点联通的祖先就可以了
为了方便操作 我们让\(cnt[x][0]\)表示如果\(x\)是黑点 那么子树里有多少点和它联通 \(cnt[x][1]\)表示白点
那么对于修改操作 我们假设是把\(x\)从黑改成白
我们只需要找到那个深度最浅的和\(x\)联通的祖先\(p\) 然后把\(fa[p]\sim fa[x]\)这条链上所有点的\(cnt[i][0]\)减掉\(cnt[x][0]\)
然后更改\(x\)的颜色
再找到此时深度最浅的和\(x\)联通的祖先\(p2\)(注意\(x\)的颜色变了 所以和祖先的联通也已经变了) 把\(fa[p2]\sim fa[x]\)这条链上所有点的\(cnt[i][1]\)加上\(cnt[x][1]\)
因为\(x\)变白点之后子树里的黑点就不和外面联通了 而子树里的白点就会和外面联通
(实际上你会发现\(p\)和\(p2\)中有一个肯定就是\(x\) 因为\(x\)的父亲要么是白点要么是黑点 但是无所谓)
白改黑同理
区间修改树剖就可以了 问题在于如何快速找到深度最浅的和\(x\)联通的祖先?
还是树剖 线段树再维护一下区间内有多少个黑点白点 那么从\(fa[x]\)开始一直往上跳重链 如果整段都和\(x\)颜色一样就继续跳 否则一定可以线段树二分找到第一个和\(x\)颜色不一样的点
但是这里写起来就会比较麻烦。。。这题估计还是LCT简单点
时间复杂度\(O(n\log^2 n)\)
#include <bits/stdc++.h>
#define N 100005
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
return x * f;
}
int n, m, col[N];
int head[N], pre[N<<1], to[N<<1], sz;
int dfn[N], rnk[N], tme, d[N], siz[N], top[N], son[N], fa[N];
inline void addedge(int u, int v) {
pre[++sz] = head[u]; head[u] = sz; to[sz] = v;
pre[++sz] = head[v]; head[v] = sz; to[sz] = u;
}
void dfs(int x) {
siz[x] = 1;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa[x]) continue;
d[y] = d[x] + 1; fa[y] = x;
dfs(y);
siz[x] += siz[y];
if (!son[x] || siz[son[x]] < siz[y]) son[x] = y;
}
}
void dfs2(int x, int _top) {
top[x] = _top; dfn[x] = ++tme; rnk[tme] = x;
if (son[x]) dfs2(son[x], _top);
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
struct segtree{
int l, r, cnt[2], tag[2], sum[2]; //0:black 1:white
} tr[N<<2];
#define lson ind<<1
#define rson ind<<1|1
inline void pushup(int ind) {
tr[ind].cnt[0] = tr[lson].cnt[0] + tr[rson].cnt[0];
tr[ind].cnt[1] = tr[lson].cnt[1] + tr[rson].cnt[1];
tr[ind].sum[0] = tr[lson].sum[0] + tr[rson].sum[0];
tr[ind].sum[1] = tr[lson].sum[1] + tr[rson].sum[1];
}
void build(int ind, int l, int r) {
tr[ind].l = l; tr[ind].r = r; tr[ind].tag[0] = tr[ind].tag[1] = 0;
if (l == r) {
tr[ind].cnt[0] = siz[rnk[l]]; tr[ind].cnt[1] = 1;
tr[ind].sum[0] = 1; tr[ind].sum[1] = 0;
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid); build(rson, mid+1, r);
pushup(ind);
}
void pushdown(int ind) {
if (tr[ind].tag[0]) {
int v = tr[ind].tag[0]; tr[ind].tag[0] = 0;
tr[lson].cnt[0] += v; tr[lson].tag[0] += v;
tr[rson].cnt[0] += v; tr[rson].tag[0] += v;
}
if (tr[ind].tag[1]) {
int v = tr[ind].tag[1]; tr[ind].tag[1] = 0;
tr[lson].cnt[1] += v; tr[lson].tag[1] += v;
tr[rson].cnt[1] += v; tr[rson].tag[1] += v;
}
}
void update(int ind, int x, int y, int v, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (x <= l && r <= y) {
tr[ind].cnt[c] += (r - l + 1) * v; tr[ind].tag[c] += v;
return;
}
pushdown(ind);
int mid = (l + r) >> 1;
if (x <= mid) update(lson, x, y, v, c);
if (mid < y) update(rson, x, y, v, c);
pushup(ind);
}
int query(int ind, int pos, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) return tr[ind].cnt[c];
pushdown(ind);
int mid = (l + r) >> 1;
if (pos <= mid) return query(lson, pos, c);
else return query(rson, pos, c);
}
void change(int ind, int pos, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) {
tr[ind].sum[c^1] = 0;
tr[ind].sum[c] = 1;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) change(lson, pos, c);
else change(rson, pos, c);
pushup(ind);
}
int find(int ind, int x, int y, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) {
if (tr[ind].sum[c]) return l;
else return 0;
}
if (x <= l && r <= y) {
if (!tr[ind].sum[c]) return 0;
}
int mid = (l + r) >> 1;
if (mid >= y) return find(lson, x, y, c);
if (x > mid) return find(rson, x, y, c);
int ret = find(rson, x, y, c);
if (!ret) return find(lson, x, y, c);
else return ret;
}
void Update(int x) {
int c = col[x], tmp[2] = {query(1, dfn[x], 0), query(1, dfn[x], 1)};
col[x] ^= 1;
int xx = fa[x];
while (xx) { //边找边修改
int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
if (lst) {
update(1, lst, dfn[xx], -tmp[c], c);
break;
} else {
update(1, dfn[top[xx]], dfn[xx], -tmp[c], c);
xx = fa[top[xx]];
}
}
xx = fa[x];
while (xx) {
int lst = find(1, dfn[top[xx]], dfn[xx], c);
if (lst) {
update(1, lst, dfn[xx], tmp[c^1], c^1);
break;
} else {
update(1, dfn[top[xx]], dfn[xx], tmp[c^1], c^1);
xx = fa[top[xx]];
}
}
change(1, dfn[x], col[x]);
}
int Query(int x) {
int c = col[x], xx = fa[x], lstson = x;
while (xx) {
int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
if (lst) {
if (lst == dfn[xx]) {
return query(1, dfn[lstson], c);
} else {
return query(1, lst + 1, c);
}
}
lstson = top[xx];
xx = fa[top[xx]];
}
return query(1, dfn[1], c);
}
int main() {
n = read();
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
addedge(u, v);
}
dfs(1); dfs2(1, 1);
build(1, 1, n);
m = read();
for (int i = 1, tp, x; i <= m; i++) {
tp = read(), x = read();
if (!tp) {
printf("%d\n", Query(x));
} else {
Update(x);
}
}
return 0;
}