Day 13 - 树形 DP 与换根 DP

树形 DP

树形 \(\text{DP}\),即在树上进行的 \(\text{DP}\)。由于树固有的递归性质,树形 \(\text{DP}\) 一般都是递归进行的。

基础

以下面这道题为例,介绍一下树形 \(\text{DP}\) 的一般过程。

例题 洛谷 P1352 没有上司的舞会

某大学有 \(n\) 个职员,编号为 \(1 \sim N\)。他们之间有从属关系,也就是说他们的关系就像一棵以校长为根的树,父结点就是子结点的直接上司。现在有个周年庆宴会,宴会每邀请来一个职员都会增加一定的快乐指数 \(a_i\),但是呢,如果某个职员的直接上司来参加舞会了,那么这个职员就无论如何也不肯来参加舞会了。所以,请你编程计算,邀请哪些职员可以使快乐指数最大,求最大的快乐指数。

我们设 \(f(i,0/1)\) 代表以 \(i\) 为根的子树的最优解(第二维的值为 0 代表 \(i\) 不参加舞会的情况,1 代表 \(i\) 参加舞会的情况)。

对于每个状态,都存在两种决策(其中下面的 \(x\) 都是 \(i\) 的儿子):

  • 上司不参加舞会时,下属可以参加,也可以不参加,此时有 \(f(i,0) = \sum\max \{f(x,1),f(x,0)\}\)
  • 上司参加舞会时,下属都不会参加,此时有 \(f(i,1) = \sum{f(x,0)} + a_i\)

我们可以通过 \(\text{DFS}\),在返回上一层时更新当前结点的最优解。

#include<iostream>
#include<vector>
#include<cmath>
using namespace std;
#define MAXN 6005

long long n, a[MAXN], f[MAXN][2];
vector<long long> v[MAXN];

void dfs(long long x, long long fa) {
    f[x][0] = 0; f[x][1] = a[x];
    for(int i = 0; i < v[x].size(); i ++) {
        int p = v[x][i];
        if(p == fa) continue;
        dfs(p, x); f[x][0] += max(f[p][0], f[p][1]);
        f[x][1] += f[p][0];
    }
    return;
}

int main() {
    cin >> n;
    for(int i = 1; i <= n; i ++) cin >> a[i];
    for(int i = 1, fr, to; i < n; i ++) {
        cin >> fr >> to;
        v[fr].push_back(to), v[to].push_back(fr);
    }
    dfs(1, 0);
    cout << max(f[1][0], f[1][1]);
    return 0;
}

习题

树上背包

树上的背包问题,简单来说就是背包问题与树形 \(\text{DP}\) 的结合。

例题 洛谷 P2014 CTSC1997 选课

现在有 \(n\) 门课程,第 \(i\) 门课程的学分为 \(a_i\),每门课程有零门或一门先修课,有先修课的课程需要先学完其先修课,才能学习该课程。

一位学生要学习 \(m\) 门课程,求其能获得的最多学分数。

数据范围:\(n,m \leq 300\)

每门课最多只有一门先修课的特点,与有根树中一个点最多只有一个父亲结点的特点类似。

因此可以想到根据这一性质建树,从而所有课程组成了一个森林的结构。为了方便起见,我们可以新增一门 \(0\) 学分的课程(设这个课程的编号为 \(0\)),作为所有无先修课课程的先修课,这样我们就将森林变成了一棵以 \(0\) 号课程为根的树。

我们设 \(f(u,i,j)\) 表示以 \(u\) 号点为根的子树中,已经遍历了 \(u\) 号点的前 \(i\) 棵子树,选了 \(j\) 门课程的最大学分。

转移的过程结合了树形 \(\text{DP}\) 和 背包 \(\text{DP}\)的特点,我们枚举 \(u\) 点的每个子结点 \(v\),同时枚举以 \(v\) 为根的子树选了几门课程,将子树的结果合并到 \(u\) 上。

记点 \(x\) 的儿子个数为 \(s_x\),以 \(x\) 为根的子树大小为 \(\textit{siz_x}\),可以写出下面的状态转移方程:

\[f(u,i,j)=\max_{v,k \leq j,k \leq \textit{siz_v}} f(u,i-1,j-k)+f(v,s_v,k) \]

注意上面状态转移方程中的几个限制条件,这些限制条件确保了一些无意义的状态不会被访问到。

\(f\) 的第二维可以很轻松地用滚动数组的方式省略掉,注意这时需要倒序枚举 \(j\) 的值。

可以证明,该做法的时间复杂度为 \(O(nm)^1\)

参考代码:

#include<iostream>
#include<vector>
using namespace std;
#define MAXN 305

int f[MAXN][MAXN], s[MAXN], n, m;
vector<int> e[MAXN];

int dfs(int u) {
    int p = 1;
    f[u][1] = s[u];
    for(auto v : e[u]) {
        int siz = dfs(v);
        for(int i = min(p, m + 1); i; i--)
            for(int j = 1; j <= siz && i + j <= m + 1; j++)
                f[u][i + j] = max(f[u][i + j], f[u][i] + f[v][j]);
        p += siz;
    }
    return p;
}

int main() {
    cin >> n >> m;
    for(int i = 1, k; i <= n; i++) {
        cin >> k >> s[i];
        e[k].push_back(i);
    }
    dfs(0);
    cout << f[0][m + 1] << "\n";
    return 0;
}

习题

参考资料与注释

\(^1\): 子树合并背包类型的 dp 的复杂度证明 - LYD729 的 CSDN 博客

换根 DP

树形 \(\text{DP}\) 中的换根 \(\text{DP}\) 问题又被称为二次扫描,通常不会指定根结点,并且根结点的变化会对一些值,例如子结点深度和、点权和等产生影响。

通常需要两次 \(\text{DFS}\),第一次 \(\text{DFS}\) 预处理诸如深度,点权和之类的信息,在第二次 \(\text{DFS}\) 开始运行换根动态规划。

接下来以一些例题来带大家熟悉这个内容。

例题 [POI2008]STA-Station

给定一个 \(n\) 个点的树,请求出一个结点,使得以这个结点为根时,所有结点的深度之和最大。

不妨令 \(u\) 为当前结点,\(v\) 为当前结点的子结点。首先需要用 \(s_i\) 来表示以 \(i\) 为根的子树中的结点个数,并且有 \(s_u=1+\sum s_v\)。显然需要一次 \(\text{DFS}\) 来计算所有的 \(s_i\),这次的 \(\text{DFS}\) 就是预处理,我们得到了以某个结点为根时其子树中的结点总数。

考虑状态转移,这里就是体现"换根"的地方了。令 \(f_u\) 为以 \(u\) 为根时,所有结点的深度之和。

\(f_v\leftarrow f_u\) 可以体现换根,即以 \(u\) 为根转移到以 \(v\) 为根。显然在换根的转移过程中,以 \(v\) 为根或以 \(u\) 为根会导致其子树中的结点的深度产生改变。具体表现为:

  • 所有在 \(v\) 的子树上的结点深度都减少了一,那么总深度和就减少了 \(s_v\)

  • 所有不在 \(v\) 的子树上的结点深度都增加了一,那么总深度和就增加了 \(n-s_v\)

根据这两个条件就可以推出状态转移方程 \(f_v = f_u - s_v + n - s_v=f_u + n - 2 \times s_v\)

于是在第二次 \(\text{DFS}\) 遍历整棵树并状态转移 \(f_v=f_u + n - 2 \times s_v\),那么就能求出以每个结点为根时的深度和了。最后只需要遍历一次所有根结点深度和就可以求出答案。

参考代码:

#include <bits/stdc++.h>
using namespace std;

int head[1000010 << 1], tot;
long long n, sz[1000010], dep[1000010];
long long f[1000010];

struct node {
  int to, next;
} e[1000010 << 1];

void add(int u, int v) {  // 建图
  e[++tot] = {v, head[u]};
  head[u] = tot;
}

void dfs(int u, int fa) {  // 预处理dfs
  sz[u] = 1;
  dep[u] = dep[fa] + 1;
  for (int i = head[u]; i; i = e[i].next) {
    int v = e[i].to;
    if (v != fa) {
      dfs(v, u);
      sz[u] += sz[v];
    }
  }
}

void get_ans(int u, int fa) {  // 第二次dfs换根dp
  for (int i = head[u]; i; i = e[i].next) {
    int v = e[i].to;
    if (v != fa) {
      f[v] = f[u] - sz[v] * 2 + n;
      get_ans(v, u);
    }
  }
}

int main() {
  scanf("%lld", &n);
  int u, v;
  for (int i = 1; i <= n - 1; i++) {
    scanf("%d%d", &u, &v);
    add(u, v);
    add(v, u);
  }
  dfs(1, 1);
  for (int i = 1; i <= n; i++) f[1] += dep[i];
  get_ans(1, 1);
  long long int ans = -1;
  int id;
  for (int i = 1; i <= n; i++) {  // 统计答案
    if (f[i] > ans) {
      ans = f[i];
      id = i;
    }
  }
  printf("%d\n", id);
  return 0;
}

习题

字典树

定义

字典树,英文名 \(\text{trie}\)。顾名思义,就是一个像字典一样的树。

引入

先放一张图:

可以发现,这棵字典树用边来代表字母,而从根结点到树上某一结点的路径就代表了一个字符串。举个例子,\(1\to4\to 8\to 12\) 表示的就是字符串 caa

\(\text{trie}\) 的结构非常好懂,我们用 \(\delta(u,c)\) 表示结点 \(u\)\(c\) 字符指向的下一个结点,或着说是结点 \(u\) 代表的字符串后面添加一个字符 \(c\) 形成的字符串的结点。(\(c\) 的取值范围和字符集大小有关,不一定是 \(0\sim 26\)。)

有时需要标记插入进 \(\text{trie}\) 的是哪些字符串,每次插入完成时在这个字符串所代表的节点处打上标记即可。

实现

放一个结构体封装的模板:

struct trie {
  int nex[100000][26], cnt;
  bool exist[100000];  // 该结点结尾的字符串是否存在

  void insert(char *s, int l) {  // 插入字符串
    int p = 0;
    for (int i = 0; i < l; i++) {
      int c = s[i] - 'a';
      if (!nex[p][c]) nex[p][c] = ++cnt;  // 如果没有,就添加结点
      p = nex[p][c];
    }
    exist[p] = 1;
  }

  bool find(char *s, int l) {  // 查找字符串
    int p = 0;
    for (int i = 0; i < l; i++) {
      int c = s[i] - 'a';
      if (!nex[p][c]) return 0;
      p = nex[p][c];
    }
    return exist[p];
  }
};

应用

检索字符串

字典树最基础的应用——查找一个字符串是否在「字典」中出现过。

于是他错误的点名开始了

给你 \(n\) 个名字串,然后进行 \(m\) 次点名,每次你需要回答「名字不存在」、「第一次点到这个名字」、「已经点过这个名字」之一。

\(1\le n\le 10^4\)\(1\le m\le 10^5\),所有字符串长度不超过 \(50\)

题解:

对所有名字建 \(\text{trie}\),再在 \(\text{trie}\) 中查询字符串是否存在、是否已经点过名,第一次点名时标记为点过名。

参考代码:

#include <cstdio>

const int N = 500010;

char s[60];
int n, m, ch[N][26], tag[N], tot = 1;

int main() {
  scanf("%d", &n);

  for (int i = 1; i <= n; ++i) {
    scanf("%s", s + 1);
    int u = 1;
    for (int j = 1; s[j]; ++j) {
      int c = s[j] - 'a';
      if (!ch[u][c])
        ch[u][c] =
            ++tot;  // 如果这个节点的子节点中没有这个字符,添加上并将该字符的节点号记录为++tot
      u = ch[u][c];  // 往更深一层搜索
    }
    tag[u] = 1;  // 最后一个字符为节点 u 的名字未被访问到记录为 1
  }

  scanf("%d", &m);

  while (m--) {
    scanf("%s", s + 1);
    int u = 1;
    for (int j = 1; s[j]; ++j) {
      int c = s[j] - 'a';
      u = ch[u][c];
      if (!u) break;  // 不存在对应字符的出边说明名字不存在
    }
    if (tag[u] == 1) {
      tag[u] = 2;  // 最后一个字符为节点 u 的名字已经被访问
      puts("OK");
    } else if (tag[u] == 2)  // 已经被访问,重复访问
      puts("REPEAT");
    else
      puts("WRONG");
  }

  return 0;
}

AC 自动机

\(\text{trie}\)\(\text{AC}\) 自动机的一部分。

维护异或极值

将数的二进制表示看做一个字符串,就可以建出字符集为 \(\{0,1\}\)\(\text{trie}\) 树。

BZOJ1954 最长异或路径

给你一棵带边权的树,求 \((u, v)\) 使得 \(u\)\(v\) 的路径上的边权异或和最大,输出这个最大值。这里的异或和指的是所有边权的异或。

点数不超过 \(10^5\),边权在 \([0,2^{31})\) 内。

题解:

随便指定一个根 \(root\),用 \(T(u, v)\) 表示 \(u\)\(v\) 之间的路径的边权异或和,那么 \(T(u,v)=T(root, u)\oplus T(root,v)\),因为 \(\text{LCA}\) 以上的部分异或两次抵消了。

那么,如果将所有 \(T(root, u)\) 插入到一棵 \(\text{trie}\) 中,就可以对每个 \(T(root, u)\) 快速求出和它异或和最大的 \(T(root, v)\)

\(\text{trie}\) 的根开始,如果能向和 \(T(root, u)\) 的当前位不同的子树走,就向那边走,否则没有选择。

贪心的正确性:如果这么走,这一位为 \(1\);如果不这么走,这一位就会为 \(0\)。而高位是需要优先尽量大的。

参考代码:

#include <algorithm>
#include <cstdio>
using namespace std;

const int N = 100010;

int head[N], nxt[N << 1], to[N << 1], weight[N << 1], cnt;
int n, dis[N], ch[N << 5][2], tot = 1, ans;

void insert(int x) {
  for (int i = 30, u = 1; i >= 0; --i) {
    int c = ((x >> i) & 1);  // 二进制一位一位向下取
    if (!ch[u][c]) ch[u][c] = ++tot;
    u = ch[u][c];
  }
}

void get(int x) {
  int res = 0;
  for (int i = 30, u = 1; i >= 0; --i) {
    int c = ((x >> i) & 1);
    if (ch[u][c ^ 1]) {  // 如果能向和当前位不同的子树走,就向那边走
      u = ch[u][c ^ 1];
      res |= (1 << i);
    } else
      u = ch[u][c];
  }
  ans = max(ans, res);  // 更新答案
}

void add(int u, int v, int w) {  // 建边
  nxt[++cnt] = head[u];
  head[u] = cnt;
  to[cnt] = v;
  weight[cnt] = w;
}

void dfs(int u, int fa) {
  insert(dis[u]);
  get(dis[u]);
  for (int i = head[u]; i; i = nxt[i]) {  // 遍历子节点
    int v = to[i];
    if (v == fa) continue;
    dis[v] = dis[u] ^ weight[i];
    dfs(v, u);
  }
}

int main() {
  scanf("%d", &n);

  for (int i = 1; i < n; ++i) {
    int u, v, w;
    scanf("%d%d%d", &u, &v, &w);
    add(u, v, w);  // 双向边
    add(v, u, w);
  }

  dfs(1, 0);

  printf("%d", ans);

  return 0;
}

维护异或和

\(01-\text{trie}\) 是指字符集为 \(\{0,1\}\)\(\text{trie}\)\(01-\text{trie}\) 可以用来维护一些数字的异或和,支持修改(删除 + 重新插入),和全局加一(即:让其所维护所有数值递增 1,本质上是一种特殊的修改操作)。

如果要维护异或和,需要按值从低位到高位建立 \(\text{trie}\)

一个约定:文中说当前节点 往上 指当前节点到根这条路径,当前节点 往下 指当前结点的子树。

插入 & 删除

如果要维护异或和,我们 只需要 知道某一位上 01 个数的 奇偶性 即可,也就是对于数字 1 来说,当且仅当这一位上数字 1 的个数为奇数时,这一位上的数字才是 1,请时刻记住这段文字:如果只是维护异或和,我们只需要知道某一位上 1 的数量即可,而不需要知道 \(\text{trie}\) 到底维护了哪些数字。

对于每一个节点,我们需要记录以下三个量:

  • ch[o][0/1] 指节点 o 的两个儿子,ch[o][0] 指下一位是 0,同理 ch[o][1] 指下一位是 1
  • w[o] 指节点 o 到其父亲节点这条边上数值的数量(权值)。每插入一个数字 xx 二进制拆分后在 \(\text{trie}\) 上 路径的权值都会 +1
  • xorv[o] 指以 o 为根的子树维护的异或和。

具体维护结点的代码如下所示。

void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  // w[o] = w[o] & 1;
  // 只需知道奇偶性即可,不需要具体的值。当然这句话删掉也可以,因为上文就只利用了他的奇偶性。
}

插入和删除的代码非常相似。

需要注意的地方就是:

  • 这里的 MAXH\(\text{trie}\) 的深度,也就是强制让每一个叶子节点到根的距离为 MAXH。对于一些比较小的值,可能有时候不需要建立这么深(例如:如果插入数字 4,分解成二进制后为 100,从根开始插入 001 这三位即可),但是我们强制插入 MAXH 位。这样做的目的是为了便于全局 +1 时处理进位。例如:如果原数字是 311),递增之后变成 4100),如果当初插入 3 时只插入了 2 位,那这里的进位就没了。

  • 插入和删除,只需要修改叶子节点的 w[] 即可,在回溯的过程中一路维护即可。

实现:

namespace trie {
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;

int mknode() {
  ++tot;
  ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0;
  return tot;
}

void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  w[o] = w[o] & 1;
}

void insert(int &o, int x, int dp) {
  if (!o) o = mknode();
  if (dp > MAXH) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void erase(int o, int x, int dp) {
  if (dp > 20) return (void)(w[o]--);
  erase(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}
}  // namespace trie

全局加一

所谓全局加一就是指,让这棵 \(\text{trie}\) 中所有的数值 +1

形式化的讲,设 \(\text{trie}\) 中维护的数值有 \(V_1, V_2, V_3 \dots V_n\), 全局加一后 其中维护的值应该变成 \(V_1+1, V_2+1, V_3+1 \dots V_n+1\)

void addall(int o) {
  swap(ch[o][0], ch[o][1]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
过程

我们思考一下二进制意义下 +1 是如何操作的。

我们只需要从低位到高位开始找第一个出现的 0,把它变成 1,然后这个位置后面的 1 都变成 0 即可。

下面给出几个例子感受一下:(括号内的数字表示其对应的十进制数字)

1000(8)  + 1 = 1001(9)  ;
10011(19) + 1 = 10100(20) ;
11111(31) + 1 = 100000(32);
10101(21) + 1 = 10110(22) ;
100000000111111(16447) + 1 = 100000001000000(16448);

对应 \(\text{trie}\) 的操作,其实就是交换其左右儿子,顺着 交换后0 边往下递归操作即可。

回顾一下 w[o] 的定义:w[o] 指节点 o 到其父亲节点这条边上数值的数量(权值)。

有没有感觉这个定义有点怪呢?如果在父亲结点存储到两个儿子的这条边的边权也许会更接近于习惯。但是在这里,在交换左右儿子的时候,在儿子结点存储到父亲这条边的距离,显然更加方便。

01-trie 合并

指的是将上述的两个 \(01-\text{trie}\) 进行合并,同时合并维护的信息。

可能关于合并 \(\text{trie}\) 的文章比较少,其实合并 \(\text{trie}\) 和合并线段树的思路非常相似,可以搜索「合并线段树」来学习如何合并 \(\text{trie}\)

其实合并 \(\text{trie}\) 非常简单,就是考虑一下我们有一个 int merge(int a, int b) 函数,这个函数传入两个 \(\text{trie}\) 树位于同一相对位置的结点编号,然后合并完成后返回合并完成的结点编号。

过程

考虑怎么实现?

分三种情况:

  • 如果 a 没有这个位置上的结点,新合并的结点就是 b
  • 如果 b 没有这个位置上的结点,新合并的结点就是 a
  • 如果 a,b 都存在,那就把 b 的信息合并到 a 上,新合并的结点就是 a,然后递归操作处理 a 的左右儿子。

提示:如果需要的合并是将 a,b 合并到一棵新树上,这里可以新建结点,然后合并到这个新结点上,这里的代码实现仅仅是将 b 的信息合并到 a 上。

实现

int merge(int a, int b) {
  if (!a) return b;  // 如果 a 没有这个位置上的结点,返回 b
  if (!b) return a;  // 如果 b 没有这个位置上的结点,返回 a
  /*
    如果 `a`, `b` 都存在,
    那就把 `b` 的信息合并到 `a` 上。
  */
  w[a] = w[a] + w[b];
  xorv[a] ^= xorv[b];
  /* 不要使用 maintain(),
    maintain() 是合并a的两个儿子的信息
    而这里需要 a b 两个节点进行信息合并
   */
  ch[a][0] = merge(ch[a][0], ch[b][0]);
  ch[a][1] = merge(ch[a][1], ch[b][1]);
  return a;
}

其实 \(\text{trie}\) 都可以合并,换句话说,\(\text{trie}\) 合并不仅仅限于 \(01-\text{trie}\)

【luogu-P6018】【Ynoi2010】Fusion tree

给你一棵 \(n\) 个结点的树,每个结点有权值。\(m\) 次操作。

需要支持以下操作。

  • 将树上与一个节点 \(x\) 距离为 \(1\) 的节点上的权值 \(+1\)。这里树上两点间的距离定义为从一点出发到另外一点的最短路径上边的条数。
  • 在一个节点 \(x\) 上的权值 \(-v\)
  • 询问树上与一个节点 \(x\) 距离为 \(1\) 的所有节点上的权值的异或和。

对于 \(100\%\) 的数据,满足 \(1\le n \le 5\times 10^5\)\(1\le m \le 5\times 10^5\)\(0\le a_i \le 10^5\)\(1 \le x \le n\)\(opt\in\{1,2,3\}\)

保证任意时刻每个节点的权值非负。

题解:

每个结点建立一棵 \(\text{trie}\) 维护其儿子的权值,\(\text{trie}\) 应该支持全局加一。
可以使用在每一个结点上设置懒标记来标记儿子的权值的增加量。

参考代码:

#include <bits/stdc++.h>
using namespace std;
const int _ = 5e5 + 10;

namespace trie {
const int _n = _ * 25;
int rt[_];
int ch[_n][2];
int w[_n];  //`w[o]` 指节点 `o` 到其父亲节点这条边上数值的数量(权值)。
int xorv[_n];
int tot = 0;

void maintain(int o) {  // 维护w数组和xorv(权值的异或)数组
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
}

int mknode() {  // 创造一个新的节点
  ++tot;
  ch[tot][0] = ch[tot][1] = 0;
  w[tot] = 0;
  return tot;
}

void insert(int &o, int x, int dp) {  // x是权重,dp是深度
  if (!o) o = mknode();
  if (dp > 20) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void erase(int o, int x, int dp) {
  if (dp > 20) return (void)(w[o]--);
  erase(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void addall(int o) {  // 对所有节点+1即将所有节点的ch[o][1]和ch[o][0]交换
  swap(ch[o][1], ch[o][0]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
}  // namespace trie

int head[_];

struct edges {
  int node;
  int nxt;
} edge[_ << 1];

int tot = 0;

void add(int u, int v) {
  edge[++tot].nxt = head[u];
  head[u] = tot;
  edge[tot].node = v;
}

int n, m;
int rt;
int lztar[_];
int fa[_];

void dfs0(int o, int f) {  // 得到fa数组
  fa[o] = f;
  for (int i = head[o]; i; i = edge[i].nxt) {  // 遍历子节点
    int node = edge[i].node;
    if (node == f) continue;
    dfs0(node, o);
  }
}

int V[_];

int get(int x) { return (fa[x] == -1 ? 0 : lztar[fa[x]]) + V[x]; }  // 权值函数

int main() {
  cin >> n >> m;
  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    add(u, v);  // 双向建边
    add(rt = v, u);
  }
  dfs0(rt, -1);  // rt是随机的一个点
  for (int i = 1; i <= n; i++) {
    cin >> V[i];
    if (fa[i] != -1) trie::insert(trie::rt[fa[i]], V[i], 0);
  }
  while (m--) {
    int opt, x;
    cin >> opt >> x;
    if (opt == 1) {
      lztar[x]++;
      if (x != rt) {
        if (fa[fa[x]] != -1) trie::erase(trie::rt[fa[fa[x]]], get(fa[x]), 0);
        V[fa[x]]++;
        if (fa[fa[x]] != -1)
          trie::insert(trie::rt[fa[fa[x]]], get(fa[x]), 0);  // 重新插入
      }
      trie::addall(trie::rt[x]);  // 对所有节点+1
    } else if (opt == 2) {
      int v;
      cin >> v;
      if (x != rt) trie::erase(trie::rt[fa[x]], get(x), 0);
      V[x] -= v;
      if (x != rt) trie::insert(trie::rt[fa[x]], get(x), 0);  // 重新插入
    } else {
      int res = 0;
      res = trie::xorv[trie::rt[x]];
      res ^= get(fa[x]);
      printf("%d\n", res);
    }
  }
  return 0;
}

【luogu-P6623】【省选联考 2020 A 卷】树

给定一棵 \(n\) 个结点的有根树 \(T\),结点从 \(1\) 开始编号,根结点为 \(1\) 号结点,每个结点有一个正整数权值 \(v_i\)
\(x\) 号结点的子树内(包含 \(x\) 自身)的所有结点编号为 \(c_1,c_2,\dots,c_k\),定义 \(x\) 的价值为:
\(val(x)=(v_{c_1}+d(c_1,x)) \oplus (v_{c_2}+d(c_2,x)) \oplus \cdots \oplus (v_{c_k}+d(c_k, x))\) 其中 \(d(x,y)\)
表示树上 \(x\) 号结点与 \(y\) 号结点间唯一简单路径所包含的边数,\(d(x,x) = 0\)\(\oplus\) 表示异或运算。
请你求出 \(\sum\limits_{i=1}^n val(i)\) 的结果。

题解:

考虑每个结点对其所有祖先的贡献。
每个结点建立 \(\text{trie}\),初始先只存这个结点的权值,然后从底向上合并每个儿子结点上的 \(\text{trie}\),然后再全局加一,完成后统计答案。

参考代码:

const int _ = 526010;
int n;
int V[_];
int debug = 0;

namespace trie {
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;

int mknode() {
  ++tot;
  ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0;
  return tot;
}

void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  w[o] = w[o] & 1;
}

void insert(int &o, int x, int dp) {
  if (!o) o = mknode();
  if (dp > MAXH) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

int merge(int a, int b) {
  if (!a) return b;
  if (!b) return a;
  w[a] = w[a] + w[b];
  xorv[a] ^= xorv[b];
  ch[a][0] = merge(ch[a][0], ch[b][0]);
  ch[a][1] = merge(ch[a][1], ch[b][1]);
  return a;
}

void addall(int o) {
  swap(ch[o][0], ch[o][1]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
}  // namespace trie

int rt[_];
long long Ans = 0;
vector<int> E[_];

void dfs0(int o) {
  for (int i = 0; i < E[o].size(); i++) {
    int node = E[o][i];
    dfs0(node);
    rt[o] = trie::merge(rt[o], rt[node]);
  }
  trie::addall(rt[o]);
  trie::insert(rt[o], V[o], 0);
  Ans += trie::xorv[rt[o]];
}

int main() {
  n = read();
  for (int i = 1; i <= n; i++) V[i] = read();
  for (int i = 2; i <= n; i++) E[read()].push_back(i);
  dfs0(1);
  printf("%lld", Ans);
  return 0;
}

可持久化字典树

参见 可持久化字典树

posted @ 2024-07-20 07:47  So_noSlack  阅读(11)  评论(0编辑  收藏  举报