【模板】"动态 DP"&动态树分治
动态dp
#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;
template<class T> inline void read(T &x) {
x = 0; RG char c = getchar(); bool f = 0;
while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
x = f ? -x : x;
return ;
}
template<class T> inline void write(T x) {
if (!x) {putchar(48);return ;}
if (x < 0) x = -x, putchar('-');
int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
const int N = 100010, inf = 2147483647;
struct node {
int to, nxt;
}g[N << 1];
int last[N], gl, v[N], n;
void add(int x, int y) {
g[++gl] = (node) {y, last[x]};
last[x] = gl;
g[++gl] = (node) {x, last[y]};
last[y] = gl;
}
struct Matrix {
LL s[2][2];
Matrix operator * (const Matrix &A) const {
Matrix res;
res.s[0][0] = max(s[0][0] + A.s[0][0], s[0][1] + A.s[1][0]);
res.s[0][1] = max(s[0][0] + A.s[0][1], s[0][1] + A.s[1][1]);
res.s[1][0] = max(s[1][0] + A.s[0][0], s[1][1] + A.s[1][0]);
res.s[1][1] = max(s[1][0] + A.s[0][1], s[1][1] + A.s[1][1]);
return res;
}
}t[N << 2], tmp[N];
int dfn[N], siz[N], son[N], top[N], cnt, fa[N], pos[N], ed[N];
void dfs1(int u, int ff) {
siz[u] = 1;
for (int i = last[u]; i; i = g[i].nxt) {
int v = g[i].to; if (v == ff) continue;
fa[v] = u; dfs1(v, u); siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int topf) {
top[u] = topf; pos[u] = ++cnt; dfn[cnt] = u;
if (!son[u]) { ed[u] = u; return ; }
dfs2(son[u], topf);
ed[u] = ed[son[u]];
for (int i = last[u]; i; i = g[i].nxt) {
int v = g[i].to; if (v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
LL f[N][2];
void dp(int u, int ff) {
f[u][1] = v[u];
for (int i = last[u]; i; i = g[i].nxt) {
int v = g[i].to; if (v == ff) continue;
dp(v, u);
f[u][0] += max(f[v][1], f[v][0]);
f[u][1] += f[v][0];
}
return ;
}
//---------------------
#define lson (o << 1)
#define rson (o << 1 | 1)
void build(int o, int l, int r) {
if (l == r) {
int u = dfn[l], g0 = 0, g1 = v[u];
for (int i = last[u]; i; i = g[i].nxt)
if (g[i].to != son[u] && g[i].to != fa[u])
g0 += max(f[g[i].to][0], f[g[i].to][1]), g1 += f[g[i].to][0];
// printf("%d %d %d %d %d\n", u, g0, g1, son[u], fa[u]);
tmp[l] = t[o] = (Matrix) {g0, g0, g1, -inf};
return ;
}
int mid = (l + r) >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
t[o] = t[lson] * t[rson];
}
void Modify(int o, int l, int r, int p) {
if (l == r) { t[o] = tmp[p]; return ; }
int mid = (l + r) >> 1;
if (p <= mid) Modify(lson, l, mid, p);
else Modify(rson, mid + 1, r, p);
t[o] = t[lson] * t[rson];
}
Matrix query(int o, int l, int r, int L, int R) {
if (L <= l && r <= R) return t[o];
int mid = (l + r) >> 1;
if (R <= mid) return query(lson, l, mid, L, R);
if (L > mid) return query(rson, mid + 1, r, L, R);
return query(lson, l, mid, L, R) * query(rson, mid + 1, r, L, R);
}
Matrix getans(int x) { return query(1, 1, n, pos[top[x]], pos[ed[x]]); }
void Modify(int u, int w) {
tmp[pos[u]].s[1][0] += w - v[u]; v[u] = w;
while (u) {
Matrix a = getans(u); Modify(1, 1, n, pos[u]); Matrix b = getans(u);
u = fa[top[u]]; if (!u) break;
tmp[pos[u]].s[0][1] = (tmp[pos[u]].s[0][0] += max(b.s[0][0], b.s[1][0]) - max(a.s[0][0], a.s[1][0]));
tmp[pos[u]].s[1][0] += b.s[0][0] - a.s[0][0];
}
}
//---------------------------
int main() {
int m, x, w; Matrix ans;
read(n), read(m);
for (int i = 1; i <= n; i++) read(v[i]);
for (int i = 1, x, y; i < n; i++) { read(x), read(y); add(x, y); }
dfs1(1, 0); dfs2(1, 1); dp(1, 0); build(1, 1, n);
// ans = getans(1);
// printf("%lld\n", max(ans.s[0][0], ans.s[1][0]));
while (m--) {
read(x), read(w);
Modify(x, w);
ans = getans(1);
printf("%lld\n", max(ans.s[0][0], ans.s[1][0]));
}
return 0;
}
整体\(dp\)
整体\(dp\)大概是对于时间建一颗线段树,叶子节点表示在该时刻的\(dp\)答案,类似线段是分治。
对于一个转移,等于给区间打标记。
对于树,我们用线段树合并将子树\(dp\)合并。
这个时候我们需要用一些方式维护\(DP\)。
还是用动态\(dp\)的矩阵方式维护\(DP\)
\(\begin{bmatrix}f_{v,0}&f_{v,1}\end{bmatrix}\times \begin{bmatrix}f_{u,0}&f_{u,1}\\f_{u,0}&-\infty\end{bmatrix}=\begin{bmatrix}f_{u',0}&f_{u',1}\end{bmatrix}\)
但是这样的话,矩形没转移一次就会遍一次,就不能快速合并儿子。
我们可以先转一下矩阵
\(\begin{bmatrix}f_{u,0}&f_{u,1}\end{bmatrix}\times \begin{bmatrix}0&0\\0&-\infty\end{bmatrix}=\begin{bmatrix}max(f_{u,0},f_{u,1})&f_{u,0}\end{bmatrix}\)
然后
\(\begin{bmatrix}\sum max(f_{v,0},f_{v,1})&\sum f_{v,0}\end{bmatrix}\times \begin{bmatrix}0&0\\-\infty&w\end{bmatrix}=\begin{bmatrix}f_{u,0}&f_{u,1}\end{bmatrix}\)
这样就可以线段树合并求出第一个矩阵,然后算出\(u\)的\(DP\)值。
还有一个问题就是,区间加的线段树怎么合并?
其实,我们每次合并两个节点是\(pushdown\)一下,然后如果有一个已经是叶子了,就直接合并到另一个上去即可。
#include<bits/stdc++.h>
#define mp make_pair
#define LL long long
using namespace std;
template<class T> T gi() {
T x = 0; bool f = 0; char c = getchar();
while (c != '-' && (c < '0' || c > '9')) c = getchar();
if (c == '-') f = 1, c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f ? -x : x;
}
const int N = 1e5 + 10, inf = 1e9 + 7;
int n, m;
vector<int> e[N];
void add(int x, int y) { e[x].push_back(y), e[y].push_back(x); }
int rt[N], tot;
struct Matrix {
int a[2][2];
int *operator [] (int x) { return a[x]; }
Matrix operator * (const Matrix &z) const {
Matrix res;
res[0][0] = max(a[0][0] + z.a[0][0], a[0][1] + z.a[1][0]);
res[0][1] = max(a[0][0] + z.a[0][1], a[0][1] + z.a[1][1]);
res[1][0] = max(a[1][0] + z.a[0][0], a[1][1] + z.a[1][0]);
res[1][1] = max(a[1][0] + z.a[0][1], a[1][1] + z.a[1][1]);
return res;
}
bool operator != (const Matrix &z) const {
return a[0][0] != z.a[0][0] || a[0][1] != z.a[0][1] || a[1][0] != z.a[1][0] || a[1][1] != z.a[1][1];
}
} t[N << 6], I;
Matrix mk(int x, int y) { return (Matrix) {{{x, -inf}, {-inf, y}}}; }
int ch[N << 6][2], w[N];
struct node { int l, r, w; };
vector<node> q[N];
void pushdown(int o) {
if (t[o] != I) {
if (!ch[o][0])
t[ch[o][0] = ++tot] = t[o];
else t[ch[o][0]] = t[ch[o][0]] * t[o];
if (!ch[o][1])
t[ch[o][1] = ++tot] = t[o];
else t[ch[o][1]] = t[ch[o][1]] * t[o];
t[o] = I;
}
}
int merge(int x, int y) {
if (!x || !y) return x | y;
if (!ch[x][0] && !ch[x][1]) swap(x, y);
if (!ch[y][0] && !ch[y][1]) {
t[x] = t[x] * mk(t[y][0][0], t[y][0][1]);
return x;
}
pushdown(x), pushdown(y);
ch[x][0] = merge(ch[x][0], ch[y][0]);
ch[x][1] = merge(ch[x][1], ch[y][1]);
return x;
}
void Modify(int o, int l, int r, int L, int R, int k) {
if (L <= l && r <= R) return (void) (t[o] = t[o] * (Matrix) {{{0, -inf}, {0, k}}});
int mid = (l + r) >> 1; pushdown(o);
if (L <= mid) Modify(ch[o][0], l, mid, L, R, k);
if (R > mid) Modify(ch[o][1], mid + 1, r, L, R, k);
return ;
}
void dfs(int u, int ff) {
t[rt[u] = ++tot] = (Matrix) {{{0, 0}, {0, 0}}};
for (auto v : e[u])
if (v != ff)
dfs(v, u), rt[u] = merge(rt[u], rt[v]);
for (auto i : q[u]) if (i.l <= i.r) Modify(rt[u], 1, m, i.l, i.r, i.w);
t[rt[u]] = t[rt[u]] * (Matrix) {{{0, 0}, {0, -inf}}};
}
void dfs2(int o, int l, int r) {
if (l == r) { printf("%d\n", t[o][0][0]); return ; }
int mid = (l + r) >> 1; pushdown(o);
dfs2(ch[o][0], l, mid), dfs2(ch[o][1], mid + 1, r);
}
int main() {
n = gi<int>(), m = gi<int>(); I = mk(0, 0);
for (int i = 1; i <= n; i++) q[i].push_back((node) {1, m, gi<int>()});
for (int i = 1; i < n; i++) add(gi<int>(), gi<int>());
for (int i = 1; i <= m; i++) {
int x = gi<int>(), y = gi<int>();
q[x].rbegin() -> r = i - 1;
q[x].push_back((node) {i, m, y});
}
dfs(1, 0); dfs2(rt[1], 1, m);
return 0;
}