"华为智联杯"无线程序设计大赛 B题 异或和之和题解
题目链接:CF1 或者 CF2
讲解链接:b站
可以先看b站教学,再来详细地读题解。询问本质上就是所以简单路径的和,简单路径的权值为点权异或。
不带修
先考虑不带修怎么做。从点分治的方向来看,我们常常需要维护从分治中心出发的 链信息。这个链信息常常我们用桶来保存,比如存储每个前缀异或值出现的次数。显然直接去做异或贡献计算较为困难,因为我们拿到了一个分治中心的所有前缀异或链信息,不知道相互之间该如何匹配而不是暴力匹配。那么我们转化一下,拆位以后考虑每一位的贡献,那么每一位就只有 \(0\) 和 \(1\),考虑前缀异或为 \(1\) 和 \(0\) 的两个桶,这样就方便分讨匹对了。由于异或涉及到拼起来时,会把根异或掉,因为是这是点权链异或。所以我考虑的是,每个前缀异或都不包含当前的分治中心,这样一来就可以如下分讨:
- 我们第一个是需要分讨 拼链 情况:
关于点分治,我们知道对于当前的分治中心会将剩余的所有路径分为两类,无论哪一类都需要通过分治中心。第一类则为 拼链,拼链表示,将 不同子树 的两条链进行匹配,形成一条路径,即 \((u,v)\) 路径,\(u\) 和 \(v\) 属于两个不同的子树。那么分讨一下每一位,对于当前位,如果分治中心为 \(0\),那么显然 \(u\) 和 \(v\) 对应的前缀异或情况为:\((1,1)\) 或 \((0,0)\)。如果为 \(1\),则是 \((0,1)\)。
- 我们第二个分讨的则是 独立链 情况:
意思就是,\((u,v)\) 其中一个为分治中心,另一个为其他子树,即从分治中心出发的一条链。那么分讨一下,如果分治中心为:\(x\),那么另一条前缀异或为:\(x \oplus 1\)。
那么其实不带修的情况,我们可以枚举每棵子树分别算两种贡献。
关于点分治和点分树的常见容斥套路:
-
维护以每个分治中心出发的链信息。
-
维护对于每个分治中心来说,每个子树方向的链信息。
其中第二点如果照字面意思来看,我们需要维护一个 \(hash\),表式 \(center\Rightarrow son\) 方向的桶信息,这个显然是常数过大的。我们常见写法基于,一个父亲有多个儿子,但一个儿子只有一个父亲,我们这个映射可以通过儿子来找父亲。但又因为我们枚举贡献时,其实都是枚举分治中心,并不是相邻的儿子节点,所以这个映射我们可以改为从儿子分治中心到父亲分治中心,这样一来 \(cnt[son]\),就表示 \(fa[son]\) 在 \(son\) 方向的子树信息,注意这里的 \(son\) 是分治中心之间的关系,并不是原树之间的子树关系。
那么维护好这两个东西以后,我们就可以使用父节点的所有信息减去某一棵子树方向上的信息,从而拿到除开这棵子树外的所有其他子树的信息,这样可以保证当前子树只会和其他子树进行匹配,而不会和自身匹配出不合法的路径信息。
注意到我们枚举了每棵子树和其他子树匹配信息,这个过程算了两次,比如前面编号 \(1\) 和 \(3\) 在枚举第一棵子树时算过了,又在枚举第三棵子树时重新算了一次,所以拼链结果记得除以 \(2\),而独立链则不会,直接加上去即可。
带修
先考虑建出点分树作为带修的贡献修改。我们知道一旦一个分治中心被修改了,会影响:
-
以当前分治中心的贡献。
-
这个分治中心所有父节点分治中心关于当前分治中心的贡献。
原理:修改了一个点,所有经过这个点的路径都会被更改,那么我们知道从当前这个点作为的分治中心往下的儿子节点时,这个点已经不存在了,所以只有可能在父节点分治中心里。你需要知道的是,每个分治中心的贡献表示的都是通过这个分治中心的路径,而现在我们只需要枚举父节点,观察每个父节点和当前节点的共同贡献,其实就是修改既过当前节点又过父节点分治中心的路径贡献。
以当前分治中心的贡献
这部分的贡献我们思考下怎么来。我们先考虑暴力的,暴力的做法显然是枚举每棵子树,然后考虑这棵子树和其他子树的匹配情况算出拼链,然后再累计上独立链的情况。这个复杂度显然无法接受,我们来考虑下一个问题,我们现在 已知什么 ?我们现在知道从当前点作为分治中心的所有的 \(cnt0\) 和 \(cnt1\),我们如果直接分讨根节点,注意我们这些链都是不通过当前根节点的。
先考虑拼链贡献,假如根节点是 \(0\):
那么我们匹配是:\(cnt0 \times cnt1\),这个显然是错误的,因为如果一条 \(0\) 和 一条 \(1\) 在同一个子树中,这条拼链路径就是不合法贡献。但是我们发现也仅仅只有这部分贡献算重了:每棵子树中的 \(0\) 与 \(1\) 的拼配路径,那么如果我们能记录出每棵子树的这样的不合法路径,然后累计到父节点上,这样我们就可以容斥地去掉所有子树中不合法的路径了。我们需要维护每个父节点的所有子树的 \(01\) 路径匹配数量和。
考虑是 \(1\) 的情况类似:
我们重复算了 \(00\) 和 \(11\) 在每棵子树中,我们同时维护这两部分的桶信息,然后我们注意到 \(cnt0\times cnt0\) 这种还出现了相同路径匹配相同路径,所以其实这里我们应该用组合数的思想:\(comb(cnt0,2)\),从所有的 \(0\) 路径中取两条路径出来匹配,这样就不存在 \((i,j)\) 与 \((j,i)\) 重复计算了。
独立链那就简单了,本质上来说就是从父节点出发的 \(0\) 和 \(1\) 链,分讨下父节点情况,加上 \(x \oplus 1\) 的桶信息即可。
这个分治中心所有父节点分治中心关于当前分治中心的贡献
我们知道点分树的深度是 \(\log{n}\) 的,父节点的数量也即为 \(\log{n}\),所以我们枚举每个父节点的贡献先去掉,然后考虑修改以后新的上述的两类贡献再加回来。
我们定义下:
\(cnt[x][pos][0/1]\) 表示点 \(x\) 作为分治中心,拆位以后的第 \(pos\) 位,\(0/1\) 的前缀异或链的数量。
\(cntFa[x][pos][0/1]\) 表示点 \(fa[x]\),注意这里的 \(fa\) 是点分树上的,并不是原树上的,\(fa[x]\) 在 \(x\) 方向上的子树的 \(cnt\) 情况,即在这个方向上的前缀异或链的桶信息,实际上假设当前这个方向上 \(fa[x]\) 原树上相邻的儿子为:\(son\),则其实为 \(cnt[Fa[x]][pos][0/1][son]\),以 \(son\) 为起点的所有链信息。
\(x\) 为当前分治中心。\(fa[x]\) 为它上一个分治中心,即从 \(son\) 出发的所有前缀异或情况。
那么 \(cnt[fa[x]][pos][0/1]-cntFa[x][pos][0/1]\) 就可以表示出除了当前这个方向上的其他子树信息。而我们修改时也对应的这个修改,我们显然先要去掉当前子树方向上的信息,然后修改完以后再加回来。
难点:cntFa 的容斥修改
我们观察到 \(cnt\) 的修改很简单,\(x\) 发生变化,但 \(cnt[x]\) 是不包含 \(x\) 为根的前缀异或路径,所以并没有发生修改,而它的父节点的分治中心的:\(cnt[fa[x]]\) 我们首先应该先去掉当前的 \(cntFa[x]\) 再考虑新的 \(cntFa[x]\) 是什么情况再加回去即可。同时,如果我们维护好新的 \(cnt\) 和 \(cntFa\) 也能顺理成章地算出关于 \(cnt00\)、\(cnt01\)、\(cnt11\) 的情况。
现在我们来考虑 \(cntFa\) 怎么变的,对于某个 \(fa\) 来说:
关于 \(x\) 发生变化以后即:\(0\rightarrow 1\) 或者 \(1 \rightarrow 0\),即只有从 \(fa\) 为根,过 \(x\) 的路径才发生了变化。那么没过的就没发生变化,而且这个变化很简单,就是原来为 \(1\) 的路径变成 \(0\),而 \(0\) 变成了 \(1\),我们只要能算出从 \(fa\) 过 \(x\) 的所有路径的桶信息:\(cnt0\) 与 \(cnt1\) 在原来的 \(cntFa[x]\) 中去掉这部分,然后交换二者再加回来就正确了。我们现在考虑怎么计算 \(cnt0\) 和 \(cnt1\)。
注意到一点,我们 \(cnt[x]\) 的信息,一定会存在一个分支指向 \(fa\),如上图的蓝色点所示,所以我们其实是需要知道一个信息:关于 \(x\) 到它每个 \(fa\) 方向上的 \(son\)。因为我们只有拿到这个 \(son\) ,才能 \(cnt[x]-cntFa[son]\) 去掉这部分不正确的贡献。
\(x \rightarrow fa\) 的 \(son\)
这个 \(son\) 如图所示,貌似只需要满足三点共线即:\(son\) 在 \(fa\) 与 \(x\) 之间就行了?这个很好确认,我们利用树上前缀和算出 \((x,son)+(son,fa)-1=(x,fa)\) 即可,注意是点权,所以需要去掉一个重复算的点 \(son\)。
其实还有一种情况:
点分树中一个很经典的问题,初学者中经常会犯的问题,\(son\) 并不一定会和 \(x\) 相连,图上这种情况我们要找的 \(son\) 也需要确定,显然这个也好找,就是在第一点不符合的情况下,\((fa,x)+(x,son)-1 \neq (fa,son)\) 即不是 \(son1\) 的情况。
其实本质上来说,不是 \(son1\) 的情况就是上述两种情况了。所以我们可以枚举 \(son\),算出这个 \(son\) 对于 \(x\) 来说的所有 \(fa\) 的情况。预处理的复杂度显然枚举是 \(n\log{n}\),然后判断涉及到找 \(LCA\) 所以在没有做欧拉序的预处理下,这个复杂度即为:\(n\log^2{n}\)。
cntFa 正确更新
我们对于当前的 \(fa\) 需要考虑在点分树上的其他分治中心对它的影响。显然是我们要枚举一遍 \(faSon[x]\) 的,表示 \(x \rightarrow fa\) 的点分树上其他的分治中心,因为我们知道,\(cnt[x]\) 并不会包含经过 \(faSon[x]\) 的,所以我们需要拼:
如图,左边是原树,右边是点分树,所以我们应当在 \(faSon\) 中去掉 \(x\) 方向上的贡献。最后其实我们需要计算出 \((x,fa)\) 这条路径异或值或者 \((fa,faSon)\) 路径异或值,这个单点修加路径异或,我们可以使用树链剖分+线段树/树状数组实现。这里的复杂度显然是 \(\log^2{n}\log{W}\),枚举 \(fa\),对于每个 \(fa\) 再枚举 \(faSon\)。
树状数组版本代码
#include <bits/stdc++.h>
// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
// #define isPbdsFile
#ifdef isPbdsFile
#include <bits/extc++.h>
#else
#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>
#endif
using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
template <typename T>
int disc(T* a, int n)
{
return unique(a + 1, a + n + 1) - (a + 1);
}
template <typename T>
T lowBit(T x)
{
return x & -x;
}
template <typename T>
T Rand(T l, T r)
{
static mt19937 Rand(time(nullptr));
uniform_int_distribution<T> dis(l, r);
return dis(Rand);
}
template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
return (a % b + b) % b;
}
template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
a %= c;
T1 ans = 1;
for (; b; b >>= 1, (a *= a) %= c) if (b & 1) (ans *= a) %= c;
return modt(ans, c);
}
template <typename T>
void read(T& x)
{
x = 0;
T sign = 1;
char ch = getchar();
while (!isdigit(ch))
{
if (ch == '-') sign = -1;
ch = getchar();
}
while (isdigit(ch))
{
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
x *= sign;
}
template <typename T, typename... U>
void read(T& x, U&... y)
{
read(x);
read(y...);
}
template <typename T>
void write(T x)
{
if (typeid(x) == typeid(char)) return;
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 ^ 48);
}
template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
write(x), putchar(c);
write(c, y...);
}
template <typename T11, typename T22, typename T33>
struct T3
{
T11 one;
T22 tow;
T33 three;
bool operator<(const T3 other) const
{
if (one == other.one)
{
if (tow == other.tow) return three < other.three;
return tow < other.tow;
}
return one < other.one;
}
T3()
{
one = tow = three = 0;
}
T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
{
}
};
template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
if (x < y) x = y;
}
template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
if (x > y) x = y;
}
constexpr int N = 1e5 + 10;
constexpr int MX = 1e8;
constexpr int T = log2(MX) + 1;
vector<int> child[N];
int a[N];
int n, q;
ll res[T + 1], tmpRes[T + 1];
bool val[N][T + 1];
struct
{
int fa[N], deep[N], son[N], siz[N];
void dfs1(const int curr, const int pa)
{
deep[curr] = deep[fa[curr] = pa] + 1;
siz[curr] = 1;
for (const int nxt : child[curr])
{
if (nxt == pa) continue;
dfs1(nxt, curr);
siz[curr] += siz[nxt];
if (siz[nxt] > siz[son[curr]]) son[curr] = nxt;
}
}
int idx[N], rev[N], top[N], cnt;
void dfs2(const int curr, const int root)
{
top[curr] = root;
idx[rev[curr] = ++cnt] = curr;
if (!son[curr]) return;
dfs2(son[curr], root);
for (const int nxt : child[curr]) if (nxt != fa[curr] and nxt != son[curr]) dfs2(nxt, nxt);
}
int bit[N];
void update(int x, const int val)
{
const int v = a[idx[x]] ^ val;
while (x <= n) bit[x] ^= v, x += lowBit(x);
}
int query(int x) const
{
int ans = 0;
while (x) ans ^= bit[x], x -= lowBit(x);
return ans;
}
int query(const int l, const int r) const
{
return query(r) ^ query(l - 1);
}
void build()
{
forn(i, 1, n)
{
bit[i] ^= a[idx[i]];
const int j = i + lowBit(i);
if (j <= n) bit[j] ^= bit[i];
}
}
void treeUpdate(const int curr, const int val)
{
update(rev[curr], val);
}
int lca(int x, int y) const
{
while (top[x] != top[y])
{
if (deep[top[x]] < deep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if (deep[x] > deep[y]) swap(x, y);
return x;
}
int treeQuery(int u, int v) const
{
int ans = 0;
while (top[u] != top[v])
{
if (deep[top[u]] < deep[top[v]]) swap(u, v);
ans ^= query(rev[top[u]], rev[u]);
u = fa[top[u]];
}
if (deep[u] > deep[v]) swap(u, v);
return ans ^ query(rev[u], rev[v]);
}
int deepDist(const int x, const int y) const
{
const int LCA = lca(x, y);
return deep[x] + deep[y] - deep[LCA] - deep[fa[LCA]];
}
bool isSeg(const int x, const int mid, const int y) const
{
return deepDist(x, mid) + deepDist(mid, y) - 1 == deepDist(x, y);
}
void init()
{
dfs1(1, 0);
dfs2(1, 1);
build();
}
} HLD;
inline ll getAns()
{
ll ans = 0;
forn(i, 0, T) ans += res[i] * (1LL << i);
return ans;
}
struct
{
bool del[N];
int deep[N], fa[N];
int sumSize, maxSon, root;
int siz[N];
ll cnt[N][T + 1][2], cntFa[N][T + 1][2]; //根,这个点的父亲到它这个方向的子树的桶信息
ll c00[N][T + 1], c11[N][T + 1], c01[N][T + 1]; //每个点的所有子树的01和、11和、01和
int center[N][T + 1];
ll dist[N][T + 1];
static ll comb2(const ll v)
{
return v * (v - 1) / 2;
}
void dfs(const int curr, const int pa)
{
siz[curr] = 1;
for (const int nxt : child[curr]) if (nxt != pa and !del[nxt]) dfs(nxt, curr), siz[curr] += siz[nxt];
}
void buildCntFa(const int curr, const int pa, const int top, const bool xorSum, const int pos)
{
++cntFa[top][pos][xorSum];
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
buildCntFa(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
}
}
void buildCnt(const int curr, const int pa, const int top, const bool xorSum, const int pos)
{
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
++cnt[top][pos][xorSum ^ val[nxt][pos]];
buildCnt(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
}
}
void makeRoot(const int curr, const int pa)
{
siz[curr] = 1;
int currSize = 0;
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
makeRoot(nxt, curr);
siz[curr] += siz[nxt];
uMax(currSize, siz[nxt]);
}
uMax(currSize, sumSize - siz[curr]);
if (currSize < maxSon) maxSon = currSize, root = curr;
}
void build(const int curr)
{
del[curr] = true;
forn(pos, 0, T) buildCnt(curr, 0, curr, false, pos);
for (const int nxt : child[curr])
{
if (del[nxt]) continue;
sumSize = maxSon = siz[nxt];
makeRoot(nxt, 0);
forn(pos, 0, T) buildCntFa(nxt, curr, root, val[nxt][pos], pos);
const int rt = root;
dfs(root, 0);
fa[root] = curr;
deep[root] = deep[curr] + 1;
forn(pos, 0, T)
{
const int idx = val[curr][pos];
tmpRes[pos] += (cnt[curr][pos][0] - cntFa[rt][pos][0]) * cntFa[rt][pos][idx ^ 1];
tmpRes[pos] += (cnt[curr][pos][1] - cntFa[rt][pos][1]) * cntFa[rt][pos][idx];
res[pos] += cntFa[rt][pos][idx ^ 1];
}
forn(pos, 0, T)
{
c00[curr][pos] += comb2(cntFa[rt][pos][0]);
c01[curr][pos] += cntFa[rt][pos][0] * cntFa[rt][pos][1];
c11[curr][pos] += comb2(cntFa[rt][pos][1]);
}
build(root);
}
}
ll v0[T + 1], v1[T + 1];
void get01(const int curr, const int old, const int Fa)
{
const int c = center[curr][deep[curr] - deep[Fa]];
forn(i, 0, T)
{
const int oldIdx = old >> i & 1;
v0[i] = cnt[curr][i][oldIdx] - cntFa[c][i][oldIdx] + (oldIdx == 0);
v1[i] = cnt[curr][i][oldIdx ^ 1] - cntFa[c][i][oldIdx ^ 1] + (oldIdx == 1);
}
for (int nxt = curr; fa[nxt] != Fa; nxt = fa[nxt])
{
if (HLD.isSeg(Fa, curr, fa[nxt]))
{
const ll dst = dist[curr][deep[curr] - deep[fa[nxt]]] ^ a[fa[nxt]] ^ a[curr];
forn(i, 0, T)
{
const int oldIdx = old >> i & 1;
const int idx = dst >> i & 1;
const int need = oldIdx ^ idx;
v0[i] += cnt[fa[nxt]][i][need] - cntFa[nxt][i][need] + (need == 0);
v1[i] += cnt[fa[nxt]][i][need ^ 1] - cntFa[nxt][i][need ^ 1] + (need == 1);
}
}
}
}
void Update(const int curr, const int v)
{
forn(i, 0, T) res[i] -= a[curr] >> i & 1;
forn(i, 0, T)
{
if (a[curr] >> i & 1)
{
//00 11
res[i] -= comb2(cnt[curr][i][0]) - c00[curr][i];
res[i] -= comb2(cnt[curr][i][1]) - c11[curr][i];
res[i] -= cnt[curr][i][0];
}
else
{
//01
res[i] -= cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
res[i] -= cnt[curr][i][1];
}
}
for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
{
dist[curr][deep[curr] - deep[fa[nxt]]] = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
forn(i, 0, T)
{
res[i] -= (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
res[i] -= (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
res[i] -= cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
}
}
HLD.treeUpdate(curr, v);
for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
{
const int pNew = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
const int pOld = dist[curr][deep[curr] - deep[fa[nxt]]];
get01(curr, pOld, fa[nxt]);
forn(i, 0, T)
{
const int oldIdx = pOld >> i & 1;
const int newIdx = pNew >> i & 1;
if (oldIdx != newIdx)
{
c00[fa[nxt]][i] -= comb2(cntFa[nxt][i][0]);
c01[fa[nxt]][i] -= cntFa[nxt][i][0] * cntFa[nxt][i][1];
c11[fa[nxt]][i] -= comb2(cntFa[nxt][i][1]);
forn(j, 0, 1) cnt[fa[nxt]][i][j] -= cntFa[nxt][i][j];
cntFa[nxt][i][0] += v1[i] - v0[i];
cntFa[nxt][i][1] += v0[i] - v1[i];
forn(j, 0, 1) cnt[fa[nxt]][i][j] += cntFa[nxt][i][j];
c00[fa[nxt]][i] += comb2(cntFa[nxt][i][0]);
c01[fa[nxt]][i] += cntFa[nxt][i][0] * cntFa[nxt][i][1];
c11[fa[nxt]][i] += comb2(cntFa[nxt][i][1]);
}
res[i] += (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
res[i] += (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
res[i] += cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
}
}
a[curr] = v;
forn(i, 0, T) res[i] += a[curr] >> i & 1;
forn(i, 0, T) val[curr][i] = v >> i & 1;
forn(i, 0, T)
{
if (a[curr] >> i & 1)
{
//00 11
res[i] += comb2(cnt[curr][i][0]) - c00[curr][i];
res[i] += comb2(cnt[curr][i][1]) - c11[curr][i];
res[i] += cnt[curr][i][0];
}
else
{
//01
res[i] += cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
res[i] += cnt[curr][i][1];
}
}
}
void init()
{
sumSize = maxSon = n;
makeRoot(1, 0);
dfs(root, 0);
build(root);
HLD.init();
forn(i, 1, n)
{
const int curr = fa[i];
for (int nxt = fa[curr]; nxt; nxt = fa[nxt])
{
if (!HLD.isSeg(i, curr, nxt))
{
center[curr][deep[curr] - deep[nxt]] = i;
}
}
}
}
} pointTree;
inline void solve()
{
cin >> n;
forn(i, 1, n)
{
cin >> a[i];
forn(j, 0, T) val[i][j] = a[i] >> j & 1, res[j] += val[i][j];
}
forn(i, 1, n-1)
{
int u, v;
cin >> u >> v;
child[u].push_back(v);
child[v].push_back(u);
}
pointTree.init();
forn(i, 0, T) res[i] += tmpRes[i] >> 1;
cout << getAns() << endl;
cin >> q;
while (q--)
{
int p, v;
cin >> p >> v;
pointTree.Update(p, v);
cout << getAns() << endl;
}
}
signed int main()
{
// MyFile
Spider
//------------------------------------------------------
// clock_t start = clock();
int test = 1;
// read(test);
// cin >> test;
forn(i, 1, test) solve();
// while (cin >> n, n)solve();
// while (cin >> test)solve();
// clock_t end = clock();
// cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}
线段树版本代码
#include <bits/stdc++.h>
// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")
// #define isPbdsFile
#ifdef isPbdsFile
#include <bits/extc++.h>
#else
#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>
#endif
using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};
template <typename T>
int disc(T* a, int n)
{
return unique(a + 1, a + n + 1) - (a + 1);
}
template <typename T>
T lowBit(T x)
{
return x & -x;
}
template <typename T>
T Rand(T l, T r)
{
static mt19937 Rand(time(nullptr));
uniform_int_distribution<T> dis(l, r);
return dis(Rand);
}
template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
return (a % b + b) % b;
}
template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
a %= c;
T1 ans = 1;
for (; b; b >>= 1, (a *= a) %= c) if (b & 1) (ans *= a) %= c;
return modt(ans, c);
}
template <typename T>
void read(T& x)
{
x = 0;
T sign = 1;
char ch = getchar();
while (!isdigit(ch))
{
if (ch == '-') sign = -1;
ch = getchar();
}
while (isdigit(ch))
{
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
x *= sign;
}
template <typename T, typename... U>
void read(T& x, U&... y)
{
read(x);
read(y...);
}
template <typename T>
void write(T x)
{
if (typeid(x) == typeid(char)) return;
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 ^ 48);
}
template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
write(x), putchar(c);
write(c, y...);
}
template <typename T11, typename T22, typename T33>
struct T3
{
T11 one;
T22 tow;
T33 three;
bool operator<(const T3 other) const
{
if (one == other.one)
{
if (tow == other.tow) return three < other.three;
return tow < other.tow;
}
return one < other.one;
}
T3()
{
one = tow = three = 0;
}
T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
{
}
};
template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
if (x < y) x = y;
}
template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
if (x > y) x = y;
}
constexpr int N = 1e5 + 10;
constexpr int MX = 1e8;
constexpr int T = log2(MX) + 1;
vector<int> child[N];
int a[N];
int n, q;
ll res[T + 1], tmpRes[T + 1];
bool val[N][T + 1];
struct
{
int fa[N], deep[N], son[N], siz[N];
void dfs1(const int curr, const int pa)
{
deep[curr] = deep[fa[curr] = pa] + 1;
siz[curr] = 1;
for (const int nxt : child[curr])
{
if (nxt == pa) continue;
dfs1(nxt, curr);
siz[curr] += siz[nxt];
if (siz[nxt] > siz[son[curr]]) son[curr] = nxt;
}
}
int idx[N], rev[N], top[N], cnt;
void dfs2(const int curr, const int root)
{
top[curr] = root;
idx[rev[curr] = ++cnt] = curr;
if (!son[curr]) return;
dfs2(son[curr], root);
for (const int nxt : child[curr]) if (nxt != fa[curr] and nxt != son[curr]) dfs2(nxt, nxt);
}
int xorSum[N << 2];
void pushUp(const int curr)
{
xorSum[curr] = xorSum[ls(curr)] ^ xorSum[rs(curr)];
}
void update(const int curr, const int pos, const int val, const int l = 1, const int r = n)
{
if (l == r)
{
xorSum[curr] = val;
return;
}
const int mid = l + r >> 1;
if (pos <= mid) update(ls(curr), pos, val, l, mid);
else update(rs(curr), pos, val, mid + 1, r);
pushUp(curr);
}
int query(const int curr, const int l, const int r, const int s = 1, const int e = n)
{
if (l <= s and e <= r) return xorSum[curr];
const int mid = s + e >> 1;
int ans = 0;
if (l <= mid) ans ^= query(ls(curr), l, r, s, mid);
if (r > mid) ans ^= query(rs(curr), l, r, mid + 1, e);
return ans;
}
void build(const int curr = 1, const int l = 1, const int r = n)
{
if (l == r)
{
xorSum[curr] = a[idx[l]];
return;
}
const int mid = l + r >> 1;
build(ls(curr), l, mid);
build(rs(curr), mid + 1, r);
pushUp(curr);
}
void treeUpdate(const int curr, const int val)
{
update(1, rev[curr], val);
}
int lca(int x, int y) const
{
while (top[x] != top[y])
{
if (deep[top[x]] < deep[top[y]]) swap(x, y);
x = fa[top[x]];
}
if (deep[x] > deep[y]) swap(x, y);
return x;
}
int treeQuery(int u, int v)
{
int ans = 0;
while (top[u] != top[v])
{
if (deep[top[u]] < deep[top[v]]) swap(u, v);
ans ^= query(1, rev[top[u]], rev[u]);
u = fa[top[u]];
}
if (deep[u] > deep[v]) swap(u, v);
return ans ^ query(1, rev[u], rev[v]);
}
int deepDist(const int x, const int y) const
{
const int LCA = lca(x, y);
return deep[x] + deep[y] - deep[LCA] - deep[fa[LCA]];
}
bool isSeg(const int x, const int mid, const int y) const
{
return deepDist(x, mid) + deepDist(mid, y) - 1 == deepDist(x, y);
}
void init()
{
dfs1(1, 0);
dfs2(1, 1);
build();
}
} HLD;
inline ll getAns()
{
ll ans = 0;
forn(i, 0, T) ans += res[i] * (1LL << i);
return ans;
}
struct
{
bool del[N];
int deep[N], fa[N];
int sumSize, maxSon, root;
int siz[N];
ll cnt[N][T + 1][2], cntFa[N][T + 1][2]; //根,这个点的父亲到它这个方向的子树的桶信息
ll c00[N][T + 1], c11[N][T + 1], c01[N][T + 1]; //每个点的所有子树的01和、11和、01和
int center[N][T + 1];
ll dist[N][T + 1];
static ll comb2(const ll v)
{
return v * (v - 1) / 2;
}
void dfs(const int curr, const int pa)
{
siz[curr] = 1;
for (const int nxt : child[curr]) if (nxt != pa and !del[nxt]) dfs(nxt, curr), siz[curr] += siz[nxt];
}
void buildCntFa(const int curr, const int pa, const int top, const bool xorSum, const int pos)
{
++cntFa[top][pos][xorSum];
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
buildCntFa(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
}
}
void buildCnt(const int curr, const int pa, const int top, const bool xorSum, const int pos)
{
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
++cnt[top][pos][xorSum ^ val[nxt][pos]];
buildCnt(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
}
}
void makeRoot(const int curr, const int pa)
{
siz[curr] = 1;
int currSize = 0;
for (const int nxt : child[curr])
{
if (nxt == pa or del[nxt]) continue;
makeRoot(nxt, curr);
siz[curr] += siz[nxt];
uMax(currSize, siz[nxt]);
}
uMax(currSize, sumSize - siz[curr]);
if (currSize < maxSon) maxSon = currSize, root = curr;
}
void build(const int curr)
{
del[curr] = true;
forn(pos, 0, T) buildCnt(curr, 0, curr, false, pos);
for (const int nxt : child[curr])
{
if (del[nxt]) continue;
sumSize = maxSon = siz[nxt];
makeRoot(nxt, 0);
forn(pos, 0, T) buildCntFa(nxt, curr, root, val[nxt][pos], pos);
const int rt = root;
dfs(root, 0);
fa[root] = curr;
deep[root] = deep[curr] + 1;
forn(pos, 0, T)
{
const int idx = val[curr][pos];
tmpRes[pos] += (cnt[curr][pos][0] - cntFa[rt][pos][0]) * cntFa[rt][pos][idx ^ 1];
tmpRes[pos] += (cnt[curr][pos][1] - cntFa[rt][pos][1]) * cntFa[rt][pos][idx];
res[pos] += cntFa[rt][pos][idx ^ 1];
}
forn(pos, 0, T)
{
c00[curr][pos] += comb2(cntFa[rt][pos][0]);
c01[curr][pos] += cntFa[rt][pos][0] * cntFa[rt][pos][1];
c11[curr][pos] += comb2(cntFa[rt][pos][1]);
}
build(root);
}
}
ll v0[T + 1], v1[T + 1];
void get01(const int curr, const int old, const int Fa)
{
const int c = center[curr][deep[curr] - deep[Fa]];
forn(i, 0, T)
{
const int oldIdx = old >> i & 1;
v0[i] = cnt[curr][i][oldIdx] - cntFa[c][i][oldIdx] + (oldIdx == 0);
v1[i] = cnt[curr][i][oldIdx ^ 1] - cntFa[c][i][oldIdx ^ 1] + (oldIdx == 1);
}
for (int nxt = curr; fa[nxt] != Fa; nxt = fa[nxt])
{
if (HLD.isSeg(fa[nxt], curr, Fa))
{
const ll dst = dist[curr][deep[curr] - deep[fa[nxt]]] ^ a[fa[nxt]] ^ a[curr];
forn(i, 0, T)
{
const int oldIdx = old >> i & 1;
const int idx = dst >> i & 1;
const int need = oldIdx ^ idx;
v0[i] += cnt[fa[nxt]][i][need] - cntFa[nxt][i][need] + (need == 0);
v1[i] += cnt[fa[nxt]][i][need ^ 1] - cntFa[nxt][i][need ^ 1] + (need == 1);
}
}
}
}
void Update(const int curr, const int v)
{
forn(i, 0, T) res[i] -= a[curr] >> i & 1;
forn(i, 0, T)
{
if (a[curr] >> i & 1)
{
//00 11
res[i] -= comb2(cnt[curr][i][0]) - c00[curr][i];
res[i] -= comb2(cnt[curr][i][1]) - c11[curr][i];
res[i] -= cnt[curr][i][0];
}
else
{
//01
res[i] -= cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
res[i] -= cnt[curr][i][1];
}
}
for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
{
dist[curr][deep[curr] - deep[fa[nxt]]] = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
forn(i, 0, T)
{
res[i] -= (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
res[i] -= (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
res[i] -= cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
}
}
HLD.treeUpdate(curr, v);
for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
{
const int pNew = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
const int pOld = dist[curr][deep[curr] - deep[fa[nxt]]];
get01(curr, pOld, fa[nxt]);
forn(i, 0, T)
{
const int oldIdx = pOld >> i & 1;
const int newIdx = pNew >> i & 1;
if (oldIdx != newIdx)
{
c00[fa[nxt]][i] -= comb2(cntFa[nxt][i][0]);
c01[fa[nxt]][i] -= cntFa[nxt][i][0] * cntFa[nxt][i][1];
c11[fa[nxt]][i] -= comb2(cntFa[nxt][i][1]);
forn(j, 0, 1) cnt[fa[nxt]][i][j] -= cntFa[nxt][i][j];
cntFa[nxt][i][0] += v1[i] - v0[i];
cntFa[nxt][i][1] += v0[i] - v1[i];
forn(j, 0, 1) cnt[fa[nxt]][i][j] += cntFa[nxt][i][j];
c00[fa[nxt]][i] += comb2(cntFa[nxt][i][0]);
c01[fa[nxt]][i] += cntFa[nxt][i][0] * cntFa[nxt][i][1];
c11[fa[nxt]][i] += comb2(cntFa[nxt][i][1]);
}
res[i] += (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
res[i] += (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
res[i] += cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
}
}
a[curr] = v;
forn(i, 0, T) res[i] += a[curr] >> i & 1;
forn(i, 0, T) val[curr][i] = v >> i & 1;
forn(i, 0, T)
{
if (a[curr] >> i & 1)
{
//00 11
res[i] += comb2(cnt[curr][i][0]) - c00[curr][i];
res[i] += comb2(cnt[curr][i][1]) - c11[curr][i];
res[i] += cnt[curr][i][0];
}
else
{
//01
res[i] += cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
res[i] += cnt[curr][i][1];
}
}
}
void init()
{
sumSize = maxSon = n;
makeRoot(1, 0);
dfs(root, 0);
build(root);
HLD.init();
forn(i, 1, n)
{
const int curr = fa[i];
for (int nxt = fa[curr]; nxt; nxt = fa[nxt])
{
if (!HLD.isSeg(i, curr, nxt)) center[curr][deep[curr] - deep[nxt]] = i;
}
}
}
} pointTree;
inline void solve()
{
cin >> n;
forn(i, 1, n)
{
cin >> a[i];
forn(j, 0, T) val[i][j] = a[i] >> j & 1, res[j] += val[i][j];
}
forn(i, 1, n-1)
{
int u, v;
cin >> u >> v;
child[u].push_back(v);
child[v].push_back(u);
}
pointTree.init();
forn(i, 0, T) res[i] += tmpRes[i] >> 1;
cout << getAns() << endl;
cin >> q;
while (q--)
{
int p, v;
cin >> p >> v;
pointTree.Update(p, v);
cout << getAns() << endl;
}
}
signed int main()
{
// MyFile
Spider
//------------------------------------------------------
// clock_t start = clock();
int test = 1;
// read(test);
// cin >> test;
forn(i, 1, test) solve();
// while (cin >> n, n)solve();
// while (cin >> test)solve();
// clock_t end = clock();
// cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}