动态 DP 学习笔记
1 前言
动态 DP,简称 DDP。用于处理树上带修的简单 DP 问题。
前置知识:
- 树链剖分
- 线段树维护矩阵
- 树形 DP
2 基本做法
如果不带修,就是简单的树上 DP。
设 \(f_{i,0}\) 表示不选 \(i\) 点的最大权值,\(f_{i,1}\) 表示选 \(i\) 点并且的最大权值。
考虑到每次修改只会影响一条到根的链的 DP 值,需要快速修改链上的 DP 值。
第一步:动态 DP 首先套了树剖。
那么设 \(g_{i,0}\) 表示不选 \(i\) 点并且不考虑重儿子的最大权值,\(g_{i,1}\) 表示选 \(i\) 点并且不考虑重儿子的最大权值。
那么 \(f\) 的转移可以写成:
\[f_{i,0}=g_{i,0}+\max(f_{son_u,0}, f_{son_u,1})\\
f_{i,1}=g_{i,1}+f_{son_u,0}
\]
其中 \(son_u\) 表示 \(u\) 的重儿子。
第二步:将转移写成矩阵转移,可以是普通矩阵也可以是广义的。
先把转移写成相同的形式:
\[f_{i,0}=\max(f_{son_i,0}+g_{i,0},f_{son_i,1}+g_{i,0})\\
f_{i,1}=\max(g_{i,1}+f_{son_i,0},-\infty)
\]
构造矩阵:
\[\begin{bmatrix}f_{son_i,0},f_{son_i,1}\end{bmatrix}\cdot\begin{bmatrix}g_{i,0}& g_{i,1}\\g_{i,0}&-\infty\end{bmatrix}=\begin{bmatrix}f_{i,0},f_{i,1}\end{bmatrix}
\]
那么现在的转移可以看做自底向上做矩阵乘法,并且转移只与重儿子有关,轻儿子的贡献已经看做转移矩阵。
第三步:线段树维护转移矩阵。
由于树剖的性质,每一个节点的 DP 值可以看作其所在链的链底到自身的矩阵的并,并且链底的转移矩阵就是 DP 值。
矩阵不满足交换律,所以将矩阵写成转移矩阵在前的形式。
\[\cdot\begin{bmatrix}g_{i,0}& g_{i,0}\\g_{i,1}&-\infty\end{bmatrix}\begin{bmatrix}f_{son_i,0}\\f_{son_i,1}\end{bmatrix}=\begin{bmatrix}f_{i,0}\\f_{i,1}\end{bmatrix}
\]
这样子修改就可以只修改每个链的链底。相当于平均了复杂度,现在询问和修改都是 \(O(\log n)\)。
总复杂度带有矩阵的常数,大概是 \(O(2^3n\log^2n)\)。
#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define mk std::make_pair
#define pb push_back
using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, m;
int a[N];
std::vector<int> e[N]; //题目输入
int tot;
int son[N], sz[N];
int id[N], top[N], dfn[N], end[N], fa[N]; //树剖相关
int f[N][2]; //dp
struct mat {
int m[2][2];
void clr() {
for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) m[i][j] = 0;
}
friend mat operator * (mat a, mat b) {
mat ret; ret.clr();
for(int i = 0; i < 2; i++) {
for(int j = 0; j < 2; j++) {
for(int k = 0; k < 2; k++) {
ret.m[i][j] = std::max(ret.m[i][j], a.m[i][k] + b.m[k][j]);
}
}
}
return ret;
}
void print() {
std::cout << m[0][0] << " " << m[0][1] << "\n";
std::cout << m[1][0] << " " << m[1][1] << "\n";
}
} g[N]; //矩阵结构体
struct seg {
mat v[N << 2];
void pushup(int u) {v[u] = v[u << 1] * v[u << 1 | 1];}
void build(int u, int l, int r) {
if(l == r) {
v[u] = g[dfn[l]];
// v[u].print();
return;
}
int mid = (l + r) >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void upd(int u, int l, int r, int x) {
if(l == r) {v[u] = g[dfn[l]]; return;}
int mid = (l + r) >> 1;
if(x <= mid) upd(u << 1, l, mid, x);
else upd(u << 1 | 1, mid + 1, r, x);
pushup(u);
}
mat qry(int u, int l, int r, int L, int R) {
if(L <= l && r <= R) return v[u];
int mid = (l + r) >> 1;
if(R <= mid) return qry(u << 1, l, mid, L, R);
if(L > mid) return qry(u << 1 | 1, mid + 1, r, L, R);
return qry(u << 1, l, mid, L, R) * qry(u << 1 | 1, mid + 1, r, L, R);
}
} T; //维护每个节点矩阵的线段树
void dfs1(int u, int f) {
fa[u] = f, sz[u] = 1;
for(int v : e[u]) {
if(v == f) continue;
dfs1(v, u);
sz[u] += sz[v];
if(sz[son[u]] < sz[v]) son[u] = v;
}
}
void dfs2(int u, int topf) {
top[u] = topf;
id[u] = ++tot; dfn[tot] = u;
end[topf] = tot;
f[u][0] = 0, f[u][1] = a[u];
g[u].m[0][0] = g[u].m[0][1] = 0;
g[u].m[1][0] = f[u][1], g[u].m[1][1] = -iinf;
if(!son[u]) return;
dfs2(son[u], topf);
f[u][0] += std::max(f[son[u]][0], f[son[u]][1]);
f[u][1] += f[son[u]][0];
for(int v : e[u]) {
if(v == son[u] || v == fa[u]) continue;
dfs2(v, v);
f[u][0] += std::max(f[v][0], f[v][1]);
f[u][1] += f[v][0];
g[u].m[0][0] += std::max(f[v][0], f[v][1]);
g[u].m[0][1] = g[u].m[0][0];
g[u].m[1][0] += f[v][0];
}
}
void solve(int u, int x) {
g[u].m[1][0] += x - a[u];
a[u] = x;
while(u) {
mat lst = T.qry(1, 1, n, id[top[u]], end[top[u]]); //一个节点的 dp 值为其所在链的链底到自身的矩乘
T.upd(1, 1, n, id[u]);
mat cur = T.qry(1, 1, n, id[top[u]], end[top[u]]);
u = fa[top[u]];
g[u].m[0][0] += std::max(cur.m[0][0], cur.m[1][0]) - std::max(lst.m[0][0], lst.m[1][0]);
g[u].m[0][1] = g[u].m[0][0];
g[u].m[1][0] += cur.m[0][0] - lst.m[0][0];
// g[u].print();
}
mat ans = T.qry(1, 1, n, id[1], end[1]);
std::cout << std::max(ans.m[0][0], ans.m[1][0]) << "\n";
return;
}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cin >> n >> m;
for(int i = 1; i <= n; i++) std::cin >> a[i];
for(int i = 1; i < n; i++) {
int u, v;
std::cin >> u >> v;
e[u].pb(v), e[v].pb(u);
}
dfs1(1, 0), dfs2(1, 1);
T.build(1, 1, n);
while(m--) {
int x, y;
std::cin >> x >> y;
solve(x, y);
}
return 0;
}