2019陕西省赛-E Tree
Tree
树链剖分 + 线段树
除去区间修改的难点,其他都是树链剖分基础
线段树维护异或区间和,用且、或修改区间单点,如果单考虑一个值的话似乎很难修改,因此考虑将一个数字分割成 32 位,硬跑 32 次
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 1e5 + 10;
int tr[35][maxn << 2], lazy[35][maxn << 2], num[maxn];
int dep[maxn], siz[maxn], hson[maxn], fa[maxn];
int dfn[maxn], rnk[maxn], top[maxn];
vector<int>gra[maxn];
void dfs1(int now, int pre, int d)
{
dep[now] = d;
hson[now] = -1;
fa[now] = pre;
siz[now] = 1;
for(auto nex : gra[now])
{
if(nex == pre) continue;
dfs1(nex, now, d + 1);
siz[now] += siz[nex];
if(hson[now] == -1 || siz[hson[now]] < siz[nex])
hson[now] = nex;
}
}
int tp = 0;
void dfs2(int now, int t)
{
top[now] = t;
tp++;
dfn[now] = tp;
rnk[tp] = now;
if(hson[now] != -1)
{
dfs2(hson[now], t);
for(auto nex : gra[now])
{
if(nex == hson[now] || nex == fa[now]) continue;
dfs2(nex, nex);
}
}
}
void build(int now, int l, int r)
{
if(l == r)
{
int nex = rnk[l];
for(int i=0; i<32 && num[nex]; i++)
{
tr[i][now] = num[nex] & 1;
num[nex] >>= 1;
}
return;
}
int mid = l + r >> 1;
build(now << 1, l, mid);
build(now << 1 | 1, mid + 1, r);
for(int i=0; i<32; i++) tr[i][now] = tr[i][now << 1] + tr[i][now << 1 | 1];
}
int bit = 0;
inline void push_down(int now, int l, int r)
{
if(lazy[bit][now] == 2)
{
tr[bit][now << 1] = tr[bit][now << 1 | 1] = 0;
lazy[bit][now << 1] = lazy[bit][now << 1 | 1] = 2;
}
else if(lazy[bit][now] == 1)
{
int mid = l + r >> 1;
tr[bit][now << 1] = mid - l + 1;
tr[bit][now << 1 | 1] = r - mid;
lazy[bit][now << 1] = lazy[bit][now << 1 | 1] = 1;
}
lazy[bit][now] = 0;
}
void update(int now, int l, int r, int L, int R, int t)
{
if(L <= l && r <= R)
{
if(t == 2) tr[bit][now] = 0;
else tr[bit][now] = r - l + 1;
lazy[bit][now] = t;
return;
}
push_down(now, l, r);
int mid = l + r >> 1;
if(L <= mid)
update(now << 1, l, mid, L, R, t);
if(R > mid)
update(now << 1 | 1, mid + 1, r, L, R, t);
tr[bit][now] = tr[bit][now << 1] + tr[bit][now << 1 | 1];
}
int query(int now, int l, int r, int L, int R)
{
if(L <= l && r <= R)
return tr[bit][now];
push_down(now, l, r);
int mid = l + r >> 1, ans = 0;
if(L <= mid)
ans += query(now << 1, l, mid, L, R);
if(R > mid)
ans += query(now << 1 | 1, mid + 1, r, L, R);
return ans;
}
void init(int n, int rt = 1)
{
tp = 0;
dfs1(rt, rt, 1);
dfs2(rt, rt);
build(1, 1, n);
}
bool solve(int n, int a, int t)
{
bit = 0;
for(int i=0; i<32; i++, bit++)
{
int now = a, b = t & 1;
t >>= 1;
while(top[now] != 1)
{
b += query(1, 1, n, dfn[top[now]], dfn[now]);
now = fa[top[now]];
}
b += query(1, 1, n, 1, dfn[now]);
if(b & 1) return true;
}
return false;
}
void update_a(int n, int s, int t, int op)
{
int way = 1;
if(op == 2) way = 0;
bit = 0;
for(; bit < 32; t >>= 1, bit++)
{
if(way != (t & 1)) continue;
int a = s;
while(top[a] != 1)
{
update(1, 1, n, dfn[top[a]], dfn[a], op);
a = fa[top[a]];
}
update(1, 1, n, 1, dfn[a], op);
}
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for(int i=1; i<=n; i++) scanf("%d", &num[i]);
for(int i=1; i<n; i++)
{
int x, y;
scanf("%d%d", &x, &y);
gra[y].push_back(x);
gra[x].push_back(y);
}
init(n);
while(m--)
{
int op, s, t;
scanf("%d%d%d", &op, &s, &t);
if(op == 3)
{
if(solve(n, s, t))
printf("YES\n");
else
printf("NO\n");
}
else
update_a(n, s, t, op);
}
return 0;
}