P4719 【模板】"动态 DP"&动态树分治
题目描述
给定一棵 n 个点的树,点带点权。
有 m 次操作,每次操作给定 x,y,表示修改点 x 的权值为 y。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
输入格式
第一行有两个整数,分别表示结点个数 n 和操作个数 m。
第二行有 n 个整数,第 i 个整数表示节点 iii 的权值 ai。
接下来 (n−1) 行,每行两个整数 u,v,表示存在一条连接 u 与 v 的边。
接下来 m 行,每行两个整数 x,y,表示一次操作,修改点 x 的权值为 y。
输出格式
对于每次操作,输出一行一个整数表示答案。
输入输出样例
输入 #1
10 10
-11 80 -99 -76 56 38 92 -51 -34 47
2 1
3 1
4 3
5 2
6 2
7 1
8 2
9 4
10 7
9 -44
2 -17
2 98
7 -58
8 48
3 99
8 -61
9 76
9 14
10 93
输出 #1
186
186
190
145
189
288
244
320
258
304
说明/提示
数据规模与约定
- 对于 100%1 的数据,保证 1≤n,m≤\(10^5\),1≤u,v,x≤n,\(-10^2 \leq a_i, y \leq 10^2\)。
动态DP板子题。
我们首先可以推出DP方程:\(dpi,0=∑max(dpj,0,dpj,1)\)
\(dp_{i, 1} = \sum_{j} dp_{j, 0} + a_i\)
对于每一次修改,我们要\(O(n)\)的改一条链,很不妙,所以用树链剖分,\(O(logn)\)的修改一条链。
我们用\(f_{i, 0/1}\)表示以\(i\)为根的这颗树内,轻儿子的ans。
所以DP方程还可以写成:\(dp_{i, 0} = f_{j, 0} + max(dp_{son, 1}, dp_{son, 0})\)
\(dp_{i, 1} = f_{i, 1} + dp_{son, 0}\)
这里\(son\)是轻儿子。
这样我们就可以直接维护转移矩阵,修改一条链了。
转移矩阵:
\[\left[
\begin{array}{}
f[x][0] & f[x][0]\\
f[x][1] & -\infty\\
\end{array}
\right] * \left[
\begin{array}{}
dp[son][0]\\
dp[son][1]\\
\end{array}
\right] = \left[
\begin{array}{}
dp[x][0]\\
dp[x][1]\\
\end{array}
\right]
\]
修改一下矩阵乘法规则,把加法改为取\(max\)
#include <bits/stdc++.h>
#define ls(o) (o << 1)
#define rs(o) (o << 1 | 1)
#define mid ((l + r) >> 1)
using namespace std;
inline long long read() {
long long s = 0, f = 1; char ch;
while(!isdigit(ch = getchar())) (ch == '-') && (f = -f);
for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48));
return s * f;
}
const int N = 1e5 + 5, inf = 1e9;
int n, m, cnt, tot;
int fa[N], id[N], val[N], end[N], dfn[N], top[N], siz[N], hav[N], head[N];
int f[N][2], dp[N][2];
struct edge { int to, nxt; } e[N << 1];
struct mat {
long long v[2][2];
mat() {
for(int i = 0;i <= 1; i++)
for(int j = 0;j <= 1; j++)
v[i][j] = -inf;
}
} t[N << 2], nod[N];
mat operator * (const mat &a, const mat &b) {
mat c;
for(int i = 0;i <= 1; i++)
for(int j = 0;j <= 1; j++)
for(int k = 0;k <= 1; k++)
c.v[i][j] = max(c.v[i][j], a.v[i][k] + b.v[k][j]);
return c;
}
void add(int x, int y) {
e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y;
}
void get_tree(int x, int Fa) {
siz[x] = 1; fa[x] = Fa;
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == Fa) continue;
get_tree(y, x);
siz[x] += siz[y];
if(siz[y] > siz[hav[x]]) hav[x] = y;
}
}
void get_top(int x, int topic) {
dfn[x] = ++tot; id[tot] = x; top[x] = topic;
if(hav[x] == 0) { end[topic] = tot; return ; }
get_top(hav[x], topic);
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa[x] || y == hav[x]) continue;
get_top(y, y);
}
}
void dfs(int x) {
dp[x][1] = f[x][1] = val[x];
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to; if(y == fa[x]) continue;
dfs(y);
dp[x][0] += max(dp[y][0], dp[y][1]);
dp[x][1] += dp[y][0];
if(y != hav[x]) {
f[x][0] += max(dp[y][0], dp[y][1]);
f[x][1] += dp[y][0];
}
}
}
void up(int o) {
t[o] = t[ls(o)] * t[rs(o)];
}
void build(int o, int l, int r) {
if(l == r) {
int x = id[l];
t[o].v[0][0] = t[o].v[0][1] = f[x][0];
t[o].v[1][0] = f[x][1];
nod[l] = t[o];
return ;
}
build(ls(o), l, mid); build(rs(o), mid + 1, r);
up(o);
}
void change(int o, int l, int r, int x) {
if(l == r) { t[o] = nod[l]; return ; }
if(x <= mid) change(ls(o), l, mid, x);
if(x > mid) change(rs(o), mid + 1, r, x);
up(o);
}
mat get_ans(int o, int l, int r, int x, int y) {
if(x <= l && y >= r) return t[o];
if(y <= mid) return get_ans(ls(o), l, mid, x, y);
if(x > mid) return get_ans(rs(o), mid + 1, r, x, y);
return get_ans(ls(o), l, mid, x, y) * get_ans(rs(o), mid + 1, r, x, y);
}
void update_tree(int x, int y) {
nod[dfn[x]].v[1][0] += y - val[x];
val[x] = y;
while(x) {
mat Old, New;
int a = x;
x = top[x];
Old = get_ans(1, 1, n, dfn[x], end[x]);
change(1, 1, n, dfn[a]);
New = get_ans(1, 1, n, dfn[x], end[x]);
int tmp = dfn[fa[x]];
nod[tmp].v[0][0] = nod[tmp].v[0][1] += max(New.v[0][0], New.v[1][0]) - max(Old.v[0][0], Old.v[1][0]);
nod[tmp].v[1][0] += New.v[0][0] - 0Old.v[0][0];
x = fa[x];
}
}
int main() {
n = read(); m = read();
for(int i = 1;i <= n; i++) val[i] = read();
for(int i = 1, x, y;i <= n - 1; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
get_tree(1, 0); get_top(1, 1);
dfs(1); build(1, 1, n);
for(int i = 1, x, y;i <= m; i++) {
x = read(); y = read();
update_tree(x, y);
mat ans = get_ans(1, 1, n, dfn[1], end[1]);
printf("%lld\n", max(ans.v[0][0], ans.v[1][0]));
}
return 0;
}