Loading

动态 dp

前置知识 - 矩阵

动态 dp

广义矩阵乘法

例题

桌面上摆放着 \(n\) 颗糖果,第 \(i\) 颗糖果的美味值是 \(a_i\)。你可以从桌上拿走一些糖果,但是你不能拿走相邻的两颗糖果。求出你所能拿走的糖果的最大美味值之和。


我们先定义一个普通的 dp 状态:

\(dp_{i, 0 / 1}\) 表示考虑前 \(i\) 颗糖果,第 \(i\) 颗糖果拿或不拿所能得到的美味值的最大值总和。

显然的,我们可以写出转移:

\[\begin{cases} dp_{i, 0} = \max(dp_{i - 1, 0}, dp_{i - 1, 1}) \\ dp_{i, 1} = dp_{i - 1, 0} + a_i\end{cases} \]

所以,\(dp_{i + 1, 0}, dp_{i + 1, 1}\) 的转移是这样的:

\[\begin{cases} dp_{i + 1, 0} = \max(dp_{i, 0}, dp_{i, 1}) \\ dp_{i + 1, 1} = dp_{i, 0} + a_{i + 1}\end{cases} \]

我们尝试将上式代入下式,可以得到:

\[\begin{cases} dp_{i + 1, 0} = \max(\max(dp_{i - 1, 0}, dp_{i - 1, 1}), dp_{i - 1, 0} + a_i) \\ dp_{i + 1, 1} = \max(dp_{i - 1, 0}, dp_{i - 1, 1}) + a_{i + 1}\end{cases} \]

我们将他们都变成只取 \(\max\) 的形式:

\[\begin{cases} dp_{i + 1, 0} = \max(dp_{i - 1, 0}, dp_{i - 1, 1}, dp_{i - 1, 0} + a_i) \\ dp_{i + 1, 1} = \max(dp_{i - 1, 0} + a_{i + 1}, dp_{i - 1, 1} + a_{i + 1})\end{cases} \]

我们可以发现,当 \(a_i < 0\) 时,必然不会选择这颗糖,所以,我们可以简化一下式子:

\[\begin{cases} dp_{i, 0} = \max(dp_{i - 1, 0}, dp_{i - 1, 1}) \\ dp_{i, 1} = \max(dp_{i - 1, 0} + a_i, dp_{i - 1, 1} - \infty)\end{cases} \]

\[\begin{cases} dp_{i + 1, 0} = \max(dp_{i - 1, 1}, dp_{i - 1, 0} + a_i) \\ dp_{i + 1, 1} = \max(dp_{i - 1, 0} + a_{i + 1}, dp_{i - 1, 1} + a_{i + 1})\end{cases} \]

Max(min) - plus 矩阵运算

我们将上面两个式子对应的矩阵写下来:

\[\begin{cases} dp_{i, 0} = \max(dp_{i - 1, 0}, dp_{i - 1, 1}) \\ dp_{i, 1} = \max(dp_{i - 1, 0} + a_i, dp_{i - 1, 1} - \infty)\end{cases} \ \ \ \rightarrow \ \ \ \begin{bmatrix} 0 \ \ \ \ \ \ \ \ 0 \\ a_i \ -\infty\end{bmatrix} \]

\[\begin{cases} dp_{i + 1, 0} = \max(dp_{i - 1, 1}, dp_{i - 1, 0} + a_i) \\ dp_{i + 1, 1} = \max(dp_{i - 1, 0} + a_{i + 1}, dp_{i - 1, 1} + a_{i + 1})\end{cases} \ \ \ \ \rightarrow \ \ \ \ \begin{bmatrix} a_i \ \ \ \ \ \ \ \ 0 \\ a_{i + 1} \ \ \ a_{i + 1}\end{bmatrix} \]

我们定义 Max(min) - plus 矩阵运算是如下的运算:

\[\begin{cases} y_1 = \max(x_1 + k_{1, 1}, x_2 + k_{1, 2}, \dots) \\ y_2 = \max(x_1 + k_{2, 1}, x_2 + k_{2, 2}, \dots) \\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \vdots\end{cases} = \begin{bmatrix} k_{1, 1} \ \ \ k_{1, 2} \ \ \dots \\ k_{2, 1} \ \ \ k_{2, 2} \ \ \dots \\ \vdots\end{bmatrix} \]

\[\begin{bmatrix} a \ \ \ \ b \\ c \ \ \ \ d \end{bmatrix} \cdot \begin{bmatrix} e \ \ \ \ f \\ g \ \ \ \ h\end{bmatrix} = \begin{bmatrix} \max(a + e,, b + g) \ \ \ \ max(a + f, b + h) \\ \max(c + e, d + g) \ \ \ \max(b + f, d + h) \end{bmatrix} \]


在这种运算下,我们会发现,\(\begin{bmatrix} 0 \ \ \ \ \ \ \ \ 0 \\ a_{i + 1} \ -\infty\end{bmatrix} \cdot \begin{bmatrix} 0 \ \ \ \ \ \ \ \ 0 \\ a_i \ -\infty\end{bmatrix} = \begin{bmatrix} a_i \ \ \ \ \ \ \ \ 0 \\ a_{i + 1} \ \ \ a_{i + 1}\end{bmatrix}\)

因此,我们可以将 dp 式子看作矩阵连乘的结果。

所以,当参与 dp 的元素改变时,我们只需要修改其对应的矩阵,并用数据结构维护矩阵连乘的结果,实现快速 dp。

CSES-1724

题意

给一个有 \(n\) 个点和 \(m\) 条有向边的图,每条边都有边权,请求出从点 \(1\) 到点 \(n\) 恰好经过 \(k\) 条边的最小路径长度。

思路

我们可以想出一个 dp 状态,\(dp_{i, j}\) 表示从点 \(1\) 出发,经过 \(j\) 条边到达点 \(i\) 的最小路径长度。

但是,这道题的 \(k\) 太大了,我们不能暴力转移。

首先,我们会发现,每次转移的边都是同样的,也就是说,如果在只走 \(j - 1\) 条边时,可以从 \(u\) 转移到 \(v\),那么在走 \(j\) 条边时,也是可以从 \(u\) 转移到 \(v\) 的,所以,我们可以考虑使用广义矩阵乘法加速这个过程。

所以,我们会发现,只有存在一条 \(u\)\(v\) 的边时,才可以从 \(u\) 转移到 \(v\),因此,我们将初始的矩阵中每条边对应的 \(a_{u, v}\) 的权值改成它们之间的边权。

之后直接矩阵快速幂即可。

代码

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

const int N = 110;
const ll INF = 2e18;

struct matrix {
    ll a[N][N];
    matrix operator * (const matrix &x) {
        matrix ret;
        for (int i = 1; i < N; i++) {
            for (int j = 1; j < N; j++) {
                ret.a[i][j] = INF;
                for (int k = 1; k < N; k++) {
                    ret.a[i][j] = min(ret.a[i][j], a[i][k] + x.a[k][j]);
                }
            }
        }
        return ret;
    }
} a;

int n, m, k;

matrix qpow(matrix x, int y) {
    matrix ret;
    for (int i = 1; i < N; i++) {
        for (int j = 1; j < N; j++) ret.a[i][j] = INF;
        ret.a[i][i] = 0;
    }
    while (y) {
        if (y & 1) ret = ret * x;
        x = x * x, y >>= 1;
    }
    return ret;
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m >> k;
    for (int i = 1; i < N; i++) {
        for (int j = 1; j < N; j++) a.a[i][j] = INF;
    }
    while (m--) {
        int u, v, w; cin >> u >> v >> w;
        a.a[u][v] = min(a.a[u][v], 0ll + w);
    }
    a = qpow(a, k);
    cout << (a.a[1][n] == INF ? -1 : a.a[1][n]);
    return 0;
}

洛谷 P4719

题意

给定一棵 \(n\) 个点的树,点带点权。

\(m\) 次操作,每次操作给定 \(x, y\),表示修改点 \(x\) 的权值为 \(y\)

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

思路

首先,我们考虑不带修改要怎么做:

我们设 \(dp_{i, 0 / 1}\) 表示当前考虑到第 \(i\) 个点,第 \(i\) 个点不选或选的权值总和的最大值。

那么对于每一个 \(u\),都有:

\[dp_{u, 0} = \sum\limits_{(u, v) \in E} \max(dp_{v, 0}, dp_{v, 1}) \\ dp_{u, 1} = a_u + \sum \limits_{(u, v) \in E} dp_{v, 1} \]

那么,显然的,我们需要修改,又需要维护矩阵的乘积,可以考虑树链剖分。

我们先将整棵树剖成许多条重链。

然后,由于树链剖分的性质:每条重链上的 \(dfs\) 序是连续的。

因此,我们可以先用线段树维护矩阵的乘积,对于每一条重链,将这条重链的乘积更新到这条重链的链头的父亲上,从下往上依次计算。

然而,我们每次更新某个点的权值,都最多只会影响 \(\log n\) 条重链。

也就是说,每次最多只需要更新 \(\log n\) 的点的权值。

因此,时间复杂度为 \(O(n \log ^ 2 n)\)

代码

#include <bits/stdc++.h>

using namespace std;
using pii = pair<int, int>;

const int N = 1e5 + 10, INF = 2e7;

struct matrix {
    int a[2][2];
    matrix operator * (const matrix &x) {
        matrix ret;
        for (int i = 0; i < 2; i++) {
            for (int j = 0; j < 2; j++) {
                ret.a[i][j] = -INF;
                for (int k = 0; k < 2; k++) {
                    ret.a[i][j] = max(ret.a[i][j], a[i][k] + x.a[k][j]);
                }
            }
        }
        return ret;
    }
} tr[4 * N];

int n, m, a[N];
pii val[N];
vector<int> g[N];
int fa[N], son[N], sz[N], dep[N];
int tot, top[N], dfn[N], rnk[N];
int l[N], r[N];

void dfs1(int u, int f) {
    fa[u] = f, sz[u] = 1, son[u] = -1, dep[u] = dep[f] + 1;
    for (int v : g[u]) {
        if (v != f) {
            dfs1(v, u);
            if (son[u] == -1 || sz[v] > sz[son[u]]) son[u] = v;
        }
    }
}

void dfs2(int u, int tp) {
    dfn[u] = ++tot, rnk[tot] = u, top[u] = tp, r[tp] = tot;
    if (!l[tp]) l[tp] = tot;
    if (son[u] != -1) dfs2(son[u], tp);
    for (int v : g[u]) {
        if (v != fa[u] && v != son[u]) dfs2(v, v);
    }
}

matrix build(int i, int l, int r) {
    if (l == r) return tr[i] = {0, 0, a[rnk[l]], -INF};
    int mid = (l + r) >> 1;
    return tr[i] = build(i * 2, l, mid) * build(i * 2 + 1, mid + 1, r);
}

void modify(int i, int l, int r, int pos, int x, int y) {
    if (l == r) {
        tr[i].a[0][0] += x, tr[i].a[0][1] += x, tr[i].a[1][0] += y;
        return ;
    }
    int mid = (l + r) >> 1;
    pos <= mid ? modify(i * 2, l, mid, pos, x, y) : modify(i * 2 + 1, mid + 1, r, pos, x, y);
    tr[i] = tr[i * 2] * tr[i * 2 + 1];
}

matrix query(int i, int l, int r, int ql, int qr) {
    if (qr < l || ql > r) return {0, -INF, -INF, 0};
    if (ql <= l && r <= qr) return tr[i];
    int mid = (l + r) >> 1;
    return query(i * 2, l, mid, ql, qr) * query(i * 2 + 1, mid + 1, r, ql, qr);
}

void Modify(int u, int x) {
    int t = u;
    while (top[t] != 1) {
        auto [x, y] = val[top[t]];
        modify(1, 1, n, dfn[fa[top[t]]], -x, -y), t = fa[top[t]];
    }

    modify(1, 1, n, dfn[u], 0, x - a[u]), a[u] = x;
    while (top[u] != 1) {
        matrix k = query(1, 1, n, l[top[u]], r[top[u]]);
        int x = max({k.a[0][0], k.a[0][1], k.a[1][0], k.a[1][1]}), y = max(k.a[0][0], k.a[0][1]);
        val[top[u]] = {x, y}, modify(1, 1, n, dfn[fa[top[u]]], x, y), u = fa[top[u]];
    }
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = 1, u, v; i < n; i++) {
        cin >> u >> v;
        g[u].push_back(v), g[v].push_back(u);
    }
    dfs1(1, 0), dfs2(1, 1), build(1, 1, n);
    for (int i = n; i >= 1; i--) {
        if (l[top[rnk[i]]] == i && i != 1) {
            matrix k = query(1, 1, n, i, r[rnk[i]]);
            int x = max({k.a[0][0], k.a[0][1], k.a[1][0], k.a[1][1]}), y = max(k.a[0][0], k.a[0][1]);
            val[rnk[i]] = {x, y}, modify(1, 1, n, dfn[fa[rnk[i]]], x, y);
        }
    }
    while (m--) {
        int pos, x; cin >> pos >> x, Modify(pos, x);
        matrix ans = query(1, 1, n, l[1], r[1]);
        cout << max({ans.a[0][0], ans.a[0][1], ans.a[1][0], ans.a[1][1], 0}) << '\n';
    }
    return 0;
}
posted @ 2024-08-19 21:10  chengning0909  阅读(4)  评论(0编辑  收藏  举报