Loading

“动态dp” 学习笔记

前置

这应该算一个神仙算法了吧,至少我<-这个蒟蒻学完是觉得这样的。

首现,你得会这些东西。

矩阵乘法树链剖分,矩阵乘法优化dp。

是不是有一种心态碎成渣渣的欲望。

模板题:

P4719 【模板】"动态 DP"&动态树分治

对于LCT的做法,我太弱了,不会LCT,所以可以去看别的文章。

思路

dp部分

首先,我们可以考虑怎么暴力做这道题目。

矩阵乘法+dp!!!

我们尝试来推一下这个式子:

我们设 \(f_{i , 1}\) 为选择 \(i\)\(i\) 的子树的最大权独立集的权值大小。

\(f_{i , 0}\) 为不选择 \(i\)\(i\) 的子树的最大权独立集的权值大小。

则有一个清晰易懂的式子:

\[f_{i , 0} = \sum_{son=1}\max(f_{son,0},f_{son,1}) \]

\[f_{i , 1} = \sum_{son=1}f_{son,0}+a_i \]

最后的答案就是 \(\max(f_{1, 0}, f_{1, 1})\)

然后,我们发现,这个东西带修以后,如果直接去跑,复杂度当场炸掉,有没有什么更快的方法呢。

这个时候,就需要引出我们的树链剖分了。

树链剖分部分

如果使用树链剖分,我们可以在 \(O(\log(n)^2)\) 的复杂度下,实现单点修改。

这个复杂度和普通的树链剖分是一样的。

至于复杂度的证明。

我们可一从dp本身的角度考虑。

由于dp最主要的一点就是无后效性,所以它其实和直接的用数据结构维护区间和差不多。

我们现在只需要考虑如何在线段树里\(O(1)\)的修改和查询。

矩阵部分

你可以看到,题解区对于这一块的详细讲解非常少,对于我这种对矩阵乘法不太熟的萌新非常不友好。

所以我在这里会详细的说明矩阵部分。

以便我自己理解(狗头)

我们刚刚讲到。

需要考虑如何在线段树里 \(O(1)\) 的修改和查询。

我们发现,可以用矩阵乘法进行优化。

\(g_{i,1}\) 表示 \(i\) 号点的所有轻儿子,都不取的最大权独立集;\(g_{i, 0}\) 表示 \(i\) 号点的所有轻儿子,可取可不取形成的最大权独立集。

那么刚刚那个式子就被简化成了:

\[f_{i,0} = \max(f_{j , 0}+g_{i,0},f_{j,1}+g_{i,0}) \]

\[f_{i,1} = \max(g_{i,1}+f_{j,0},-\infty) \]

这里参考了题解区题解

我们可以考虑重定义区间乘法。

即:

mat operator *(const mat &tmp) const
{
    mat res; res.init();
    for(int i = 0;i <= 1;i++)
        for(int j = 0;j <= 1;j++)
            for(int k = 0;k <= 1;k++)
                res.a[i][j] = max(res.a[i][j] , a[i][k] + tmp.a[k][j]);
    return res;
}

至于矩阵的推导式子:

\[\begin{vmatrix}f_{p,0}&f_{p,0}\\f_{p,1}&-\infty\end{vmatrix} = \begin{vmatrix}f_{l,0}&f_{l,0}\\f_{l,1}&-\infty\end{vmatrix} * \begin{vmatrix}f_{r,0}&f_{r,0}\\f_{r,1}&-\infty\end{vmatrix} \]

其中,\(p\) 为当前节点,\(l\) 为左子节点 \(r\) 为右子节点。

在建树的时候,将矩阵初始化为:

\[\begin{vmatrix}f_{i,0}&f_{i,0}\\f_{i,1}&-\infty\end{vmatrix} \]

然后直接用矩阵乘法进行 \(O(1)\) 转移。

当然,在树剖往上跳的过程中,重链头部的矩阵同样需要转移。

我们设两个转移矩阵为:

\[\begin{vmatrix}a_{1}&b_{1}\\c_{1}&d_{1}\end{vmatrix} \]

\[\begin{vmatrix}a_{2}&b_{2}\\c_{2}&d_{2}\end{vmatrix} \]

要被转移的矩阵为:

\[\begin{vmatrix}x&y\\z&u\end{vmatrix} \]

那么转移方程为:

\[\begin{vmatrix}x&y\\z&u\end{vmatrix} = \begin{vmatrix}x+\max(a_{2},c_{1})-\max(a_{1},c_{1})&x+\max(a_{2},c_{1})-\max(a_{1},c_{1})\\z+a_{2}-a_{1}&u\end{vmatrix}\]

至于为什么是两个转移矩阵,我会在代码部分说明。

这样,我们就成功的完成了 \(O(1)\) 的推导了(建议自己手推一下前两个)。

代码实现

代码实现比较难,我刷了整整一版提交记录,才卡了过去。

代码长达4.5KB(码风问题)。

细节看注释吧。

#include <bits/stdc++.h>
using namespace std;
const int inf = 99999999;
const int manx = 1e5 + 5;
int n , m , cnt , tot , a[manx] , nv[manx];
int dp[manx][3] , f[manx][3] ,  head[manx];
//dp数组

struct edge
{
    int to , nxt;
}e[manx * 2];
//连边

struct tree
{
    int siz , id , fa , top , son , dep , end;
}t[manx];
//树链剖分节点信息,注意这里需要多统计一个end,表示重链的尾巴。

struct ST
{
    int l , r;
    #define l(x) st[x].l
    #define r(x) st[x].r
    #define lp p * 2
    #define rp p * 2 + 1
}st[manx * 4];
//线段树

struct mat
{
    int a[3][3];
    void init()
    {
        for(int i = 0;i <= 1;i++)
            for(int j = 0;j <= 1;j++)
                a[i][j] = -inf;
    }
    mat operator *(const mat &tmp) const
    {
        mat res; res.init();
        for(int i = 0;i <= 1;i++)
            for(int j = 0;j <= 1;j++)
                for(int k = 0;k <= 1;k++)
                    res.a[i][j] = max(res.a[i][j] , a[i][k] + tmp.a[k][j]);
        return res;
    }
    //重定义的乘法
}ans, tr[manx * 4] , v[manx];
//矩阵
//tr为线段树中的矩阵,v[i]为编号为i的节点的矩阵。

int read()
{
    int asd = 0 , qwe = 1; char zxc;
    while(!isdigit(zxc = getchar())) if(zxc == '-') qwe = -1;
    while(isdigit(zxc)) asd = asd * 10 + zxc - '0' , zxc = getchar();
    return asd * qwe;
}

void add(int x , int y)
{
    e[++cnt] = (edge){y , head[x]} , head[x] = cnt;
    e[++cnt] = (edge){x , head[y]} , head[y] = cnt;
}

//TODO begin

void dfs1(int now , int com)
{
    t[now].siz = 1 , t[now].dep = t[com].dep + 1 , t[now].fa = com;
    int h = 0 , son = 0;
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == com) continue;
        dfs1(e[i].to , now);
        if(t[e[i].to].siz > h) h = t[e[i].to].siz , son = e[i].to;
        t[now].siz += t[e[i].to].siz;
    }
    t[now].son = son;
}
//树链剖分第一次dfs

void dfs2(int now , int gf)
{
    // cout << now << endl;
    t[now].id = ++tot; nv[tot] = now;
    t[gf].end = tot , t[now].top = gf;
    if(t[now].son) dfs2(t[now].son , gf);
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == t[now].fa || e[i].to == t[now].son)
            continue;
        dfs2(e[i].to , e[i].to);
    }
}
//树链剖分第二次dfs,注意要统计end。

void dfs3(int now)
{
    f[now][1] = a[now];
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == t[now].fa || e[i].to == t[now].son) continue;
        dfs3(e[i].to); 
        f[now][0] += max(dp[e[i].to][1] , dp[e[i].to][0]);
        f[now][1] += dp[e[i].to][0];
    }
    dp[now][0] += f[now][0] , dp[now][1] += f[now][1];
    if(!t[now].son) return;
    dfs3(t[now].son);
    dp[now][0] += max(dp[t[now].son][1] , dp[t[now].son][0]);
    dp[now][1] += dp[t[now].son][0];
}
//初始的答案计算。

//TODO seg_Tree

void build(int p , int l , int r)
{
    l(p) = l , r(p) = r;
    if(l == r)
    {
        // cout << p << " " << l << " " << r << " " << nv[l] << endl;
        v[nv[l]].a[0][0] = f[nv[l]][0] , v[nv[l]].a[1][0] = f[nv[l]][1];
        v[nv[l]].a[0][1] = f[nv[l]][0] , v[nv[l]].a[1][1] = -inf;
        //此处为第二个矩阵乘法推导式。
        tr[p] = v[nv[l]];
        return;
    }
    int mid = (l + r) >> 1;
    build(lp , l , mid) , build(rp , mid + 1 , r);
    tr[p] = tr[lp] * tr[rp];
    //此处为第一个推导式。
}
//建树

mat ask(int p , int l , int r)
{
    if(l <= l(p) && r(p) <= r) return tr[p];
    int mid = (l(p) + r(p)) >> 1;
    if(mid >= r) return ask(lp , l , r);
    if(mid < l) return ask(rp , l , r);
    // cout << p << " " << l << " " << r << endl;
    return ask(lp , l , r) * ask(rp , l , r);
    //同样的第一个推导式。
}

void update(int p , int id)
{
    if(l(p) == r(p))
    {
        // cout << p << " " << nv[id] << endl;
        tr[p] = v[nv[id]];
        //直接赋值就可以了。
        return;
    }
    int mid = (l(p) + r(p)) >> 1;
    if(mid >= id) update(lp , id);
    else update(rp , id);
    tr[p] = tr[lp] * tr[rp];
    //同样的。
}

void change(int u , int w)
{
    v[u].a[1][0] += w - a[u] , a[u] = w;
    //修改当前矩阵。
    while(u != 0)
    {
        mat x , y;
        int now = t[u].top;
        x = ask(1 , t[now].id , t[now].end);
        //修改前查询。
        update(1 , t[u].id);
        //将当前矩阵的修改转移到线段树上去。
        y = ask(1 , t[now].id , t[now].end);
        ///修改后查询。
        u = t[now].fa;
        //往上跳。
        // cout << u << " " << v[u].a[1][0] << " " << v[u].a[0][0] << endl;
        v[u].a[0][0] += max(y.a[0][0] , y.a[1][0]) - max(x.a[0][0], x.a[1][0] );
        v[u].a[0][1] = v[u].a[0][0];
        v[u].a[1][0] += y.a[0][0] - x.a[0][0];
        //第三个矩阵推导式,读者自行理解(可怜孩子并不会)。
    }
}
//树链剖分往上跳。

int main()
{
    n = read() , m = read();
    for(int i = 1;i <= n;i++) a[i] = read();
    for(int i = 1;i < n;i++)
    {
        int x = read() , y = read();
        add(x , y);
    }
    dfs1(1 , 0) , dfs2(1 , 1) , dfs3(1) , build(1 , 1 , n);
    for(int i = 1;i <= m;i++)
    {
        int x = read() , y = read();
        change(x , y);
        ans = ask(1 , t[1].id , t[1].end);
        // cout << ans.a[0][0] << " " << ans.a[1][0] << endl;
        cout << max(ans.a[0][0] , ans.a[1][0]) << endl;
        //max中两个分别对应暴力算法中的f[1][1]和f[1][0]。
    }
    return 0;
}

感受到算法的难度了吗,反正我感受到了。

一道例题

P5024 [NOIP2018 提高组] 保卫王国

有没有觉得现在看这道题感觉很简单。

因为最小权覆盖集 = 全集 - 最大权独立集。

所以直接修改查询就可以了。

当城市 \(a\) 不得驻扎军队时。

\(a\) 增加 \(\infty\)

当城市 \(a\) 必须驻扎军队时。

\(a\) 减少 \(\infty\)

如果查询的答案为 \(\infty\)

则为无解。

Code

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int inf = 9999999999;
const int manx = 1e5 + 5;
int n , m , num , cnt , tot , a[manx] , nv[manx];
int dp[manx][3] , f[manx][3] ,  head[manx];
string type;

struct edge
{
    int to , nxt;
}e[manx * 2];

struct tree
{
    int siz , id , fa , top , son , dep , end;
}t[manx];

struct ST
{
    int l , r;
    #define l(x) st[x].l
    #define r(x) st[x].r
    #define lp p * 2
    #define rp p * 2 + 1
}st[manx * 4];

struct mat
{
    int a[3][3];
    void init()
    {
        for(int i = 0;i <= 1;i++)
            for(int j = 0;j <= 1;j++)
                a[i][j] = -inf;
    }
    mat operator *(const mat &tmp) const
    {
        mat res; res.init();
        for(int i = 0;i <= 1;i++)
            for(int j = 0;j <= 1;j++)
                for(int k = 0;k <= 1;k++)
                    res.a[i][j] = max(res.a[i][j] , a[i][k] + tmp.a[k][j]);
        return res;
    }
}ans, tr[manx * 4] , v[manx];

inline int read()
{
    int asd = 0 , qwe = 1; char zxc;
    while(!isdigit(zxc = getchar())) if(zxc == '-') qwe = -1;
    while(isdigit(zxc)) asd = asd * 10 + zxc - '0' , zxc = getchar();
    return asd * qwe;
}

inline void add(int x , int y)
{
    e[++cnt] = (edge){y , head[x]} , head[x] = cnt;
    e[++cnt] = (edge){x , head[y]} , head[y] = cnt;
}

//TODO begin

inline void dfs1(int now , int com)
{
    t[now].siz = 1 , t[now].dep = t[com].dep + 1 , t[now].fa = com;
    int h = 0 , son = 0;
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == com) continue;
        dfs1(e[i].to , now);
        if(t[e[i].to].siz > h) h = t[e[i].to].siz , son = e[i].to;
        t[now].siz += t[e[i].to].siz;
    }
    t[now].son = son;
}

inline void dfs2(int now , int gf)
{
    // cout << now << endl;
    t[now].id = ++tot; nv[tot] = now;
    t[gf].end = tot , t[now].top = gf;
    if(t[now].son) dfs2(t[now].son , gf);
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == t[now].fa || e[i].to == t[now].son)
            continue;
        dfs2(e[i].to , e[i].to);
    }
}

inline void dfs3(int now)
{
    f[now][1] = a[now];
    for(int i = head[now];i;i = e[i].nxt)
    {
        if(e[i].to == t[now].fa || e[i].to == t[now].son) continue;
        dfs3(e[i].to); 
        f[now][0] += max(dp[e[i].to][1] , dp[e[i].to][0]);
        f[now][1] += dp[e[i].to][0];
    }
    dp[now][0] += f[now][0] , dp[now][1] += f[now][1];
    if(!t[now].son) return;
    dfs3(t[now].son);
    dp[now][0] += max(dp[t[now].son][1] , dp[t[now].son][0]);
    dp[now][1] += dp[t[now].son][0];
}

//TODO seg_Tree

inline void build(int p , int l , int r)
{
    l(p) = l , r(p) = r;
    if(l == r)
    {
        // cout << p << " " << l << " " << r << " " << nv[l] << endl;
        v[nv[l]].a[0][0] = f[nv[l]][0] , v[nv[l]].a[1][0] = f[nv[l]][1];
        v[nv[l]].a[0][1] = f[nv[l]][0] , v[nv[l]].a[1][1] = -inf;
        tr[p] = v[nv[l]];
        return;
    }
    int mid = (l + r) >> 1;
    build(lp , l , mid) , build(rp , mid + 1 , r);
    tr[p] = tr[lp] * tr[rp];
}

inline mat ask(int p , int l , int r)
{
    if(l <= l(p) && r(p) <= r) return tr[p];
    int mid = (l(p) + r(p)) >> 1;
    if(mid >= r) return ask(lp , l , r);
    if(mid < l) return ask(rp , l , r);
    // cout << p << " " << l << " " << r << endl;
    return ask(lp , l , r) * ask(rp , l , r);
}

inline void update(int p , int id)
{
    if(l(p) == r(p))
    {
        // cout << p << " " << nv[id] << endl;
        tr[p] = v[nv[id]];
        return;
    }
    int mid = (l(p) + r(p)) >> 1;
    if(mid >= id) update(lp , id);
    else update(rp , id);
    tr[p] = tr[lp] * tr[rp];
}

inline void change(int u , int w)
{
    v[u].a[1][0] += w , a[u] += w;
    while(u != 0)
    {
        mat x , y;
        int now = t[u].top;
        x = ask(1 , t[now].id , t[now].end);
        update(1 , t[u].id);
        y = ask(1 , t[now].id , t[now].end);
        u = t[now].fa;
        // cout << u << " " << v[u].a[1][0] << " " << v[u].a[0][0] << endl;
        v[u].a[0][0] += max(y.a[0][0] , y.a[1][0]) - max(x.a[0][0], x.a[1][0] );
        v[u].a[0][1] = v[u].a[0][0];
        v[u].a[1][0] += y.a[0][0] - x.a[0][0];
    }
}

signed main()
{
    n = read() , m = read(); cin >> type; 
    for(int i = 1;i <= n;i++) a[i] = read() , num += a[i];
    for(int i = 1;i < n;i++)
    {
        int x = read() , y = read();
        add(x , y);
    }
    dfs1(1 , 0) , dfs2(1 , 1) , dfs3(1) , build(1 , 1 , n);
    for(int i = 1;i <= m;i++)
    {
        int x1 = read() , y1 = read() , x2 = read() , y2 = read() , sum = 0;
        change(x1 , (y1 ? -inf : inf));
        change(x2 , (y2 ? -inf : inf));
        sum = ((y1 ^ 1) + (y2 ^ 1)) * inf;
        ans = ask(1 , t[1].id , t[1].end);
        sum = max(ans.a[0][0] , ans.a[1][0]) - sum;
        change(x1 , (y1 ? inf : -inf));
        change(x2 , (y2 ? inf : -inf));
        if(num - sum > inf) cout << -1 << endl;
        else cout << num - sum << endl;
    }
    return 0;
}
posted @ 2021-11-13 21:22  JiaY19  阅读(68)  评论(0编辑  收藏  举报