"华为智联杯"无线程序设计大赛 B题 异或和之和题解

题目链接:CF1 或者 CF2

讲解链接:b站

可以先看b站教学,再来详细地读题解。询问本质上就是所以简单路径的和,简单路径的权值为点权异或

不带修

先考虑不带修怎么做。从点分治的方向来看,我们常常需要维护从分治中心出发的 链信息。这个链信息常常我们用桶来保存,比如存储每个前缀异或值出现的次数。显然直接去做异或贡献计算较为困难,因为我们拿到了一个分治中心的所有前缀异或链信息,不知道相互之间该如何匹配而不是暴力匹配。那么我们转化一下,拆位以后考虑每一位的贡献,那么每一位就只有 \(0\)\(1\),考虑前缀异或为 \(1\)\(0\) 的两个桶,这样就方便分讨匹对了。由于异或涉及到拼起来时,会把根异或掉,因为是这是点权链异或。所以我考虑的是,每个前缀异或都不包含当前的分治中心,这样一来就可以如下分讨:

  1. 我们第一个是需要分讨 拼链 情况:

关于点分治,我们知道对于当前的分治中心会将剩余的所有路径分为两类,无论哪一类都需要通过分治中心。第一类则为 拼链,拼链表示,将 不同子树 的两条链进行匹配,形成一条路径,即 \((u,v)\) 路径,\(u\)\(v\) 属于两个不同的子树。那么分讨一下每一位,对于当前位,如果分治中心为 \(0\),那么显然 \(u\)\(v\) 对应的前缀异或情况为:\((1,1)\)\((0,0)\)。如果为 \(1\),则是 \((0,1)\)

  1. 我们第二个分讨的则是 独立链 情况:

意思就是,\((u,v)\) 其中一个为分治中心,另一个为其他子树,即从分治中心出发的一条链。那么分讨一下,如果分治中心为:\(x\),那么另一条前缀异或为:\(x \oplus 1\)

那么其实不带修的情况,我们可以枚举每棵子树分别算两种贡献。

关于点分治和点分树的常见容斥套路:

  1. 维护以每个分治中心出发的链信息。

  2. 维护对于每个分治中心来说,每个子树方向的链信息。

其中第二点如果照字面意思来看,我们需要维护一个 \(hash\),表式 \(center\Rightarrow son\) 方向的桶信息,这个显然是常数过大的。我们常见写法基于,一个父亲有多个儿子,但一个儿子只有一个父亲,我们这个映射可以通过儿子来找父亲。但又因为我们枚举贡献时,其实都是枚举分治中心,并不是相邻的儿子节点,所以这个映射我们可以改为从儿子分治中心到父亲分治中心,这样一来 \(cnt[son]\),就表示 \(fa[son]\)\(son\) 方向的子树信息,注意这里的 \(son\) 是分治中心之间的关系,并不是原树之间的子树关系。

那么维护好这两个东西以后,我们就可以使用父节点的所有信息减去某一棵子树方向上的信息,从而拿到除开这棵子树外的所有其他子树的信息,这样可以保证当前子树只会和其他子树进行匹配,而不会和自身匹配出不合法的路径信息。

注意到我们枚举了每棵子树和其他子树匹配信息,这个过程算了两次,比如前面编号 \(1\)\(3\) 在枚举第一棵子树时算过了,又在枚举第三棵子树时重新算了一次,所以拼链结果记得除以 \(2\),而独立链则不会,直接加上去即可。

带修

先考虑建出点分树作为带修的贡献修改。我们知道一旦一个分治中心被修改了,会影响:

  1. 以当前分治中心的贡献。

  2. 这个分治中心所有父节点分治中心关于当前分治中心的贡献。

原理:修改了一个点,所有经过这个点的路径都会被更改,那么我们知道从当前这个点作为的分治中心往下的儿子节点时,这个点已经不存在了,所以只有可能在父节点分治中心里。你需要知道的是,每个分治中心的贡献表示的都是通过这个分治中心的路径,而现在我们只需要枚举父节点,观察每个父节点和当前节点的共同贡献,其实就是修改既过当前节点又过父节点分治中心的路径贡献。

以当前分治中心的贡献

这部分的贡献我们思考下怎么来。我们先考虑暴力的,暴力的做法显然是枚举每棵子树,然后考虑这棵子树和其他子树的匹配情况算出拼链,然后再累计上独立链的情况。这个复杂度显然无法接受,我们来考虑下一个问题,我们现在 已知什么 ?我们现在知道从当前点作为分治中心的所有的 \(cnt0\)\(cnt1\),我们如果直接分讨根节点,注意我们这些链都是不通过当前根节点的。

先考虑拼链贡献,假如根节点是 \(0\)

那么我们匹配是:\(cnt0 \times cnt1\),这个显然是错误的,因为如果一条 \(0\) 和 一条 \(1\) 在同一个子树中,这条拼链路径就是不合法贡献。但是我们发现也仅仅只有这部分贡献算重了:每棵子树中的 \(0\)\(1\) 的拼配路径,那么如果我们能记录出每棵子树的这样的不合法路径,然后累计到父节点上,这样我们就可以容斥地去掉所有子树中不合法的路径了。我们需要维护每个父节点的所有子树的 \(01\) 路径匹配数量和。

考虑是 \(1\) 的情况类似:

我们重复算了 \(00\)\(11\) 在每棵子树中,我们同时维护这两部分的桶信息,然后我们注意到 \(cnt0\times cnt0\) 这种还出现了相同路径匹配相同路径,所以其实这里我们应该用组合数的思想:\(comb(cnt0,2)\),从所有的 \(0\) 路径中取两条路径出来匹配,这样就不存在 \((i,j)\)\((j,i)\) 重复计算了。

独立链那就简单了,本质上来说就是从父节点出发的 \(0\)\(1\) 链,分讨下父节点情况,加上 \(x \oplus 1\) 的桶信息即可。

这个分治中心所有父节点分治中心关于当前分治中心的贡献

我们知道点分树的深度是 \(\log{n}\) 的,父节点的数量也即为 \(\log{n}\),所以我们枚举每个父节点的贡献先去掉,然后考虑修改以后新的上述的两类贡献再加回来。

我们定义下:

\(cnt[x][pos][0/1]\) 表示点 \(x\) 作为分治中心,拆位以后的第 \(pos\) 位,\(0/1\) 的前缀异或链的数量。

\(cntFa[x][pos][0/1]\) 表示点 \(fa[x]\),注意这里的 \(fa\) 是点分树上的,并不是原树上的,\(fa[x]\)\(x\) 方向上的子树的 \(cnt\) 情况,即在这个方向上的前缀异或链的桶信息,实际上假设当前这个方向上 \(fa[x]\) 原树上相邻的儿子为:\(son\),则其实为 \(cnt[Fa[x]][pos][0/1][son]\),以 \(son\) 为起点的所有链信息。

\(x\) 为当前分治中心。\(fa[x]\) 为它上一个分治中心,即从 \(son\) 出发的所有前缀异或情况。

那么 \(cnt[fa[x]][pos][0/1]-cntFa[x][pos][0/1]\) 就可以表示出除了当前这个方向上的其他子树信息。而我们修改时也对应的这个修改,我们显然先要去掉当前子树方向上的信息,然后修改完以后再加回来。

难点:cntFa 的容斥修改

我们观察到 \(cnt\) 的修改很简单,\(x\) 发生变化,但 \(cnt[x]\) 是不包含 \(x\) 为根的前缀异或路径,所以并没有发生修改,而它的父节点的分治中心的:\(cnt[fa[x]]\) 我们首先应该先去掉当前的 \(cntFa[x]\) 再考虑新的 \(cntFa[x]\) 是什么情况再加回去即可。同时,如果我们维护好新的 \(cnt\)\(cntFa\) 也能顺理成章地算出关于 \(cnt00\)\(cnt01\)\(cnt11\) 的情况。

现在我们来考虑 \(cntFa\) 怎么变的,对于某个 \(fa\) 来说:

关于 \(x\) 发生变化以后即:\(0\rightarrow 1\) 或者 \(1 \rightarrow 0\),即只有从 \(fa\) 为根,过 \(x\) 的路径才发生了变化。那么没过的就没发生变化,而且这个变化很简单,就是原来为 \(1\) 的路径变成 \(0\),而 \(0\) 变成了 \(1\),我们只要能算出从 \(fa\)\(x\) 的所有路径的桶信息:\(cnt0\)\(cnt1\) 在原来的 \(cntFa[x]\) 中去掉这部分,然后交换二者再加回来就正确了。我们现在考虑怎么计算 \(cnt0\)\(cnt1\)

注意到一点,我们 \(cnt[x]\) 的信息,一定会存在一个分支指向 \(fa\),如上图的蓝色点所示,所以我们其实是需要知道一个信息:关于 \(x\) 到它每个 \(fa\) 方向上的 \(son\)。因为我们只有拿到这个 \(son\) ,才能 \(cnt[x]-cntFa[son]\) 去掉这部分不正确的贡献。

\(x \rightarrow fa\)\(son\)

这个 \(son\) 如图所示,貌似只需要满足三点共线即:\(son\)\(fa\)\(x\) 之间就行了?这个很好确认,我们利用树上前缀和算出 \((x,son)+(son,fa)-1=(x,fa)\) 即可,注意是点权,所以需要去掉一个重复算的点 \(son\)

其实还有一种情况:

点分树中一个很经典的问题,初学者中经常会犯的问题,\(son\) 并不一定会和 \(x\) 相连,图上这种情况我们要找的 \(son\) 也需要确定,显然这个也好找,就是在第一点不符合的情况下,\((fa,x)+(x,son)-1 \neq (fa,son)\) 即不是 \(son1\) 的情况。

其实本质上来说,不是 \(son1\) 的情况就是上述两种情况了。所以我们可以枚举 \(son\),算出这个 \(son\) 对于 \(x\) 来说的所有 \(fa\) 的情况。预处理的复杂度显然枚举是 \(n\log{n}\),然后判断涉及到找 \(LCA\) 所以在没有做欧拉序的预处理下,这个复杂度即为:\(n\log^2{n}\)

cntFa 正确更新

我们对于当前的 \(fa\) 需要考虑在点分树上的其他分治中心对它的影响。显然是我们要枚举一遍 \(faSon[x]\) 的,表示 \(x \rightarrow fa\) 的点分树上其他的分治中心,因为我们知道,\(cnt[x]\) 并不会包含经过 \(faSon[x]\) 的,所以我们需要拼:

如图,左边是原树,右边是点分树,所以我们应当在 \(faSon\) 中去掉 \(x\) 方向上的贡献。最后其实我们需要计算出 \((x,fa)\) 这条路径异或值或者 \((fa,faSon)\) 路径异或值,这个单点修加路径异或,我们可以使用树链剖分+线段树/树状数组实现。这里的复杂度显然是 \(\log^2{n}\log{W}\),枚举 \(fa\),对于每个 \(fa\) 再枚举 \(faSon\)

树状数组版本代码
#include <bits/stdc++.h>

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

// #define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c) if (b & 1) (ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-') sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char)) return;
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow) return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3()
    {
        one = tow = three = 0;
    }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y) x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y) x = y;
}

constexpr int N = 1e5 + 10;
constexpr int MX = 1e8;
constexpr int T = log2(MX) + 1;
vector<int> child[N];
int a[N];
int n, q;
ll res[T + 1], tmpRes[T + 1];
bool val[N][T + 1];

struct
{
    int fa[N], deep[N], son[N], siz[N];

    void dfs1(const int curr, const int pa)
    {
        deep[curr] = deep[fa[curr] = pa] + 1;
        siz[curr] = 1;
        for (const int nxt : child[curr])
        {
            if (nxt == pa) continue;
            dfs1(nxt, curr);
            siz[curr] += siz[nxt];
            if (siz[nxt] > siz[son[curr]]) son[curr] = nxt;
        }
    }

    int idx[N], rev[N], top[N], cnt;

    void dfs2(const int curr, const int root)
    {
        top[curr] = root;
        idx[rev[curr] = ++cnt] = curr;
        if (!son[curr]) return;
        dfs2(son[curr], root);
        for (const int nxt : child[curr]) if (nxt != fa[curr] and nxt != son[curr]) dfs2(nxt, nxt);
    }

    int bit[N];

    void update(int x, const int val)
    {
        const int v = a[idx[x]] ^ val;
        while (x <= n) bit[x] ^= v, x += lowBit(x);
    }

    int query(int x) const
    {
        int ans = 0;
        while (x) ans ^= bit[x], x -= lowBit(x);
        return ans;
    }

    int query(const int l, const int r) const
    {
        return query(r) ^ query(l - 1);
    }

    void build()
    {
        forn(i, 1, n)
        {
            bit[i] ^= a[idx[i]];
            const int j = i + lowBit(i);
            if (j <= n) bit[j] ^= bit[i];
        }
    }

    void treeUpdate(const int curr, const int val)
    {
        update(rev[curr], val);
    }

    int lca(int x, int y) const
    {
        while (top[x] != top[y])
        {
            if (deep[top[x]] < deep[top[y]]) swap(x, y);
            x = fa[top[x]];
        }
        if (deep[x] > deep[y]) swap(x, y);
        return x;
    }

    int treeQuery(int u, int v) const
    {
        int ans = 0;
        while (top[u] != top[v])
        {
            if (deep[top[u]] < deep[top[v]]) swap(u, v);
            ans ^= query(rev[top[u]], rev[u]);
            u = fa[top[u]];
        }
        if (deep[u] > deep[v]) swap(u, v);
        return ans ^ query(rev[u], rev[v]);
    }

    int deepDist(const int x, const int y) const
    {
        const int LCA = lca(x, y);
        return deep[x] + deep[y] - deep[LCA] - deep[fa[LCA]];
    }

    bool isSeg(const int x, const int mid, const int y) const
    {
        return deepDist(x, mid) + deepDist(mid, y) - 1 == deepDist(x, y);
    }

    void init()
    {
        dfs1(1, 0);
        dfs2(1, 1);
        build();
    }
} HLD;

inline ll getAns()
{
    ll ans = 0;
    forn(i, 0, T) ans += res[i] * (1LL << i);
    return ans;
}

struct
{
    bool del[N];
    int deep[N], fa[N];
    int sumSize, maxSon, root;
    int siz[N];
    ll cnt[N][T + 1][2], cntFa[N][T + 1][2]; //根,这个点的父亲到它这个方向的子树的桶信息
    ll c00[N][T + 1], c11[N][T + 1], c01[N][T + 1]; //每个点的所有子树的01和、11和、01和
    int center[N][T + 1];
    ll dist[N][T + 1];

    static ll comb2(const ll v)
    {
        return v * (v - 1) / 2;
    }

    void dfs(const int curr, const int pa)
    {
        siz[curr] = 1;
        for (const int nxt : child[curr]) if (nxt != pa and !del[nxt]) dfs(nxt, curr), siz[curr] += siz[nxt];
    }

    void buildCntFa(const int curr, const int pa, const int top, const bool xorSum, const int pos)
    {
        ++cntFa[top][pos][xorSum];
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            buildCntFa(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
        }
    }

    void buildCnt(const int curr, const int pa, const int top, const bool xorSum, const int pos)
    {
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            ++cnt[top][pos][xorSum ^ val[nxt][pos]];
            buildCnt(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
        }
    }

    void makeRoot(const int curr, const int pa)
    {
        siz[curr] = 1;
        int currSize = 0;
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            makeRoot(nxt, curr);
            siz[curr] += siz[nxt];
            uMax(currSize, siz[nxt]);
        }
        uMax(currSize, sumSize - siz[curr]);
        if (currSize < maxSon) maxSon = currSize, root = curr;
    }


    void build(const int curr)
    {
        del[curr] = true;
        forn(pos, 0, T) buildCnt(curr, 0, curr, false, pos);
        for (const int nxt : child[curr])
        {
            if (del[nxt]) continue;
            sumSize = maxSon = siz[nxt];
            makeRoot(nxt, 0);
            forn(pos, 0, T) buildCntFa(nxt, curr, root, val[nxt][pos], pos);
            const int rt = root;
            dfs(root, 0);
            fa[root] = curr;
            deep[root] = deep[curr] + 1;
            forn(pos, 0, T)
            {
                const int idx = val[curr][pos];
                tmpRes[pos] += (cnt[curr][pos][0] - cntFa[rt][pos][0]) * cntFa[rt][pos][idx ^ 1];
                tmpRes[pos] += (cnt[curr][pos][1] - cntFa[rt][pos][1]) * cntFa[rt][pos][idx];
                res[pos] += cntFa[rt][pos][idx ^ 1];
            }
            forn(pos, 0, T)
            {
                c00[curr][pos] += comb2(cntFa[rt][pos][0]);
                c01[curr][pos] += cntFa[rt][pos][0] * cntFa[rt][pos][1];
                c11[curr][pos] += comb2(cntFa[rt][pos][1]);
            }
            build(root);
        }
    }

    ll v0[T + 1], v1[T + 1];

    void get01(const int curr, const int old, const int Fa)
    {
        const int c = center[curr][deep[curr] - deep[Fa]];
        forn(i, 0, T)
        {
            const int oldIdx = old >> i & 1;
            v0[i] = cnt[curr][i][oldIdx] - cntFa[c][i][oldIdx] + (oldIdx == 0);
            v1[i] = cnt[curr][i][oldIdx ^ 1] - cntFa[c][i][oldIdx ^ 1] + (oldIdx == 1);
        }
        for (int nxt = curr; fa[nxt] != Fa; nxt = fa[nxt])
        {
            if (HLD.isSeg(Fa, curr, fa[nxt]))
            {
                const ll dst = dist[curr][deep[curr] - deep[fa[nxt]]] ^ a[fa[nxt]] ^ a[curr];
                forn(i, 0, T)
                {
                    const int oldIdx = old >> i & 1;
                    const int idx = dst >> i & 1;
                    const int need = oldIdx ^ idx;
                    v0[i] += cnt[fa[nxt]][i][need] - cntFa[nxt][i][need] + (need == 0);
                    v1[i] += cnt[fa[nxt]][i][need ^ 1] - cntFa[nxt][i][need ^ 1] + (need == 1);
                }
            }
        }
    }

    void Update(const int curr, const int v)
    {
        forn(i, 0, T) res[i] -= a[curr] >> i & 1;
        forn(i, 0, T)
        {
            if (a[curr] >> i & 1)
            {
                //00 11
                res[i] -= comb2(cnt[curr][i][0]) - c00[curr][i];
                res[i] -= comb2(cnt[curr][i][1]) - c11[curr][i];
                res[i] -= cnt[curr][i][0];
            }
            else
            {
                //01
                res[i] -= cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
                res[i] -= cnt[curr][i][1];
            }
        }
        for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
        {
            dist[curr][deep[curr] - deep[fa[nxt]]] = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
            forn(i, 0, T)
            {
                res[i] -= (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
                res[i] -= (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
                res[i] -= cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
            }
        }
        HLD.treeUpdate(curr, v);
        for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
        {
            const int pNew = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
            const int pOld = dist[curr][deep[curr] - deep[fa[nxt]]];
            get01(curr, pOld, fa[nxt]);
            forn(i, 0, T)
            {
                const int oldIdx = pOld >> i & 1;
                const int newIdx = pNew >> i & 1;
                if (oldIdx != newIdx)
                {
                    c00[fa[nxt]][i] -= comb2(cntFa[nxt][i][0]);
                    c01[fa[nxt]][i] -= cntFa[nxt][i][0] * cntFa[nxt][i][1];
                    c11[fa[nxt]][i] -= comb2(cntFa[nxt][i][1]);
                    forn(j, 0, 1) cnt[fa[nxt]][i][j] -= cntFa[nxt][i][j];
                    cntFa[nxt][i][0] += v1[i] - v0[i];
                    cntFa[nxt][i][1] += v0[i] - v1[i];
                    forn(j, 0, 1) cnt[fa[nxt]][i][j] += cntFa[nxt][i][j];
                    c00[fa[nxt]][i] += comb2(cntFa[nxt][i][0]);
                    c01[fa[nxt]][i] += cntFa[nxt][i][0] * cntFa[nxt][i][1];
                    c11[fa[nxt]][i] += comb2(cntFa[nxt][i][1]);
                }
                res[i] += (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
                res[i] += (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
                res[i] += cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
            }
        }
        a[curr] = v;
        forn(i, 0, T) res[i] += a[curr] >> i & 1;
        forn(i, 0, T) val[curr][i] = v >> i & 1;
        forn(i, 0, T)
        {
            if (a[curr] >> i & 1)
            {
                //00 11
                res[i] += comb2(cnt[curr][i][0]) - c00[curr][i];
                res[i] += comb2(cnt[curr][i][1]) - c11[curr][i];
                res[i] += cnt[curr][i][0];
            }
            else
            {
                //01
                res[i] += cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
                res[i] += cnt[curr][i][1];
            }
        }
    }

    void init()
    {
        sumSize = maxSon = n;
        makeRoot(1, 0);
        dfs(root, 0);
        build(root);
        HLD.init();
        forn(i, 1, n)
        {
            const int curr = fa[i];
            for (int nxt = fa[curr]; nxt; nxt = fa[nxt])
            {
                if (!HLD.isSeg(i, curr, nxt))
                {
                    center[curr][deep[curr] - deep[nxt]] = i;
                }
            }
        }
    }
} pointTree;

inline void solve()
{
    cin >> n;
    forn(i, 1, n)
    {
        cin >> a[i];
        forn(j, 0, T) val[i][j] = a[i] >> j & 1, res[j] += val[i][j];
    }
    forn(i, 1, n-1)
    {
        int u, v;
        cin >> u >> v;
        child[u].push_back(v);
        child[v].push_back(u);
    }
    pointTree.init();
    forn(i, 0, T) res[i] += tmpRes[i] >> 1;
    cout << getAns() << endl;
    cin >> q;
    while (q--)
    {
        int p, v;
        cin >> p >> v;
        pointTree.Update(p, v);
        cout << getAns() << endl;
    }
}

signed int main()
{
    // MyFile
    Spider
    //------------------------------------------------------
    // clock_t start = clock();
    int test = 1;
    //    read(test);
    // cin >> test;
    forn(i, 1, test) solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
    // clock_t end = clock();
    // cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}
线段树版本代码
#include <bits/stdc++.h>

// #pragma GCC optimize(2)
// #pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
// #pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

// #define isPbdsFile

#ifdef isPbdsFile

#include <bits/extc++.h>

#else

#include <ext/pb_ds/priority_queue.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/trie_policy.hpp>
#include <ext/pb_ds/tag_and_trait.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#include <ext/pb_ds/list_update_policy.hpp>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/exception.hpp>
#include <ext/rope>

#endif

using namespace std;
using namespace __gnu_cxx;
using namespace __gnu_pbds;
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tii;
typedef tuple<ll, ll, ll> tll;
typedef unsigned int ui;
typedef unsigned long long ull;
#define hash1 unordered_map
#define hash2 gp_hash_table
#define hash3 cc_hash_table
#define stdHeap std::priority_queue
#define pbdsHeap __gnu_pbds::priority_queue
#define sortArr(a, n) sort(a+1,a+n+1)
#define all(v) v.begin(),v.end()
#define yes cout<<"YES"
#define no cout<<"NO"
#define Spider ios_base::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
#define MyFile freopen("..\\input.txt", "r", stdin),freopen("..\\output.txt", "w", stdout);
#define forn(i, a, b) for(int i = a; i <= b; i++)
#define forv(i, a, b) for(int i=a;i>=b;i--)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define endl '\n'
//用于Miller-Rabin
[[maybe_unused]] static int Prime_Number[13] = {0, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37};

template <typename T>
int disc(T* a, int n)
{
    return unique(a + 1, a + n + 1) - (a + 1);
}

template <typename T>
T lowBit(T x)
{
    return x & -x;
}

template <typename T>
T Rand(T l, T r)
{
    static mt19937 Rand(time(nullptr));
    uniform_int_distribution<T> dis(l, r);
    return dis(Rand);
}

template <typename T1, typename T2>
T1 modt(T1 a, T2 b)
{
    return (a % b + b) % b;
}

template <typename T1, typename T2, typename T3>
T1 qPow(T1 a, T2 b, T3 c)
{
    a %= c;
    T1 ans = 1;
    for (; b; b >>= 1, (a *= a) %= c) if (b & 1) (ans *= a) %= c;
    return modt(ans, c);
}

template <typename T>
void read(T& x)
{
    x = 0;
    T sign = 1;
    char ch = getchar();
    while (!isdigit(ch))
    {
        if (ch == '-') sign = -1;
        ch = getchar();
    }
    while (isdigit(ch))
    {
        x = (x << 3) + (x << 1) + (ch ^ 48);
        ch = getchar();
    }
    x *= sign;
}

template <typename T, typename... U>
void read(T& x, U&... y)
{
    read(x);
    read(y...);
}

template <typename T>
void write(T x)
{
    if (typeid(x) == typeid(char)) return;
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 ^ 48);
}

template <typename C, typename T, typename... U>
void write(C c, T x, U... y)
{
    write(x), putchar(c);
    write(c, y...);
}


template <typename T11, typename T22, typename T33>
struct T3
{
    T11 one;
    T22 tow;
    T33 three;

    bool operator<(const T3 other) const
    {
        if (one == other.one)
        {
            if (tow == other.tow) return three < other.three;
            return tow < other.tow;
        }
        return one < other.one;
    }

    T3()
    {
        one = tow = three = 0;
    }

    T3(T11 one, T22 tow, T33 three) : one(one), tow(tow), three(three)
    {
    }
};

template <typename T1, typename T2>
void uMax(T1& x, T2 y)
{
    if (x < y) x = y;
}

template <typename T1, typename T2>
void uMin(T1& x, T2 y)
{
    if (x > y) x = y;
}

constexpr int N = 1e5 + 10;
constexpr int MX = 1e8;
constexpr int T = log2(MX) + 1;
vector<int> child[N];
int a[N];
int n, q;
ll res[T + 1], tmpRes[T + 1];
bool val[N][T + 1];

struct
{
    int fa[N], deep[N], son[N], siz[N];

    void dfs1(const int curr, const int pa)
    {
        deep[curr] = deep[fa[curr] = pa] + 1;
        siz[curr] = 1;
        for (const int nxt : child[curr])
        {
            if (nxt == pa) continue;
            dfs1(nxt, curr);
            siz[curr] += siz[nxt];
            if (siz[nxt] > siz[son[curr]]) son[curr] = nxt;
        }
    }

    int idx[N], rev[N], top[N], cnt;

    void dfs2(const int curr, const int root)
    {
        top[curr] = root;
        idx[rev[curr] = ++cnt] = curr;
        if (!son[curr]) return;
        dfs2(son[curr], root);
        for (const int nxt : child[curr]) if (nxt != fa[curr] and nxt != son[curr]) dfs2(nxt, nxt);
    }

    int xorSum[N << 2];

    void pushUp(const int curr)
    {
        xorSum[curr] = xorSum[ls(curr)] ^ xorSum[rs(curr)];
    }

    void update(const int curr, const int pos, const int val, const int l = 1, const int r = n)
    {
        if (l == r)
        {
            xorSum[curr] = val;
            return;
        }
        const int mid = l + r >> 1;
        if (pos <= mid) update(ls(curr), pos, val, l, mid);
        else update(rs(curr), pos, val, mid + 1, r);
        pushUp(curr);
    }

    int query(const int curr, const int l, const int r, const int s = 1, const int e = n)
    {
        if (l <= s and e <= r) return xorSum[curr];
        const int mid = s + e >> 1;
        int ans = 0;
        if (l <= mid) ans ^= query(ls(curr), l, r, s, mid);
        if (r > mid) ans ^= query(rs(curr), l, r, mid + 1, e);
        return ans;
    }

    void build(const int curr = 1, const int l = 1, const int r = n)
    {
        if (l == r)
        {
            xorSum[curr] = a[idx[l]];
            return;
        }
        const int mid = l + r >> 1;
        build(ls(curr), l, mid);
        build(rs(curr), mid + 1, r);
        pushUp(curr);
    }

    void treeUpdate(const int curr, const int val)
    {
        update(1, rev[curr], val);
    }

    int lca(int x, int y) const
    {
        while (top[x] != top[y])
        {
            if (deep[top[x]] < deep[top[y]]) swap(x, y);
            x = fa[top[x]];
        }
        if (deep[x] > deep[y]) swap(x, y);
        return x;
    }

    int treeQuery(int u, int v)
    {
        int ans = 0;
        while (top[u] != top[v])
        {
            if (deep[top[u]] < deep[top[v]]) swap(u, v);
            ans ^= query(1, rev[top[u]], rev[u]);
            u = fa[top[u]];
        }
        if (deep[u] > deep[v]) swap(u, v);
        return ans ^ query(1, rev[u], rev[v]);
    }

    int deepDist(const int x, const int y) const
    {
        const int LCA = lca(x, y);
        return deep[x] + deep[y] - deep[LCA] - deep[fa[LCA]];
    }

    bool isSeg(const int x, const int mid, const int y) const
    {
        return deepDist(x, mid) + deepDist(mid, y) - 1 == deepDist(x, y);
    }

    void init()
    {
        dfs1(1, 0);
        dfs2(1, 1);
        build();
    }
} HLD;

inline ll getAns()
{
    ll ans = 0;
    forn(i, 0, T) ans += res[i] * (1LL << i);
    return ans;
}

struct
{
    bool del[N];
    int deep[N], fa[N];
    int sumSize, maxSon, root;
    int siz[N];
    ll cnt[N][T + 1][2], cntFa[N][T + 1][2]; //根,这个点的父亲到它这个方向的子树的桶信息
    ll c00[N][T + 1], c11[N][T + 1], c01[N][T + 1]; //每个点的所有子树的01和、11和、01和
    int center[N][T + 1];
    ll dist[N][T + 1];

    static ll comb2(const ll v)
    {
        return v * (v - 1) / 2;
    }

    void dfs(const int curr, const int pa)
    {
        siz[curr] = 1;
        for (const int nxt : child[curr]) if (nxt != pa and !del[nxt]) dfs(nxt, curr), siz[curr] += siz[nxt];
    }

    void buildCntFa(const int curr, const int pa, const int top, const bool xorSum, const int pos)
    {
        ++cntFa[top][pos][xorSum];
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            buildCntFa(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
        }
    }

    void buildCnt(const int curr, const int pa, const int top, const bool xorSum, const int pos)
    {
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            ++cnt[top][pos][xorSum ^ val[nxt][pos]];
            buildCnt(nxt, curr, top, xorSum ^ val[nxt][pos], pos);
        }
    }

    void makeRoot(const int curr, const int pa)
    {
        siz[curr] = 1;
        int currSize = 0;
        for (const int nxt : child[curr])
        {
            if (nxt == pa or del[nxt]) continue;
            makeRoot(nxt, curr);
            siz[curr] += siz[nxt];
            uMax(currSize, siz[nxt]);
        }
        uMax(currSize, sumSize - siz[curr]);
        if (currSize < maxSon) maxSon = currSize, root = curr;
    }


    void build(const int curr)
    {
        del[curr] = true;
        forn(pos, 0, T) buildCnt(curr, 0, curr, false, pos);
        for (const int nxt : child[curr])
        {
            if (del[nxt]) continue;
            sumSize = maxSon = siz[nxt];
            makeRoot(nxt, 0);
            forn(pos, 0, T) buildCntFa(nxt, curr, root, val[nxt][pos], pos);
            const int rt = root;
            dfs(root, 0);
            fa[root] = curr;
            deep[root] = deep[curr] + 1;
            forn(pos, 0, T)
            {
                const int idx = val[curr][pos];
                tmpRes[pos] += (cnt[curr][pos][0] - cntFa[rt][pos][0]) * cntFa[rt][pos][idx ^ 1];
                tmpRes[pos] += (cnt[curr][pos][1] - cntFa[rt][pos][1]) * cntFa[rt][pos][idx];
                res[pos] += cntFa[rt][pos][idx ^ 1];
            }
            forn(pos, 0, T)
            {
                c00[curr][pos] += comb2(cntFa[rt][pos][0]);
                c01[curr][pos] += cntFa[rt][pos][0] * cntFa[rt][pos][1];
                c11[curr][pos] += comb2(cntFa[rt][pos][1]);
            }
            build(root);
        }
    }

    ll v0[T + 1], v1[T + 1];


    void get01(const int curr, const int old, const int Fa)
    {
        const int c = center[curr][deep[curr] - deep[Fa]];
        forn(i, 0, T)
        {
            const int oldIdx = old >> i & 1;
            v0[i] = cnt[curr][i][oldIdx] - cntFa[c][i][oldIdx] + (oldIdx == 0);
            v1[i] = cnt[curr][i][oldIdx ^ 1] - cntFa[c][i][oldIdx ^ 1] + (oldIdx == 1);
        }
        for (int nxt = curr; fa[nxt] != Fa; nxt = fa[nxt])
        {
            if (HLD.isSeg(fa[nxt], curr, Fa))
            {
                const ll dst = dist[curr][deep[curr] - deep[fa[nxt]]] ^ a[fa[nxt]] ^ a[curr];
                forn(i, 0, T)
                {
                    const int oldIdx = old >> i & 1;
                    const int idx = dst >> i & 1;
                    const int need = oldIdx ^ idx;
                    v0[i] += cnt[fa[nxt]][i][need] - cntFa[nxt][i][need] + (need == 0);
                    v1[i] += cnt[fa[nxt]][i][need ^ 1] - cntFa[nxt][i][need ^ 1] + (need == 1);
                }
            }
        }
    }

    void Update(const int curr, const int v)
    {
        forn(i, 0, T) res[i] -= a[curr] >> i & 1;
        forn(i, 0, T)
        {
            if (a[curr] >> i & 1)
            {
                //00 11
                res[i] -= comb2(cnt[curr][i][0]) - c00[curr][i];
                res[i] -= comb2(cnt[curr][i][1]) - c11[curr][i];
                res[i] -= cnt[curr][i][0];
            }
            else
            {
                //01
                res[i] -= cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
                res[i] -= cnt[curr][i][1];
            }
        }
        for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
        {
            dist[curr][deep[curr] - deep[fa[nxt]]] = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
            forn(i, 0, T)
            {
                res[i] -= (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
                res[i] -= (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
                res[i] -= cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
            }
        }
        HLD.treeUpdate(curr, v);
        for (int nxt = curr; fa[nxt]; nxt = fa[nxt])
        {
            const int pNew = HLD.treeQuery(curr, fa[nxt]) ^ a[fa[nxt]];
            const int pOld = dist[curr][deep[curr] - deep[fa[nxt]]];
            get01(curr, pOld, fa[nxt]);
            forn(i, 0, T)
            {
                const int oldIdx = pOld >> i & 1;
                const int newIdx = pNew >> i & 1;
                if (oldIdx != newIdx)
                {
                    c00[fa[nxt]][i] -= comb2(cntFa[nxt][i][0]);
                    c01[fa[nxt]][i] -= cntFa[nxt][i][0] * cntFa[nxt][i][1];
                    c11[fa[nxt]][i] -= comb2(cntFa[nxt][i][1]);
                    forn(j, 0, 1) cnt[fa[nxt]][i][j] -= cntFa[nxt][i][j];
                    cntFa[nxt][i][0] += v1[i] - v0[i];
                    cntFa[nxt][i][1] += v0[i] - v1[i];
                    forn(j, 0, 1) cnt[fa[nxt]][i][j] += cntFa[nxt][i][j];
                    c00[fa[nxt]][i] += comb2(cntFa[nxt][i][0]);
                    c01[fa[nxt]][i] += cntFa[nxt][i][0] * cntFa[nxt][i][1];
                    c11[fa[nxt]][i] += comb2(cntFa[nxt][i][1]);
                }
                res[i] += (cnt[fa[nxt]][i][0] - cntFa[nxt][i][0]) * cntFa[nxt][i][val[fa[nxt]][i] ^ 1];
                res[i] += (cnt[fa[nxt]][i][1] - cntFa[nxt][i][1]) * cntFa[nxt][i][val[fa[nxt]][i]];
                res[i] += cntFa[nxt][i][a[fa[nxt]] >> i & 1 ^ 1];
            }
        }
        a[curr] = v;
        forn(i, 0, T) res[i] += a[curr] >> i & 1;
        forn(i, 0, T) val[curr][i] = v >> i & 1;
        forn(i, 0, T)
        {
            if (a[curr] >> i & 1)
            {
                //00 11
                res[i] += comb2(cnt[curr][i][0]) - c00[curr][i];
                res[i] += comb2(cnt[curr][i][1]) - c11[curr][i];
                res[i] += cnt[curr][i][0];
            }
            else
            {
                //01
                res[i] += cnt[curr][i][0] * cnt[curr][i][1] - c01[curr][i];
                res[i] += cnt[curr][i][1];
            }
        }
    }

    void init()
    {
        sumSize = maxSon = n;
        makeRoot(1, 0);
        dfs(root, 0);
        build(root);
        HLD.init();
        forn(i, 1, n)
        {
            const int curr = fa[i];
            for (int nxt = fa[curr]; nxt; nxt = fa[nxt])
            {
                if (!HLD.isSeg(i, curr, nxt)) center[curr][deep[curr] - deep[nxt]] = i;
            }
        }
    }
} pointTree;

inline void solve()
{
    cin >> n;
    forn(i, 1, n)
    {
        cin >> a[i];
        forn(j, 0, T) val[i][j] = a[i] >> j & 1, res[j] += val[i][j];
    }
    forn(i, 1, n-1)
    {
        int u, v;
        cin >> u >> v;
        child[u].push_back(v);
        child[v].push_back(u);
    }
    pointTree.init();
    forn(i, 0, T) res[i] += tmpRes[i] >> 1;
    cout << getAns() << endl;
    cin >> q;
    while (q--)
    {
        int p, v;
        cin >> p >> v;
        pointTree.Update(p, v);
        cout << getAns() << endl;
    }
}

signed int main()
{
    // MyFile
    Spider
    //------------------------------------------------------
    // clock_t start = clock();
    int test = 1;
    //    read(test);
    // cin >> test;
    forn(i, 1, test) solve();
    //    while (cin >> n, n)solve();
    //    while (cin >> test)solve();
    // clock_t end = clock();
    // cerr << "time = " << double(end - start) / CLOCKS_PER_SEC << "s" << endl;
}

\[预处理时间复杂度:预处理树剖+点分树的建立+拆位:O(n\log{n}\log{W}) \]

\[查询复杂度:枚举父分治中心+枚举每个儿子中心对父中心+拆位:O(q\log^2{n}\log{W}) \]

posted @ 2024-06-19 15:45  Athanasy  阅读(55)  评论(0编辑  收藏  举报