树形DP

树形 DP 即在树上进行的 DP。

常见的两种转移方向:

  • 父节点 子节点:如求节点深度,dpu=dpfa+1
  • 子节点 父节点:如求子树大小,dpu=1+dpv

例题:P5658 [CSP-S2019] 括号树

分析:本题 n 小的数据点保证为链,直接枚举 i,代表从根节点到 i 号节点。枚举 [1,i] 中的子区间左右端点 [l,r],判断该子串是否符合括号匹配。

参考代码
#include <cstdio>
typedef long long LL;
const int N = 500005;
char s[N];
int f[N];
bool check(int l, int r) {
int left = 0;
for (int i = l; i <= r; i++) {
if (s[i] == '(') left++;
else if (left == 0) return false;
else left--;
}
return left == 0;
}
int main()
{
int n; scanf("%d%s", &n, s + 1);
for (int i = 2; i <= n; i++) scanf("%d", &f[i]);
LL ans = 0;
for (int i = 2; i <= n; i++) {
// 1~i
LL cnt = 0;
for (int l = 1; l <= i; l++) {
for (int r = l; r <= i; r++) {
// [l, r]
if (check(l, r)) {
cnt++;
}
}
}
ans ^= cnt * i;
}
printf("%lld\n", ans);
return 0;
}

实际得分 20 分。

考虑数据为链但 n5×105 的问题做法,此时可以看作是一个线性序列上的问题。

考虑用 dpi 表示以 si 结尾的合法括号串,如果 si 是左括号,显然 dpi=0;而如果 si 是右括号,这时实际上需要找到与该右括号对应匹配的左括号(这个问题可以借助一个栈来实现),则该左括号到当前的右括号构成了一个合法括号串,而实际上如果这个左括号的前一位是一个合法括号串的结尾,那么之前的合法括号串拼上刚匹配的这个括号串也是一个合法括号串,因此这时 dpi=dppre1+1,这里的 pre 代表当前右括号匹配的左括号的位置。

题目要求计算的是合法括号子串的数量,因此只需计算 dp 结果的前缀和即为前 i 个字符形成的字符串中合法括号子串的数量。

参考代码
#include <cstdio>
#include <stack>
using namespace std;
typedef long long LL;
const int N = 500005;
char s[N];
int f[N];
LL dp[N]; // 以s[i]结尾的括号子串数量
LL sum[N]; // 1~i中的括号子串数量,即dp的前缀和
int main()
{
int n; scanf("%d%s", &n, s + 1);
for (int i = 2; i <= n; i++) scanf("%d", &f[i]);
stack<int> stk; // 记录左括号的位置
LL ans = 0;
for (int i = 1; i <= n; i++) {
if (s[i] == '(') {
stk.push(i);
} else if (!stk.empty()) {
int pre = stk.top();
stk.pop();
dp[i] = dp[pre - 1] + 1;
}
sum[i] = sum[i - 1] + dp[i];
ans ^= sum[i] * i;
}
printf("%lld\n", ans);
return 0;
}

实际得分 55 分。

把处理链的思路转化到任意树上。

其中 dpsum 的计算方式可以类推过来,只不过链上通过“减一”表达上一个位置的方式对应到树上要变成“父节点”。因此原来的计算式子需要调整一下:

  • dpi=dppre1+1dpi=dpfpre+1
  • sumi=sumi1+dpisumi=sumfi+dpi

除此以外,还需要解决树上的括号栈的递归与回溯问题。发生回溯后,栈里的信息可能会和当前状态不匹配。比如某个节点(左括号)有多棵子树,进入其中一棵子树之后,该子树中的右括号匹配掉了这个左括号(出栈),而接下来再进入下一棵子树时这个左括号依然需要在栈中。

因此回溯时,我们要执行当时递归时相反的操作。比如,当前节点是右括号,如果此时栈不为空,栈会弹出一个元素以匹配当前右括号。我们可以记录这个信息,在最后回溯前把它重新压入栈中,保持状态的一致性。

参考代码
#include <cstdio>
#include <vector>
#include <stack>
using namespace std;
typedef long long LL;
const int N = 500005;
char s[N];
vector<int> tree[N];
stack<int> stk;
int f[N];
LL dp[N], sum[N];
void dfs(int cur, int fa) {
int tmp = 0;
if (s[cur] == '(') {
stk.push(cur); tmp = -1;
dp[cur] = 0;
} else if (stk.empty()) {
dp[cur] = 0;
} else {
tmp = stk.top(); stk.pop();
dp[cur] = dp[f[tmp]] + 1;
}
sum[cur] = sum[fa] + dp[cur];
for (int to : tree[cur]) dfs(to, cur);
if (tmp == -1) stk.pop();
else if (tmp > 0) stk.push(tmp);
}
int main()
{
int n;
scanf("%d%s", &n, s + 1);
for (int i = 2; i <= n; i++) {
scanf("%d", &f[i]);
tree[f[i]].push_back(i);
}
dfs(1, 0);
LL ans = 0;
for (int i = 1; i <= n; i++) ans ^= (sum[i] * i);
printf("%lld\n", ans);
return 0;
}

例题:P7073 [CSP-J2020] 表达式

分析:对于初始情况,可以通过后缀表达式与栈建立二叉树,通过树上 DP 进行结果计算。

对于 30% 的数据,可以修改后重新进行 DP。

参考代码
#include <cstdio>
#include <stack>
using std::stack;
const int N = 1000005;
const int OFFSET = 100000;
char s[10];
int dp[N], lc[N], rc[N], t[N];
void dfs(int u) {
if (lc[u]) dfs(lc[u]);
if (rc[u]) dfs(rc[u]);
if (t[u] == 1) dp[u] = !dp[lc[u]];
else if (t[u] == 2) dp[u] = dp[lc[u]] & dp[rc[u]];
else if (t[u] == 3) dp[u] = dp[lc[u]] | dp[rc[u]];
}
int main()
{
int n, len = OFFSET;
stack<int> stk;
while (true) {
scanf("%s", s);
if (s[0] >= '0' && s[0] <= '9') {
sscanf(s, "%d", &n);
break;
}
if (s[0] == 'x') {
int xid; sscanf(s + 1, "%d", &xid);
stk.push(xid);
} else if (s[0] == '!') {
t[++len] = 1;
lc[len] = stk.top();
stk.pop();
stk.push(len);
} else {
t[++len] = (s[0] == '&' ? 2 : 3);
rc[len] = stk.top(); stk.pop();
lc[len] = stk.top(); stk.pop();
stk.push(len);
}
}
for (int i = 1; i <= n; i++) scanf("%d", &dp[i]);
int q; scanf("%d", &q);
for (int i = 1; i <= q; i++) {
int idx; scanf("%d", &idx);
dp[idx] = !dp[idx];
dfs(len);
printf("%d\n", dp[len]);
dp[idx] = !dp[idx];
}
return 0;
}

而要想通过所有的数据点,必须一次知道每个变量改变后的算式结果,或者说改变这一项会不会改变算式结果,可以从根节点开始讨论,弄清楚当这一棵子树的值改变时计算结果是否会改变,只递归进入会对算式结果产生改变的子树,到达叶节点时对相应项进行标记。

对于操作符为 ! 的情况,一定进入它的子节点。

对于操作符为 & 的情况,如果两个子节点的计算结果均为 1,则不管哪棵子树的计算结果发生变化都会对当前节点产生影响,则两边都要递归下去,如果一个值为 1,另一个为 0,则进入值为 0 的子树,如果两个值都为 0,则不用继续递归。

对于操作符为 | 的情况,如果两个子节点的计算结果均为 0,则不管哪棵子树的计算结果发生变化都会对当前节点产生影响,则两边都要递归下去,如果一个值为 1,另一个为 0,则进入值为 1 的子树,如果两个值都为 1,则不用继续递归。

对于每次询问,如果该项无标记,则答案为一开始的计算结果,否则为原始计算结果取反。

参考代码
#include <cstdio>
#include <stack>
using std::stack;
const int N = 1000005;
const int OFFSET = 100000;
char s[10];
int dp[N], lc[N], rc[N], t[N], flag[N];
void dfs(int u) {
if (lc[u]) dfs(lc[u]);
if (rc[u]) dfs(rc[u]);
if (t[u] == 1) dp[u] = !dp[lc[u]];
else if (t[u] == 2) dp[u] = dp[lc[u]] & dp[rc[u]];
else if (t[u] == 3) dp[u] = dp[lc[u]] | dp[rc[u]];
}
void calc(int u) {
if (t[u] == 0) {
flag[u] = 1; return;
}
if (t[u] == 1) {
calc(lc[u]);
} else {
int l = dp[lc[u]], r = dp[rc[u]];
if (t[u] == 2) {
if (l == 1 && r == 1) {
calc(lc[u]); calc(rc[u]);
} else if (l == 1 && r == 0) {
calc(rc[u]);
} else if (l == 0 && r == 1) {
calc(lc[u]);
}
} else {
if (l == 0 && r == 0) {
calc(lc[u]); calc(rc[u]);
} else if (l == 1 && r == 0) {
calc(lc[u]);
} else if (l == 0 && r == 1) {
calc(rc[u]);
}
}
}
}
int main()
{
int n, len = OFFSET;
stack<int> stk;
while (true) {
scanf("%s", s);
if (s[0] >= '0' && s[0] <= '9') {
sscanf(s, "%d", &n);
break;
}
if (s[0] == 'x') {
int xid; sscanf(s + 1, "%d", &xid);
stk.push(xid);
} else if (s[0] == '!') {
t[++len] = 1;
lc[len] = stk.top();
stk.pop();
stk.push(len);
} else {
t[++len] = (s[0] == '&' ? 2 : 3);
rc[len] = stk.top(); stk.pop();
lc[len] = stk.top(); stk.pop();
stk.push(len);
}
}
for (int i = 1; i <= n; i++) scanf("%d", &dp[i]);
dfs(len);
calc(len);
int q; scanf("%d", &q);
for (int i = 1; i <= q; i++) {
int idx; scanf("%d", &idx);
printf("%d\n", dp[len] ^ flag[idx]);
}
return 0;
}

例题:P4084 [USACO17DEC] Barn Painting G

分析:状态设计比较直接,设 dpu,c 表示以 u 为根节点的子树,节点 u 的颜色为 c 的方案数,即对于所有初始状态,dpu,1=dpu,2=dpu,3=1,如果某个节点被上了指定的颜色,那么该节点的状态中另外两种上色状态方案数为 0

对于每个节点,由于不能与子节点颜色相同,则有:

  • dpu,1=vsonu(dpv,2+dpv,3)
  • dpu,2=vsonu(dpv,1+dpv,3)
  • dpu,3=vsonu(dpv,1+dpv,2)
参考代码
#include <cstdio>
#include <vector>
using namespace std;
const int N = 100005;
const int MOD = 1000000007;
vector<int> tree[N];
int c[N], dp[N][4];
void dfs(int u, int fa) {
int ans1 = 1, ans2 = 1, ans3 = 1;
for (int v : tree[u]) {
if (v == fa) continue;
dfs(v, u);
ans1 = 1ll * (dp[v][2] + dp[v][3]) % MOD * ans1 % MOD;
ans2 = 1ll * (dp[v][1] + dp[v][3]) % MOD * ans2 % MOD;
ans3 = 1ll * (dp[v][1] + dp[v][2]) % MOD * ans3 % MOD;
}
if (c[u] == 0 || c[u] == 1) dp[u][1] = ans1;
if (c[u] == 0 || c[u] == 2) dp[u][2] = ans2;
if (c[u] == 0 || c[u] == 3) dp[u][3] = ans3;
}
int main()
{
int n, k;
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++) {
int x, y;
scanf("%d%d", &x, &y);
tree[x].push_back(y); tree[y].push_back(x);
}
while (k--) {
int b;
scanf("%d", &b); scanf("%d", &c[b]);
}
dfs(1, 0);
printf("%d\n", ((dp[1][1] + dp[1][2]) % MOD + dp[1][3]) % MOD);
return 0;
}

例题:P3576 [POI2014] MRO-Ant colony

分析:从叶子节点往上推到食蚁兽所在的边不好做,但是食蚁兽在那条边上捕食的一定是正好 k 只蚂蚁。

所以可以从食蚁兽所在的边开始推,把这条边的两端点连上一个虚点 0 作为根,容易发现如果最后正好有 k 只蚂蚁爬到根,那么在每一条边上可行的蚂蚁数量都是连续区间。

不妨设 lu 表示到达 u 这个点,沿着 u 往父节点爬时该边可行蚂蚁数量的最小值(闭区间),ru 表示最大值(开区间),则有 lu=lfa×(degfa1)ru=rfa×(degfa1),初始值 l0=kr0=k+1

最后对于每个叶节点,看有多少蚁群满足数量在 [lu,ru) 之间,这个二分查找即可,统计数量后乘 k 即为答案。

参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using std::vector;
using std::sort;
using std::lower_bound;
using ll = long long;
const int N = 1000005;
int m[N], n, g, k;
ll l[N], r[N], ans;
vector<int> tree[N];
void dfs(int u, int fa) {
int deg = tree[u].size();
if (deg == 1) {
ans += lower_bound(m + 1, m + g + 1, r[u]) - lower_bound(m + 1, m + g + 1, l[u]);
return;
}
for (int v : tree[u]) {
if (v == fa) continue;
l[v] = l[u] * (deg - 1); r[v] = r[u] * (deg - 1);
dfs(v, u);
}
}
int main()
{
scanf("%d%d%d", &n, &g, &k);
for (int i = 1; i <= g; i++) scanf("%d", &m[i]);
sort(m + 1, m + g + 1);
int a, b; scanf("%d%d", &a, &b);
tree[0].push_back(a); tree[0].push_back(b);
tree[a].push_back(0); tree[b].push_back(0);
for (int i = 2; i < n; i++) {
int a, b; scanf("%d%d", &a, &b);
tree[a].push_back(b); tree[b].push_back(a);
}
l[0] = k; r[0] = k + 1;
dfs(0, 0);
printf("%lld\n", ans * k);
return 0;
}

习题:P2899 [USACO08JAN] Cell Phone Network G

解题思路

注意本题和 P2016 战略游戏 的区别,战略游戏是选择一些点从而覆盖所有的边,本题是选择一些点从而覆盖所有的点。

在战略游戏中,一条边可能会被两端的点覆盖到,因此对于每个点对应的子树需要设计两个状态(选/不选)。类似地,在本题中,我们可以要分三种状态:

  • dpu,0 表示 u 被自己覆盖的情况下对应子树的最少信号塔数量
  • dpu,1 表示 u 被子节点覆盖的情况下对应子树的最少信号塔数量
  • dpu,2 表示 u 被父节点覆盖的情况下对应子树的最少信号塔数量

则有状态转移:

  • dpu,0=vsonumin{dpv,0,dpv,1,dpv,2},因为 u 处自己放置了信号塔,因此子节点处放或不放都可以
  • dpu,1=dpv,0+vsonuvvmin{dpv,0,dpv,1},此时至少要有一个子节点放置信号塔,其他可放可不放,因此 v 应该是所有子节点 vdpv,0min{dpv,0,dpv,1} 最小的那个子节点;注意若 u 没有子树即 u 为叶子节点,此时 dpu,1=1
  • dpu,2=vsonumin{dpv,0,dpv,1},因为本节点处不放,靠父节点放置来覆盖,所以子节点中除了状态 2 以外都可以
参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 10005;
vector<int> tree[N];
int dp[N][3];
// dp[u][0]:u处放置
// dp[u][1]:u处依赖子节点放置
// dp[u][2]:u处依赖父节点放置
void dfs(int u, int fa) {
dp[u][0] = 1;
int best = -1;
for (int v : tree[u]) {
if (v == fa) continue;
dfs(v, u);
dp[u][0] += min(min(dp[v][0], dp[v][1]), dp[v][2]);
dp[u][2] += min(dp[v][0], dp[v][1]);
dp[u][1] += min(dp[v][0], dp[v][1]);
// 寻找必须要放置的那个子节点
int cur_diff = dp[v][0] - min(dp[v][0], dp[v][1]);
int best_diff = dp[best][0] - min(dp[best][0], dp[best][1]);
if (best == -1 || cur_diff < best_diff)
best = v;
}
if (best != -1) {
// 至少要在一个子节点处放置
dp[u][1] += dp[best][0] - min(dp[best][0], dp[best][1]);
} else {
dp[u][1] = 1; // 没有子树,必须放置
}
}
int main()
{
int n; scanf("%d", &n);
for (int i = 1; i < n; i++) {
int a, b; scanf("%d%d", &a, &b);
tree[a].push_back(b);
tree[b].push_back(a);
}
dfs(1, 0);
printf("%d\n", min(dp[1][0], dp[1][1]));
return 0;
}

习题:P3574 [POI2014] FAR-FarmCraft

解题思路

dpu 表示假如在 0 时刻到达 u,它的子树被安装好的时间。

发现对下面子树的遍历顺序会影响最终结果,考虑这个顺序,假设针对 u 的某两棵子树 v1v2

  • 假设先走 v1 再走 v2,则此时可能的完成时间是 max(1+dpv1,2×szv1+1+dpv2),前者表示 v1 那棵子树完成时间更晚,后者表示 v2 那棵子树完成时间更晚,此时要先走完 v1 子树再走到 v2 才能加 dpv2
  • 假设先走 v2 再走 v1,则此时可能的完成时间是 max(1+dpv2,2×szv2+1+dpv1)

显然我们希望 v1 子树和 v2 子树形成更好的遍历顺序,考虑按这上面的式子对子树排序。

注意,用上面的式子比大小对子节点排序需要证明 max(1+dpv1,2×szv1+1+dpv2)<max(1+dpv2,2×szv2+1+dpv1) 这个式子具有传递性。这是可以证明的:假设小于号前面的式子取到第一项,此时这个式子必然满足,因为小于号后面式子的第二项必然比它大,传递性显然成立;假如小于号前面的式子取到第二项,此时相当于需要 2×szv1+dpv2<2×szv2+dpv1,这个式子经过移项可以使得小于号左边只和 v1 有关,右边只和 v2 有关,因此传递性得证。

所以我们可以按这种方式对子树排序,按照子树的遍历依次更新 dpu,这里的转移式是 2×sum+1+dpv,其中 sum 代表在 v 这棵子树之前的子树大小总和。

注意最后答案是 dp12×(n1)+c1 的较大值,因为题目要求走一圈后回到点 1 才能开始给 1 装软件。

参考代码
#include <cstdio>
#include <vector>
#include <algorithm>
using std::vector;
using std::sort;
using std::max;
const int N = 5e5 + 5;
vector<int> tree[N];
int c[N], sz[N], n, dp[N];
void dfs(int u, int fa) {
dp[u] = c[u]; sz[u] = 1;
for (int v : tree[u]) {
if (v == fa) continue;
dfs(v, u);
sz[u] += sz[v];
}
sort(tree[u].begin(), tree[u].end(), [](int i, int j) {
int i_before_j = max(1 + dp[i], 2 * sz[i] + 1 + dp[j]);
int j_before_i = max(1 + dp[j], 2 * sz[j] + 1 + dp[i]);
return i_before_j < j_before_i;
});
int sum = 0;
for (int v : tree[u]) {
if (v == fa) continue;
dp[u] = max(dp[u], 2 * sum + 1 + dp[v]);
sum += sz[v];
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &c[i]);
}
for (int i = 1; i < n; i++) {
int a, b; scanf("%d%d", &a, &b);
tree[a].push_back(b); tree[b].push_back(a);
}
dfs(1, 0);
// 1号点要等回来才能装,所以要考虑2*(n-1)+c[1]
printf("%d\n", max(dp[1], 2 * (n - 1) + c[1]));
return 0;
}
posted @   RonChen  阅读(136)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 【杂谈】分布式事务——高大上的无用知识?
点击右上角即可分享
微信分享提示