LOJ 3303 「联合省选 2020 A」树
一棵有根树,求每个子树的 val,val表示(子树内点到子树根节点距离+点权)的异或和。
显然如果 u->v 是一条边,\(s_v\) 表示 v 子树的所有值, \(s_u\) 的所有值是 \(s_v\) 每个值 +1 和 \(a_u\) 的并集。考虑如何维护 \(s_u\)。
考虑用权值线段树,但是全部 +1 这个操作在权值线段树上难以完成。转换为 trie 树。
考虑逆建 01trie 树,那么对于 trie 上所有数 +1,只需要把 01 互换,处理进位的时候往互换前 1 的方向走即可,整个与权值线段树上走一条链的复杂度一样,做动态线段树合并即可。
/*================================================================
*
* 创 建 者: badcw
* 创建日期: 2020/7/10 17:16
*
================================================================*/
#include <bits/stdc++.h>
#define VI vector<int>
#define ll long long
using namespace std;
const int maxn = 1e6+50;
const int mod = 1e9+7;
ll qp(ll a, ll n, ll mod = ::mod) {
ll res = 1;
while (n > 0) {
if (n & 1) res = res * a % mod;
a = a * a % mod;
n >>= 1;
}
return res;
}
template<class T> void _R(T &x) { cin >> x; }
void _R(int &x) { scanf("%d", &x); }
void _R(int64_t &x) { scanf("%lld", &x); }
void _R(double &x) { scanf("%lf", &x); }
void _R(char &x) { x = getchar(); }
void _R(char *x) { scanf("%s", x); }
void R() {}
template<class T, class... U> void R(T &head, U &... tail) { _R(head); R(tail...); }
template<class T> void _W(const T &x) { cout << x; }
void _W(const int &x) { printf("%d", x); }
void _W(const int64_t &x) { printf("%lld", x); }
void _W(const double &x) { printf("%.16f", x); }
void _W(const char &x) { putchar(x); }
void _W(const char *x) { printf("%s", x); }
template<class T,class U> void _W(const pair<T,U> &x) {_W(x.F); putchar(' '); _W(x.S);}
template<class T> void _W(const vector<T> &x) { for (auto i = x.begin(); i != x.end(); _W(*i++)) if (i != x.cbegin()) putchar(' '); }
void W() {}
template<class T, class... U> void W(const T &head, const U &... tail) { _W(head); putchar(sizeof...(tail) ? ' ' : '\n'); W(tail...); }
struct Node {
int val, sz, son[2];
} p[maxn << 4];
int rt[maxn];
int tot = 0;
void insert(int &k, int dep, int x) {
if (!k) k = ++tot;
if (dep > 21) return;
p[k].val = x;
p[k].sz = 1;
insert(p[k].son[x & 1], dep + 1, x >> 1);
}
void pushup(int x) {
p[x].val = ((p[p[x].son[0]].val ^ p[p[x].son[1]].val) << 1) | (p[p[x].son[1]].sz & 1);
p[x].sz = p[p[x].son[0]].sz + p[p[x].son[1]].sz;
}
int merge(int x, int y, int dep) {
if (!x || !y) return x | y;
if (dep > 21) return 0;
p[x].son[0] = merge(p[x].son[0], p[y].son[0], dep + 1);
p[x].son[1] = merge(p[x].son[1], p[y].son[1], dep + 1);
pushup(x);
if (!p[x].sz) x = 0;
return x;
}
vector<int> edge[maxn];
int a[maxn];
void rotate(int x) {
if (x == 0) return;
rotate(p[x].son[1]);
swap(p[x].son[0], p[x].son[1]);
pushup(x);
}
ll res = 0;
void dfs(int u) {
insert(rt[u], 0, a[u]);
for (auto v : edge[u]) {
dfs(v);
rotate(rt[v]);
merge(rt[u], rt[v], 0);
}
// cerr << u << " " << p[rt[u]].val << endl;
res += p[rt[u]].val;
}
int main(int argc, char* argv[]) {
int n;
R(n);
for (int i = 1; i <= n; ++i) R(a[i]);
for (int i = 1; i < n; ++i) {
int fa;
R(fa);
// cerr << "edge " << fa << " " << i + 1 << endl;
edge[fa].push_back(i + 1);
}
dfs(1);
W(res);
return 0;
}