[Ynoi2017]由乃的OJ
可以发现因为运算只有位运算,因此每一位是互不影响的,因此我们可以将每一位拉出来单独考虑。你会发现对于每一位而言,不论在何时都只有两种取值,并且一旦我们知道了这一位进入这条链的取值,最后出来的答案也是确定的。因此我们可以考虑直接维护出 \(0 / 1\) 进入这段链的取值,这个我们可以直接通过 \(LCT\) 来维护。具体来讲,对于 \(Splay\) 上的每个节点 \(i\) 维护出 \(l_i, r_i\),分别表示以 \(0 / 1\) 进入以 \(i\) 为根的 \(\rm Splay\) 中的点组成的链(从深度浅的点到深度深的点),那么就有合并(令 \(f_{i, 0}, f_{i, 1}\) 分别表示以 \(0 / 1\) 进入 \(i\) 的答案):
\[l_i = (l_{ls} \& r_{rs}) | ((\sim l_{ls}) \& l_{rs})
\]
\[r_i = (r_{ls} \& r_{rs}) | ((\sim r_{ls}) \& l_{rs})
\]
实现的时候因为 \(LCT\) 有 \(Makeroot\) 操作,因此还要记录反向进入的答案。然后最终我们 \(Split\) 取出这条链,从高到低位贪心去选择即可,于此同时记录一个 \(up\) 表示有没有达到上界。这样复杂度就能做到 \(O(nk \log n)\)。
你会发现这个做法的瓶颈在于我们进行了拆位考虑,但你会发现我们只在乎每一位从 \(0 / 1\) 进入出来后的答案,又没有什么办法能让所有位置通过一次运算就知道 \(0 / 1\) 进入的答案呢?不难发现因为所有位是不会互相影响的,于是我们直接拿所有全 \(0\) 的二进制串和全 \(1\) 的二进制串进去跑即可,这样就可以一次跑出所有位置以 \(0 / 1\) 进入的答案了。因为每位之间不会互相影响,所以合并方式还是和按位考虑的时候一致。这样复杂度就是 \(O(n \times \max\{k, \log n\})\) 了。
一些坑点:
- 注意以后写 \(LCT\) 时,\(pushdown\) 一定是交换儿子的信息,而自己的信息在打标记时就要交换。因为 \(Splay\) 时会可能用到儿子的信息。
#include<bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define dep(i, l, r) for(int i = r; i >= l; --i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
typedef unsigned long long ull;
const int N = 100000 + 5;
struct edge{
int v, next;
}e[N << 1];
struct node{
ull l, r;
node operator + (const node &b){
node a;
a.l = (~l & b.l) | (l & b.r), a.r = (~r & b.l) | (r & b.r);
return a;
}
}res, f[N][2], val[N][2];
ull x, ans, limit;
int n, m, k, u, v, up, tot, opt, top, h[N], st[N], fa[N], tag[N], ch[N][2];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void add(int u, int v){
e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].next = h[v], h[v] = tot;
}
void dfs(int u, int Fa){
fa[u] = Fa;
Next(i, u) if(e[i].v != Fa) dfs(e[i].v, u);
}
int which(int x){
return (ch[fa[x]][1] == x);
}
int isroot(int x){
return ((ch[fa[x]][0] != x) && (ch[fa[x]][1] != x));
}
void update(int x){
val[x][0].l = 0, val[x][0].r = limit, val[x][1] = val[x][0];
if(ch[x][0]) val[x][0] = val[ch[x][0]][0];
if(ch[x][1]) val[x][1] = val[ch[x][1]][1];
val[x][0] = val[x][0] + f[x][0], val[x][1] = val[x][1] + f[x][1];
if(ch[x][1]) val[x][0] = val[x][0] + val[ch[x][1]][0];
if(ch[x][0]) val[x][1] = val[x][1] + val[ch[x][0]][1];
}
void down(int x){
if(!tag[x]) return;
tag[x] = 0, tag[ch[x][0]] ^= 1, tag[ch[x][1]] ^= 1, swap(ch[x][0], ch[x][1]);
swap(val[ch[x][0]][0], val[ch[x][0]][1]), swap(val[ch[x][1]][0], val[ch[x][1]][1]); // 就是指的这里
}
void rotate(int x){
int y = fa[x], z = fa[y], k = which(x), w = ch[x][k ^ 1];
fa[w] = y, ch[y][k] = w;
fa[x] = z; if(!isroot(y)) ch[z][which(y)] = x;
fa[y] = x, ch[x][k ^ 1] = y;
update(y), update(x);
}
void Splay(int x){
int cur = x; st[++top] = x;
while(!isroot(cur)) st[++top] = fa[cur], cur = fa[cur];
while(top) down(st[top--]);
while(!isroot(x)){
int y = fa[x];
if(!isroot(y)){
if(which(x) == which(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
}
void Access(int x){
for(int y = 0; x; y = x, x = fa[x]) Splay(x), ch[x][1] = y, update(x);
}
void Makeroot(int x){
Access(x), Splay(x), tag[x] ^= 1, swap(val[x][0], val[x][1]); // 以及这里
}
void Split(int x, int y){
Makeroot(x), Access(y), Splay(y);
}
signed main(){
n = read(), m = read(), k = read();
rep(i, 0, k - 1) limit += (1ull << i);
rep(i, 1, n){
opt = read(), scanf("%llu", &x);
if(opt == 1) f[i][0].l = (0ull & x), f[i][0].r = (limit & x), f[i][1] = f[i][0];
if(opt == 2) f[i][0].l = (0ull | x), f[i][0].r = (limit | x), f[i][1] = f[i][0];
if(opt == 3) f[i][0].l = (0ull ^ x), f[i][0].r = (limit ^ x), f[i][1] = f[i][0];
val[i][0] = f[i][0], val[i][1] = f[i][1];
}
rep(i, 1, n - 1) u = read(), v = read(), add(u, v);
dfs(1, 0);
while(m--){
opt = read(), u = read(), v = read(), scanf("%llu", &x);
if(opt == 1){
Split(u, v), res = val[v][0], up = 1, ans = 0;
dep(i, 0, k - 1){
if((1ull << i) <= x){
if(res.l & (1ull << i)) up = (up & (!(x & (1ull << i)))), ans += (1ull << i);
else if(res.r & (1ull << i)){
if(!up || (x & (1ull << i))) ans += (1ull << i);
}
else if(x & (1ull << i)) up = 0;
}
else if(res.l & (1ull << i)) ans += (1ull << i);
}
printf("%llu\n", ans);
}
else{
Splay(u);
if(v == 1) f[u][0].l = (0ull & x), f[u][0].r = (limit & x), f[u][1] = f[u][0];
if(v == 2) f[u][0].l = (0ull | x), f[u][0].r = (limit | x), f[u][1] = f[u][0];
if(v == 3) f[u][0].l = (0ull ^ x), f[u][0].r = (limit ^ x), f[u][1] = f[u][0];
update(u);
}
}
return 0;
}