Loading

动态 DP 学习笔记

1 前言

动态 DP,简称 DDP。用于处理树上带修的简单 DP 问题。

前置知识:

  1. 树链剖分
  2. 线段树维护矩阵
  3. 树形 DP

2 基本做法

上模板题:P4719 【模板】"动态 DP"&动态树分治

如果不带修,就是简单的树上 DP。

\(f_{i,0}\) 表示不选 \(i\) 点的最大权值,\(f_{i,1}\) 表示选 \(i\) 点并且的最大权值。

考虑到每次修改只会影响一条到根的链的 DP 值,需要快速修改链上的 DP 值。

第一步:动态 DP 首先套了树剖。

那么设 \(g_{i,0}\) 表示不选 \(i\) 点并且不考虑重儿子的最大权值,\(g_{i,1}\) 表示选 \(i\) 点并且不考虑重儿子的最大权值。

那么 \(f\) 的转移可以写成:

\[f_{i,0}=g_{i,0}+\max(f_{son_u,0}, f_{son_u,1})\\ f_{i,1}=g_{i,1}+f_{son_u,0} \]

其中 \(son_u\) 表示 \(u\) 的重儿子。

第二步:将转移写成矩阵转移,可以是普通矩阵也可以是广义的。

先把转移写成相同的形式:

\[f_{i,0}=\max(f_{son_i,0}+g_{i,0},f_{son_i,1}+g_{i,0})\\ f_{i,1}=\max(g_{i,1}+f_{son_i,0},-\infty) \]

构造矩阵:

\[\begin{bmatrix}f_{son_i,0},f_{son_i,1}\end{bmatrix}\cdot\begin{bmatrix}g_{i,0}& g_{i,1}\\g_{i,0}&-\infty\end{bmatrix}=\begin{bmatrix}f_{i,0},f_{i,1}\end{bmatrix} \]

那么现在的转移可以看做自底向上做矩阵乘法,并且转移只与重儿子有关,轻儿子的贡献已经看做转移矩阵。

第三步:线段树维护转移矩阵。

由于树剖的性质,每一个节点的 DP 值可以看作其所在链的链底到自身的矩阵的并,并且链底的转移矩阵就是 DP 值。

矩阵不满足交换律,所以将矩阵写成转移矩阵在前的形式。

\[\cdot\begin{bmatrix}g_{i,0}& g_{i,0}\\g_{i,1}&-\infty\end{bmatrix}\begin{bmatrix}f_{son_i,0}\\f_{son_i,1}\end{bmatrix}=\begin{bmatrix}f_{i,0}\\f_{i,1}\end{bmatrix} \]

这样子修改就可以只修改每个链的链底。相当于平均了复杂度,现在询问和修改都是 \(O(\log n)\)

总复杂度带有矩阵的常数,大概是 \(O(2^3n\log^2n)\)

#include <bits/stdc++.h>
#define pii std::pair<int, int>
#define fi first
#define se second
#define mk std::make_pair
#define pb push_back

using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;
const i64 iinf = 0x3f3f3f3f, linf = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10;
int n, m;
int a[N];
std::vector<int> e[N]; //题目输入
int tot;
int son[N], sz[N];
int id[N], top[N], dfn[N], end[N], fa[N]; //树剖相关
int f[N][2]; //dp
struct mat {
    int m[2][2];
    void clr() {
        for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) m[i][j] = 0;
    }
    friend mat operator * (mat a, mat b) {
        mat ret; ret.clr();
        for(int i = 0; i < 2; i++) {
            for(int j = 0; j < 2; j++) {
                for(int k = 0; k < 2; k++) {
                    ret.m[i][j] = std::max(ret.m[i][j], a.m[i][k] + b.m[k][j]);
                }
            }
        }
        return ret;
    }
    void print() {
        std::cout << m[0][0] << " " << m[0][1] << "\n";
        std::cout << m[1][0] << " " << m[1][1] << "\n";
    }
} g[N]; //矩阵结构体
struct seg {
    mat v[N << 2];
    void pushup(int u) {v[u] = v[u << 1] * v[u << 1 | 1];}
    void build(int u, int l, int r) {
        if(l == r) {
            v[u] = g[dfn[l]];
            // v[u].print();
            return; 
        }
        int mid = (l + r) >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
    void upd(int u, int l, int r, int x) {
        if(l == r) {v[u] = g[dfn[l]]; return;}
        int mid = (l + r) >> 1;
        if(x <= mid) upd(u << 1, l, mid, x);
        else upd(u << 1 | 1, mid + 1, r, x);
        pushup(u);
    } 
    mat qry(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R) return v[u];
        int mid = (l + r) >> 1;
        if(R <= mid) return qry(u << 1, l, mid, L, R);
        if(L > mid) return qry(u << 1 | 1, mid + 1, r, L, R);
        return qry(u << 1, l, mid, L, R) * qry(u << 1 | 1, mid + 1, r, L, R);
    }
} T; //维护每个节点矩阵的线段树
void dfs1(int u, int f) {
    fa[u] = f, sz[u] = 1;
    for(int v : e[u]) {
        if(v == f) continue;
        dfs1(v, u);
        sz[u] += sz[v];
        if(sz[son[u]] < sz[v]) son[u] = v;
    }
}
void dfs2(int u, int topf) {
    top[u] = topf;
    id[u] = ++tot; dfn[tot] = u;
    end[topf] = tot;
    f[u][0] = 0, f[u][1] = a[u];
    g[u].m[0][0] = g[u].m[0][1] = 0;
    g[u].m[1][0] = f[u][1], g[u].m[1][1] = -iinf;
    if(!son[u]) return;
    dfs2(son[u], topf);
    f[u][0] += std::max(f[son[u]][0], f[son[u]][1]);
    f[u][1] += f[son[u]][0];
    for(int v : e[u]) {
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
        f[u][0] += std::max(f[v][0], f[v][1]);
        f[u][1] += f[v][0];
        g[u].m[0][0] += std::max(f[v][0], f[v][1]);
        g[u].m[0][1] = g[u].m[0][0];
        g[u].m[1][0] += f[v][0];
    }

}
void solve(int u, int x) {
    g[u].m[1][0] += x - a[u];
    a[u] = x;
    while(u) {
        mat lst = T.qry(1, 1, n, id[top[u]], end[top[u]]); //一个节点的 dp 值为其所在链的链底到自身的矩乘
        T.upd(1, 1, n, id[u]);
        mat cur = T.qry(1, 1, n, id[top[u]], end[top[u]]);
        u = fa[top[u]];
        g[u].m[0][0] += std::max(cur.m[0][0], cur.m[1][0]) - std::max(lst.m[0][0], lst.m[1][0]);
        g[u].m[0][1] = g[u].m[0][0];
        g[u].m[1][0] += cur.m[0][0] - lst.m[0][0];
        // g[u].print();
    }
    mat ans = T.qry(1, 1, n, id[1], end[1]);
    std::cout << std::max(ans.m[0][0], ans.m[1][0]) << "\n";
    return;
}
int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    
    std::cin >> n >> m;
    for(int i = 1; i <= n; i++) std::cin >> a[i];
    for(int i = 1; i < n; i++) {
        int u, v;
        std::cin >> u >> v;
        e[u].pb(v), e[v].pb(u);
    }
    dfs1(1, 0), dfs2(1, 1);
    T.build(1, 1, n);
    while(m--) {
        int x, y;
        std::cin >> x >> y;
        solve(x, y);
    }

    return 0;
}
posted @ 2024-11-23 16:04  Fire_Raku  阅读(8)  评论(0编辑  收藏  举报