动态 dp 学习笔记

一、矩阵乘法

普通矩阵乘法

相信大家对矩阵乘法都不陌生,普通的矩阵乘法定义如下:

对于 \(n\times m\) 的矩阵 \(A\)\(m\times q\) 的矩阵 \(B\) ,定义 \(C=A\cdot B\) ,其中:

\[c_{i,j}=\sum_{k=1}^ma_{i,k}\cdot b_{k,j}\\ \]

矩阵 \(C\) 的大小为 \(n\times k\) ,单次矩阵乘法的时间复杂度为 \(O(nmk)\)

可以简记为:\(C\) 中第 \(i\) 行第 \(j\) 列的元素,等于 \(A\) 的第 \(i\) 行与 \(B\) 的第 \(j\) 列对应相乘再相加。

mat operator*(const mat &a,const mat &b)///传参O(1),传矩阵O(n^2)
{
    static mat c;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
        {
            c.v[i][j]=0;
            for(int k=1;k<=n;k++) c.v[i][j]=(c.v[i][j]+1ll*a.v[i][k]*b.v[k][j])%mod;
        }
    return c;
}

广义矩阵乘法

不过在动态 \(\texttt{DP}\) 中,更常见的是 \((\min,+),(\max,+)\) 广义矩阵乘法

\[c_{i,j}=\min_{1\le k\le m}\big(a_{i,k}+b_{k,j}\big)\\ c_{i,j}=\max_{1\le k\le m}\big(a_{i,k}+b_{k,j}\big)\\ \]

常用卡常方法

设矩阵阶数为 \(w\) ,由于**矩阵乘法时间复杂度 \(O(w^3)\) **,而题目中 \(w\) 一般为常数,所以配上数据结构后很容易成为卡常重灾区。

矩阵乘法有两种常见卡常方法:

  • 循环展开。

    原理是手动展开从而避免使用 for 循环,仅适用于 \(n\) 很小的情形(一般\(n=2\))。

    struct mat
    {
        int a,b,c,d;
    };
    mat operator*(const mat &x,const mat &y)
    {
        return {min(x.a+y.a,x.b+y.b),min(x.a+y.b,x.b+y.d),min(x.c+y.a,x.d+y.c),min(x.c+y.b,x.d+y.d)};
    }
    
  • 减少取模。

    如果模数 \(p\approx 10^9\) ,那么 long long 可以承受 \(9\cdot p^2\) 的数据量。

    mat operator*(const mat &a,const mat &b)
    {
        static mat c;
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
            {
                ll res=0;
                for(int k=1;k<=n;k++) res+=1ll*a.v[i][k]*b.v[k][j];
                c.v[i][j]=res%mod;
            }
        return c;
    }
    

二、动态 \(\texttt{DP}\) 概述

动态 \(\texttt{DP}\) 用于解决带修树形 \(\texttt{DP}\) 问题

话不多说,先上模板题。

约定 \(\sum\limits_{v\in son(u)}\) 表示对 \(u\) 的所有子节点 \(v\) 求和, \(\sum\limits_{v\neq wson_u}\) 表示对 \(u\) 的所有轻儿子 \(v\) 求和。


例1、\(\texttt{P4719 【模板】"动态 DP"\&动态树分治}\)

题目描述

给定一棵 \(n\) 个节点的树,点有点权 \(w_i\)

\(m\) 次单点修改点权的操作,每次操作后询问最大权独立集。

数据范围

  • \(1\le n,m\le 10^5,0\le |w_i|\le 10^2\)

时间限制 \(\texttt{1s}\) ,空间限制 \(\texttt{250MB}\)

分析

先考虑链怎么做。

\(f_{i,0/1}\) 表示仅考虑前 \(i\) 个数,不选/选第 \(i\) 个节点的最大收益。

\[f_{i,0}=\max(f_{i-1,0},f_{i-1,1})\\ f_{i,1}=f_{i-1,0}+w_i\\ \]

写成 \((\max,+)\) 广义矩阵乘法:

\[\begin{bmatrix} f_{i,0}&f_{i,1}\\ \end{bmatrix} = \begin{bmatrix} f_{i-1,0}&f_{i-1,1}\\ \end{bmatrix} \times \begin{bmatrix} 0&w_i\\ 0&-\infty\\ \end{bmatrix} \]

修改点权等价于修改单个矩阵,线段树维护单点修改区间乘积即可。


再考虑树的情况。

\(f_{u,0/1}\) 表示仅考虑 \(u\) 子树,不选\(/\)\(u\) 的最大收益。

\[f_{u,0}=\sum_{v\in son(u)}\max(f_{v,0},f_{v,1})\\ f_{u,1}=w_u+\sum_{v\in son(u)}f_{v,0}\\ \]

树链剖分转化为链上的问题,发现重儿子比较特殊,我们把轻儿子放在一起考虑。

\(g_{u,0/1}\) 表示仅考虑 \(u\) 自身及其轻子树,不选\(/\)\(u\) 的最大收益。

\[g_{u,0}=\sum_{v\neq wson_u}\max(f_{v,0},f_{v,1})\\ g_{u,1}=w_u+\sum_{v\neq wson_u}f_{v,0}\\ \]

看起来好像变复杂了,但我们可以反过来用 \(g\) 化简 \(f\)

\[f_{u,0}=g_{u,0}+\max(f_{wson_u,0},f_{wson_u,1})\\ f_{u,1}=g_{u,1}+f_{wson_u,0}\\ \]

终于转化成熟悉的 \((\max,+)\) 广义矩阵乘法!

\[\begin{bmatrix} f_{u,0}&f_{u,1}\\ \end{bmatrix} = \begin{bmatrix} f_{wson_u,0}&f_{wson_u,1}\\ \end{bmatrix} \times \begin{bmatrix} g_{u,0}&g_{u,1}\\ g_{u,0}&-\infty\\ \end{bmatrix} \]

然后考虑怎么带修。

有一个显然的性质:只有 \(u\) 到根的路径上的点的 \(f\) 值会改变。

\(f\) 是简单的,线段树查询重链( \(dfn\) 区间)的矩阵乘积即可。

然而我们更关心哪些节点的 \(g\) 值会改变,因为 \(g\) 的变化会引起矩阵的变化。

结论是,只有 \(u\) 自身及所有 fa[top[u]]\(g\) 值会改变!

top[u] 跳到 fa[top[u]] 的过程中,我们要先减掉原本对 fa[top[u]] 的贡献,更新信息后再加入新的对 fa[top[u]] 的贡献

时间复杂度 \(\mathcal O(w^3n\log^2n)\) ,本题 \(w=2\)


至此本题的大致思路就讲完了,还有一些代码实现上的细节问题。

  • 初始的 \(f,g\) 需要预处理。

    考虑到 \(f\) 需要遍历所有儿子, \(g\) 需要遍历所有重儿子,因此,在树剖第一遍 \(dfs\) 中预处理 \(f\) ,第二遍 \(dfs\) 中预处理 \(g\) ,写起来会比较简洁。

  • 线段树的 pushup 要用右边乘左边。

    为什么呢?矩阵乘法不满足交换律,又因为一条重链是在从底向上更新 \(\texttt{dp}\) 值,所以体现在 \(dfs\) 序上就是从右往左乘。

  • 树剖时预处理每条重链的链底,查询时需要 query 整条链。

    修改时 query 的区间为 [dfn[top[u]],ed[u]]ed[u]ed[top[u]] 本质相同),注意不要把右端点写成 u !

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5,inf=1e9;
int m,n,u,v,cnt;
int a[maxn];
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
int f[maxn][2],g[maxn][2];
vector<int> h[maxn];
struct mat
{
    int v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
    static mat c;
    for(int i=0;i<=1;i++)
        for(int j=0;j<=1;j++)
        {
            c.v[i][j]=-inf;
            for(int k=0;k<=1;k++) c.v[i][j]=max(c.v[i][j],a.v[i][k]+b.v[k][j]);
        }
    return c;
}
namespace sgmt
{
    #define ls p<<1
    #define rs p<<1|1
    struct node
    {
        int l,r;
        mat x;
    }f[4*maxn];
    void pushup(int p)
    {
        f[p].x=f[rs].x*f[ls].x;
    }
    void build(int p,int l,int r)
    {
        f[p].l=l,f[p].r=r;
        if(l==r)
        {
            int x=id[l];
            return f[p].x={g[x][0],g[x][1],g[x][0],-inf},void();
        }
        int mid=(l+r)/2;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }
    void modify(int p,int pos,mat x)
    {
        if(f[p].l==f[p].r) return f[p].x=x,void();
        int mid=(f[p].l+f[p].r)/2;
        if(pos<=mid) modify(ls,pos,x);
        else modify(rs,pos,x);
        pushup(p);
    }
    mat query(int p,int l,int r)
    {
        if(l<=f[p].l&&f[p].r<=r) return f[p].x;
        int mid=(f[p].l+f[p].r)/2;
        if(r<=mid) return query(ls,l,r);
        if(l>=mid+1) return query(rs,l,r);
        return query(rs,l,r)*query(ls,l,r);///注意这里是从右往左乘
    }
}
void dfs1(int u,int father)
{
    sz[u]=1,f[u][1]=a[u];
    for(auto v:h[u])
    {
        if(v==father) continue;
        d[v]=d[u]+1,fa[v]=u;
        dfs1(v,u),sz[u]+=sz[v];
        if(sz[v]>=sz[son[u]]) son[u]=v;
        f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
    }
}
void dfs2(int u,int topf)
{
    dfn[u]=++cnt,id[cnt]=u,top[u]=topf,ed[u]=dfn[u],g[u][1]=a[u];
    if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
    for(auto v:h[u])
    {
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
        g[u][0]+=max(f[v][0],f[v][1]),g[u][1]+=f[v][0];
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&u,&v);
        h[u].push_back(v),h[v].push_back(u);
    }
    d[1]=1,dfs1(1,0),dfs2(1,1);
    sgmt::build(1,1,n);
    while(m--)
    {
        scanf("%d%d",&u,&v);
        g[u][1]+=v-a[u],a[u]=v;///更新u自身的信息
        while(u)
        {
            mat lst=sgmt::query(1,dfn[top[u]],ed[u]);
            sgmt::modify(1,dfn[u],{g[u][0],g[u][1],g[u][0],-inf});
            mat now=sgmt::query(1,dfn[top[u]],ed[u]);
            u=fa[top[u]];
            g[u][0]-=max(lst.v[0][0],lst.v[0][1]),g[u][1]-=lst.v[0][0];///减掉旧的贡献
            g[u][0]+=max(now.v[0][0],now.v[0][1]),g[u][1]+=now.v[0][0];///加入新的贡献
        }
        mat res=sgmt::query(1,dfn[1],ed[1]);
        printf("%d\n",max(res.v[0][0],res.v[0][1]));
    }
    return 0;
}

总结

动态 \(\texttt{DP}\) 的操作流程:

  • 在树剖过程中预处理 \(f,g\)

    这里 \(f\) 表示考虑 \(u\) 子树的贡献, \(g\) 表示考虑 \(u\) 及其所有轻子树的贡献。

    注意 \(g\) 的转移依赖于 \(f\) ,而不是 \(g\) 自身封闭转移。

    一般来说,\(f\) 仅在预处理时会用到,修改后可以通过线段树查询得到真实的 \(f\) ;而 \(g\) 会在转移矩阵有所体现,修改需要实时维护

  • 线段树维护转移矩阵。

    注意 pushup 是从右往左乘。

    void pushup(int p)
    {
        f[p].x=f[rs].x*f[ls].x;
    }
    
  • 修改时先减掉原来的贡献,再加入更新之后的贡献。

    下面是修改操作的伪代码:

    ///单点修改g[u],注意不需要上线段树
    while(u)
    {
        mat lst=/**初始矩阵**/ * sgmt::query(1,dfn[top[u]],ed[u]);
        sgmt::modify(1,dfn[u],/**u的转移矩阵**/);
        mat now=/**初始矩阵**/ * sgmt::query(1,dfn[top[u]],ed[u]);
        u=fa[top[u]];
        ///在g[u]中减掉lst的贡献
        ///在g[u]中加入now的贡献
    }
    
  • 查询时 \(u\) 的答案为初始矩阵乘上自底向上整条链的转移矩阵。

    如果初始矩阵刚好为矩阵乘法的单位元,那么可以偷懒不乘初始矩阵。比如非负整数域上 \((\max,+)\) 广义矩阵乘法的单位元为全零矩阵,模板题就是典型例子。

    但是多数情况下初始矩阵不是单位元,也不可省略。

    ans[u]=/**初始矩阵**/ * sgmt::query(1,dfn[u],ed[u]);
    
  • 如果没有修改操作,或者所有查询操作在修改操作之后,然后询问多个点(多条路径)的答案,可以用倍增代替树剖,时间复杂度少一只 \(\log\)

三、相关例题

例2、\(\texttt{CF750E New Year and Old Subsequence}\)

题目描述

给定一个长为 \(n\) 的数字串 \(s\)

\(q\) 次询问,对于 \(s\) 的某个子串 \(s[l:r]\) ,至少要删去几个字符,才能使其包含序列 "2017" ,但不包含序列 "2016" ,无解输出 -1

数据范围

  • \(4\le n\le 2\cdot 10^5,1\le q\le 2\cdot 10^5,1\le l\le r\le n\)

时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{256MB}\)

分析

严格来说,本题不算动态 \(\texttt{DP}\) ,只能算线段树维护矩阵乘法。

注意认真审题,要求包含的是子序列而不是子串。

考虑在子序列自动机上 \(\texttt{dp}\) ,状态设计如下:

  • \(f_{i,0}\) 表示走到状态 \(\varnothing\) ,至少需要删几个字符。
  • \(f_{i,1}\) 表示走到状态 "2" ,至少需要删几个字符。
  • \(f_{i,2}\) 表示走到状态 "20" ,至少需要删几个字符。
  • \(f_{i,3}\) 表示走到状态 "201" ,至少需要删几个字符。
  • \(f_{i,4}\) 表示走到状态 "2017" ,至少需要删几个字符。

\(c=s_i\) ,可以写出转移方程:

\[\begin{aligned} &f_{i,0}=f_{i-1,0}+[c=2]\\ &f_{i,1}=\min(f_{i-1,0}[c=2],f_{i-1,1}+[c=0])\\ &f_{i,2}=\min(f_{i-1,1}[c=0],f_{i-1,2}+[c=1])\\ &f_{i,3}=\min(f_{i-1,2}[c=1],f_{i-1,3}+[c=6\vee c=7])\\ &f_{i,4}=\min(f_{i-1,3}[c=7],f_{i-1,4}+[c=6])\\ \end{aligned} \]

注: \(f[c=x]\) 要求 \(c=x\) 时才能转移, \(f+[c=x]\)\(c=x\) 视为一个布尔函数。

至此已经做到 \(\mathcal O(nq)\) ,考虑继续优化。

把转移方程看成 \((\min,+)\) 广义矩阵乘法,那么每个 \(c\) 都可以预处理出转移矩阵。

每次询问用初始矩阵 \(\begin{bmatrix}0&\infty&\infty&\infty&\infty\\\end{bmatrix}\) 乘上区间 \([l,r]\) 的所有矩阵,线段树维护即可。

时间复杂度 \(\mathcal O(w^3(n+q)\log n)\) ,本题 \(w=5\)

可以用向量乘矩阵的 \(\texttt{trick}\) 优化到\(O(w^3n\log n+w^2q\log n)\),但本题 \(n,q\) 同阶,所以意义不大。

#include<bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=2e5+5,inf=1e9;
int l,n,q,r;
char s[maxn];
struct mat
{
    int v[5][5];
    mat()
    {
        for(int i=0;i<=4;i++)
            for(int j=0;j<=4;j++)
                v[i][j]=inf;
    }
}c[10];
mat operator*(const mat &a,const mat &b)
{
    static mat c;
    for(int i=0;i<=4;i++)
        for(int j=0;j<=4;j++)
        {
            c.v[i][j]=inf;
            for(int k=0;k<=4;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
        }
    return c;
}
struct node
{
    int l,r;
    mat x;
}f[4*maxn];
void build(int p,int l,int r)
{
    f[p].l=l,f[p].r=r;
    if(l==r) return f[p].x=c[s[l]-'0'],void();
    int mid=(l+r)/2;
    build(ls,l,mid);
    build(rs,mid+1,r);
    f[p].x=f[ls].x*f[rs].x;
}
mat query(int p,int l,int r)
{
    if(l<=f[p].l&&f[p].r<=r) return f[p].x;
    int mid=(f[p].l+f[p].r)/2;
    if(r<=mid) return query(ls,l,r);
    if(l>=mid+1) return query(rs,l,r);
    return query(ls,l,r)*query(rs,l,r);
}
int main()
{
    scanf("%d%d%s",&n,&q,s+1);
    for(int i=0;i<=9;i++)
        for(int j=0;j<=4;j++)
            c[i].v[j][j]=0;
    c[2].v[0][0]=1,c[2].v[0][1]=0;
    c[0].v[1][1]=1,c[0].v[1][2]=0;
    c[1].v[2][2]=1,c[1].v[2][3]=0;
    c[7].v[3][3]=1,c[7].v[3][4]=0;
    c[6].v[3][3]=1,c[6].v[4][4]=1;
    build(1,1,n);
    while(q--)
    {
        scanf("%d%d",&l,&r);
        mat res;
        res.v[0][0]=0,res=res*query(1,l,r);
        printf("%d\n",res.v[0][4]!=inf?res.v[0][4]:-1);
    }
    return 0;
}

例3、\(\texttt{P6021 洪水}\)

题目描述

给定一棵 \(n\) 个节点的树,你可以花费 \(w_i\) 的代价删掉第 \(i\) 个点。

接下来 \(m\) 次操作:

  • Q x:询问如果要使 \(x\) 与其子树中所有叶节点不连通,花费代价的最小值。
  • C x y:将 \(w_x\) 加上 \(y\)

数据范围

  • \(1\le n,m\le 2\cdot 10^5\)
  • 保证任意时刻 \(0\le w_i\lt 2^{31}\)

时间限制 \(\texttt{1s}\) ,空间限制 \(\texttt{125MB}\)

分析

\(f_u\) 为让 \(u\) 与其子树所有叶子不连通的最小代价,容易写出转移方程:

\[f_u=\min(w_u,\sum_{v\in son_u}f_v)\\ \]

\(g_u=\sum\limits_{v\neq wson_u}f_v\) ,则 \(f_u=\min(w_u,f_{wson_u}+g_u)\)

然后写成 \((\min,+)\) 广义矩阵乘法:

\[\begin{bmatrix} f_u&0\\ \end{bmatrix} = \begin{bmatrix} f_{wson_u}&0\\ \end{bmatrix} \times \begin{bmatrix} g_u&\infty\\ w_u&0\\ \end{bmatrix} \]

动态 \(\texttt{DP}\) 维护即可,时间复杂度 \(\mathcal O(w^3n\log^2n)\) ,本题 \(w=2\)

注意叶子节点 \(wson_u\) 不存在,因此 query 时需要用初始矩阵 \(\begin{bmatrix}\infty&0\\\end{bmatrix}\) 乘上整条链的贡献。

#include<bits/stdc++.h>
#define ll long long
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=2e5+5;
const ll inf=1e18;
int m,n,u,v,cnt;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
ll f[maxn],g[maxn],w[maxn];
char ch[2];
vector<int> h[maxn];
struct mat
{
    ll v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
    static mat c;
    for(int i=0;i<=1;i++)
        for(int j=0;j<=1;j++)
        {
            c.v[i][j]=inf;
            for(int k=0;k<=1;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
        }
    return c;
}
namespace sgmt
{
    struct node
    {
        int l,r;
        mat x;
    }f[4*maxn];
    void pushup(int p)
    {
        f[p].x=f[rs].x*f[ls].x;
    }
    void build(int p,int l,int r)
    {
        f[p].l=l,f[p].r=r;
        if(l==r)
        {
            int u=id[l];
            return f[p].x={g[u],inf,w[u],0},void();
        }
        int mid=(l+r)/2;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }
    void modify(int p,int pos,mat x)
    {
        if(f[p].l==f[p].r) return f[p].x=x,void();
        int mid=(f[p].l+f[p].r)/2;
        if(pos<=mid) modify(ls,pos,x);
        else modify(rs,pos,x);
        pushup(p);
    }
    mat query(int p,int l,int r)
    {
        if(l<=f[p].l&&f[p].r<=r) return f[p].x;
        int mid=(f[p].l+f[p].r)/2;
        if(r<=mid) return query(ls,l,r);
        if(l>=mid+1) return query(rs,l,r);
        return query(rs,l,r)*query(ls,l,r);
    }
}
using sgmt::modify;
using sgmt::query;
void dfs1(int u,int father)
{
    sz[u]=1;
    for(auto v:h[u])
    {
        if(v==father) continue;
        d[v]=d[u]+1,fa[v]=u;
        dfs1(v,u),sz[u]+=sz[v];
        if(sz[v]>=sz[son[u]]) son[u]=v;
        f[u]+=f[v];
    }
    f[u]=sz[u]==1?w[u]:min(f[u],w[u]);
}
void dfs2(int u,int topf)
{
    dfn[u]=++cnt,id[cnt]=u,top[u]=topf,ed[u]=cnt;
    if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
    for(auto v:h[u])
    {
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
        g[u]+=f[v];
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&u,&v);
        h[u].push_back(v),h[v].push_back(u);
    }
    dfs1(1,0),dfs2(1,1);
    sgmt::build(1,1,n);
    scanf("%d",&m);
    while(m--)
    {
        scanf("%s",ch);
        if(ch[0]=='Q')
        {
            scanf("%d",&u);
            mat res=(mat){inf,0,0,0}*query(1,dfn[u],ed[u]);
            printf("%lld\n",res.v[0][0]);
        }
        else
        {
            scanf("%d%d",&u,&v),w[u]+=v;
            while(u)
            {
                mat lst=(mat){inf,0,0,0}*query(1,dfn[top[u]],ed[u]);
                modify(1,dfn[u],{g[u],inf,w[u],0});
                mat now=(mat){inf,0,0,0}*query(1,dfn[top[u]],ed[u]);
                u=fa[top[u]];
                g[u]+=now.v[0][0]-lst.v[0][0];
            }
        }
    }
    return 0;
}

例4、\(\texttt{P5024 [NOIP2018 提高组] 保卫王国}\)

题目描述

给定一棵 \(n\) 个节点的树,点有点权 \(w_i\)

\(m\) 次询问,每次分别钦定 \(a,b\) 必须选/不选,求最小权点覆盖,无解输出 -1

数据范围

  • \(1\le n,m,w_i\le10^5,1\le a\neq b\le n\)

时间限制 \(\texttt{2s}\) ,空间限制 \(\texttt{512MB}\)

分析

强制选点/不选点可以通过给权值加上/减去 \(\inf\) 实现。

熟知结论:最小权点覆盖等于总权值减去最大权独立集。

带修的最大权独立集就是模板题干的事情。

时间复杂度 \(\mathcal O(w^3m\log^2n)\) ,本题 \(w=2\)

#include<bits/stdc++.h>
#define ll long long
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1e5+5;
const ll inf=1e18;
int a,b,m,n,u,v,cnt;
ll sum;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int ed[maxn],id[maxn],dfn[maxn],top[maxn];
ll f[maxn][2],g[maxn][2],w[maxn];
vector<int> h[maxn];
struct mat
{
    ll v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
    static mat c;
    for(int i=0;i<=1;i++)
        for(int j=0;j<=1;j++)
        {
            c.v[i][j]=-inf;
            for(int k=0;k<=1;k++) c.v[i][j]=max(c.v[i][j],a.v[i][k]+b.v[k][j]);
        }
    return c;
}
namespace sgmt
{
    struct node
    {
        int l,r;
        mat x;
    }f[4*maxn];
    void pushup(int p)
    {
        f[p].x=f[rs].x*f[ls].x;
    }
    void build(int p,int l,int r)
    {
        f[p].l=l,f[p].r=r;
        if(l==r)
        {
            int u=id[l];
            return f[p].x={g[u][0],g[u][1],g[u][0],-inf},void();
        }
        int mid=(l+r)/2;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }
    void modify(int p,int pos,mat x)
    {
        if(f[p].l==f[p].r) return f[p].x=x,void();
        int mid=(f[p].l+f[p].r)/2;
        if(pos<=mid) modify(ls,pos,x);
        else modify(rs,pos,x);
        pushup(p);
    }
    mat query(int p,int l,int r)
    {
        if(l<=f[p].l&&f[p].r<=r) return f[p].x;
        int mid=(f[p].l+f[p].r)/2;
        if(r<=mid) return query(ls,l,r);
        if(l>=mid+1) return query(rs,l,r);
        return query(rs,l,r)*query(ls,l,r);
    }
}
void dfs1(int u,int father)
{
    sz[u]=1,f[u][1]=w[u];
    for(auto v:h[u])
    {
        if(v==father) continue;
        d[v]=d[u]+1,fa[v]=u;
        dfs1(v,u),sz[u]+=sz[v];
        if(sz[v]>=sz[son[u]]) son[u]=v;
        f[u][0]+=max(f[v][0],f[v][1]),f[u][1]+=f[v][0];
    }
}
void dfs2(int u,int topf)
{
    dfn[u]=++cnt,id[cnt]=u,ed[u]=cnt,top[u]=topf,g[u][1]=w[u];
    if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
    for(auto v:h[u])
    {
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
        g[u][0]+=max(f[v][0],f[v][1]),g[u][1]+=f[v][0];
    }
}
void modify(int u,ll tag)
{
    g[u][1]+=tag,sum+=tag;
    while(u)
    {
        mat lst=sgmt::query(1,dfn[top[u]],ed[u]);
        sgmt::modify(1,dfn[u],{g[u][0],g[u][1],g[u][0],-inf});
        mat now=sgmt::query(1,dfn[top[u]],ed[u]);
        u=fa[top[u]];
        g[u][0]-=max(lst.v[0][0],lst.v[0][1]),g[u][1]-=lst.v[0][0];
        g[u][0]+=max(now.v[0][0],now.v[0][1]),g[u][1]+=now.v[0][0];
    }
}
int main()
{
    scanf("%d%d%*s",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&w[i]),sum+=w[i];
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&u,&v);
        h[u].push_back(v),h[v].push_back(u);
    }
    dfs1(1,0),dfs2(1,1);
    sgmt::build(1,1,n);
    while(m--)
    {
        scanf("%d%d%d%d",&u,&a,&v,&b);
        modify(u,a?-inf:inf),modify(v,b?-inf:inf);
        mat cur=sgmt::query(1,1,ed[1]);
        ll res=sum-max(cur.v[0][0],cur.v[0][1])+(a+b)*inf;
        printf("%lld\n",res<inf?res:-1);
        modify(u,a?inf:-inf),modify(v,b?inf:-inf);
    }
    return 0;
}

例5、\(\texttt{LOJ3539 「JOI Open 2018」猫或狗}\)

题目描述

给定一棵 \(n\) 个节点的树,点有点权 \(w_i\) ,满足 \(w_i\in\{0,1,2\}\) ,初始全为 \(2\)

接下来 \(q\) 次操作:先单点修改点权,再询问为使所有权值为 \(0\) 的点不与权值为 \(1\) 的点连通,至少要删几条边。

数据范围

  • \(1\le n,q\le 10^5\)

时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{512MB}\)

分析

目标等价于 \(\forall u\)\(u\) 不能既与 \(0\) 连通又不与 \(1\) 连通。

\(f_{u,0/1}\)表示使点 \(u\) 不和子树内的 \(0/1\) 连通,至少要删几条边。

\[f_{u,0}=[w_u\neq0]\sum_{v\in son(u)}\min(f_{v,0},f_{v,1}+1)\\ f_{u,1}=[w_u\neq1]\sum_{v\in son(u)}\min(f_{v,1},f_{v,0}+1) \]

注:对于本题,中括号不满足时该项视为\(\infty\)

然后分离轻重子树的贡献。定义:

\[g_{u,0}=\sum_{v\neq wson_u}\min(f_{v,0},f_{v,1}+1)\\ g_{u,1}=\sum_{v\neq wson_u}\min(f_{v,1},f_{v,0}+1) \]

再用 \(g\) 化简 \(f\)

\[f_{u,0}=[w_u\neq0](g_{u,0}+\min(f_{wson_u,0},f_{wson_u,1}+1))\\ f_{u,1}=[w_u\neq1](g_{u,1}+\min(f_{wson_u,1},f_{wson_u,0}+1)) \]

写成 \((\min,+)\) 转移矩阵:

\[\begin{bmatrix} f_{u,0}&f_{u,1}\\ \end{bmatrix} = \begin{bmatrix} f_{wson_u,0}&f_{wson_u,1}\\ \end{bmatrix} \times \begin{bmatrix} [w_u\neq0]g_{u,0}&[w_u\neq1](g_{u,1}+1)\\ [w_u\neq0](g_{u,0}+1)&[w_u\neq1]g_{u,1}\\ \end{bmatrix} \]

然后动态 \(\texttt{DP}\) 就可以了。

时间复杂度 \(\mathcal O(w^3(n+q)\log^2n)\) ,本题 \(w=2\)

#include<bits/stdc++.h>
#include"catdog.h"
#define ls p<<1
#define rs p<<1|1
using namespace std;
const int maxn=1e5+5,inf=1e9;
int n,cnt;
int d[maxn],fa[maxn],sz[maxn],son[maxn];
int w[maxn],ed[maxn],id[maxn],dfn[maxn],top[maxn];
int f[maxn][2],g[maxn][2];
vector<int> e[maxn];
struct mat
{
    int v[2][2];
};
mat operator*(const mat &a,const mat &b)
{
    static mat c;
    for(int i=0;i<=1;i++)
        for(int j=0;j<=1;j++)
        {
            c.v[i][j]=inf;
            for(int k=0;k<=1;k++) c.v[i][j]=min(c.v[i][j],a.v[i][k]+b.v[k][j]);
        }
    return c;
}
namespace sgmt
{
    struct node
    {
        int l,r;
        mat x;
        void init(int u)
        {
            x={w[u]!=0?g[u][0]:inf,w[u]!=1?g[u][1]+1:inf,w[u]!=0?g[u][0]+1:inf,w[u]!=1?g[u][1]:inf};
        }
    }f[maxn<<2];
    void pushup(int p)
    {
        f[p].x=f[rs].x*f[ls].x;
    }
    void build(int p,int l,int r)
    {
        f[p].l=l,f[p].r=r;
        if(l==r) return f[p].init(id[l]);
        int mid=(l+r)/2;
        build(ls,l,mid);
        build(rs,mid+1,r);
        pushup(p);
    }
    void modify(int p,int pos)
    {
        if(f[p].l==f[p].r) return f[p].init(id[pos]);
        int mid=(f[p].l+f[p].r)/2;
        if(pos<=mid) modify(ls,pos);
        else modify(rs,pos);
        pushup(p);
    }
    mat query(int p,int l,int r)
    {
        if(l<=f[p].l&&f[p].r<=r) return f[p].x;
        int mid=(f[p].l+f[p].r)/2;
        if(r<=mid) return query(ls,l,r);
        if(l>=mid+1) return query(rs,l,r);
        return query(rs,l,r)*query(ls,l,r);
    }
}
void dfs1(int u,int father)
{
    sz[u]=1;
    for(auto v:e[u])
    {
        if(v==father) continue;
        d[v]=d[u]+1,fa[v]=u;
        dfs1(v,u),sz[u]+=sz[v];
        if(sz[v]>=sz[son[u]]) son[u]=v;
        f[u][0]+=min(f[v][0],f[v][1]+1),f[u][1]+=min(f[v][1],f[v][0]+1);
    }
    if(w[u]!=2) f[u][w[u]]=inf;
}
void dfs2(int u,int topf)
{
    dfn[u]=++cnt,id[cnt]=u,ed[u]=cnt,top[u]=topf;
    if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
    for(auto v:e[u])
    {
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
        g[u][0]+=min(f[v][0],f[v][1]+1),g[u][1]+=min(f[v][1],f[v][0]+1);
    }
}
void initialize(int _n,vector<int> a,vector<int> b)
{
    n=_n;
    for(int i=0;i<n-1;i++)
    {
        int u=a[i],v=b[i];
        e[u].push_back(v),e[v].push_back(u);
    }
    for(int i=1;i<=n;i++) w[i]=2;
    d[1]=1,dfs1(1,0),dfs2(1,1);
    sgmt::build(1,1,n);
}
int modify(int u,int x)
{
    w[u]=x;
    while(u)
    {
        mat lst=(mat){0,0,0,0}*sgmt::query(1,dfn[top[u]],ed[u]);
        sgmt::modify(1,dfn[u]);
        mat now=(mat){0,0,0,0}*sgmt::query(1,dfn[top[u]],ed[u]);
        u=fa[top[u]];
        g[u][0]-=min(lst.v[0][0],lst.v[0][1]+1),g[u][1]-=min(lst.v[0][1],lst.v[0][0]+1);
        g[u][0]+=min(now.v[0][0],now.v[0][1]+1),g[u][1]+=min(now.v[0][1],now.v[0][0]+1);
    }
    mat res=(mat){0,0,0,0}*sgmt::query(1,1,ed[1]);
    return min(res.v[0][0],res.v[0][1]);
}
int cat(int u)
{
    return modify(u,0);
}
int dog(int u)
{
    return modify(u,1);
}
int neighbor(int u)
{
    return modify(u,2);
}

例6、\(\texttt{LOJ2269 「SDOI2017」切树游戏}\)

题目描述

给定一棵 \(n\) 个节点的树,点有点权 \(w_i\)

接下来 \(q\) 次操作:

  • Change x y :将 \(w_x\) 改为 \(y\)
  • Query k :询问有多少个非空连通块,权值异或和为 \(k\)

数据范围

  • \(1\le n,q\le 3\cdot 10^4\) ,保证 Change 操作个数 \(\le 10^4\)
  • \(4\le m\le 128\) ,保证 \(m\)\(2\) 的方幂。
  • \(0\le w_i,y\lt m\)

时间限制 \(\texttt{3s}\) ,空间限制 \(\texttt{512MB}\)

分析

\(f_{u,i}\) 表示以 \(u\) 为根的所有连通块中,权值异或和为 \(i\) 的连通块个数。

每次合并一棵 \(u\) 的子树 \(v\) ,转移方程为\(f'_{u,i}=f_{u,i}+\sum\limits_{x\oplus y=i}f_{u,x}\cdot f_{v,y}\)

这里 \(f_u\)表示转移前的 \(\texttt{dp}\) 值, \(f'_u\) 表示转移后的 \(\texttt{dp}\) 值。

为节约篇幅,记 \(\hat f_u=\text{FWT}[f_u]\) ,注意 \(\hat{f_u+1}=\hat f_u+\hat 1\)

两边同时取 \(\text{FWT}\)

\[\hat{f'_u}=\hat f_u\cdot(\hat{f_v+1})\\ \]

\(f_u\) 的初始值 \(p_u=x^{w_u}\) ,最终的转移方程为:

\[\hat f_u=\hat p_u\cdot\prod_{v\in son(u)}(\hat{f_v+1})\\ \]

最后输出 \(\sum_{u=1}^nf_{u,k}\) 即可,于是我们得到了一个 \(\mathcal O(nmq)\) 的算法。


由于询问要对所有的 \(u\) 求和,因此再记一个\(h_{u,i}=\sum\limits_{v\in subtree(u)}f_{v,i}\)

转移方程 \(h_{u,i}=f_{u,i}+\sum\limits_{v\in son(u)}h_{v,i}\)

两边同时取 \(\text{FWT}\)

\[\hat h_u=\hat f_u+\sum_{v\in son(u)}\hat h_v\\ \]

接下来是动态 \(\texttt{DP}\) 的基本操作,分离轻重子树的贡献。

定义:

\[\hat g_u=\prod\limits_{v\neq wson_u}(\hat{f_v+1})\\ \hat l_u=\sum\limits_{v\neq wson_u}\hat h_v\\ \]

转移方程为:

\[\hat f_u=\hat p_u\cdot\hat g_u\cdot(\hat{f_{wson_u}+1})\\ \hat h_u=\hat f_u+\hat l_u+\hat h_{wson_u}\\ \]

然后写成矩阵乘法:

\[\begin{bmatrix} \hat f_u&\hat h_u&\hat 1\\ \end{bmatrix} = \begin{bmatrix} \hat f_{wson_u}&\hat h_{wson_u}&\hat 1\\ \end{bmatrix} \times \begin{bmatrix} \hat p_u\cdot\hat g_u&\hat p_u\cdot\hat g_u&0\\ 0&\hat 1&0\\ \hat p_u\cdot\hat g_u&\hat l_u+\hat p_u\cdot\hat g_u&\hat 1\\ \end{bmatrix}\\ \]


矩阵乘法自带 \(27\) 倍常数,显然无法承受的。

注意到 \(\begin{bmatrix}a&b&0\\0&\hat 1&0\\c&d&\hat 1\\\end{bmatrix}\) 做矩阵乘法时关于这个形式封闭:

\[\begin{bmatrix} a_1&b_1&0\\ 0&\hat 1&0\\ c_1&d_1&\hat 1\\ \end{bmatrix} \times \begin{bmatrix} a_2&b_2&0\\ 0&\hat 1&0\\ c_2&d_2&\hat 1\\ \end{bmatrix} = \begin{bmatrix} a_1\cdot a_2&a_1\cdot b_2+b_1&0\\ 0&\hat 1&0\\ c_1\cdot a_2+c_2&c_1\cdot b_2+d_1+d_2&\hat 1\\ \end{bmatrix}\\ \]

我们只需维护 \(a,b,c,d\) 四个值,常数从 \(27\) 降为 \(4\)

询问时初始矩阵为 \(\begin{bmatrix}0&0&\hat 1\\\end{bmatrix}\) ,但代码实现没这么麻烦,因为 \(\hat f_u,\hat h_u\) 分别对应转移矩阵中的 \(c,d\) 位置。

具体可以看这篇blog中“基于变换合并的算法”一栏。


但本题并没有结束,动态 \(\texttt{DP}\) 中修改需要删除原来的贡献。

其中 \(\hat l_u\) 可以直接减,但 \(\hat g_u\) 需要做除法,而零没有乘法逆元。

\(\forall 1\le u\le n\) ,我们额外用一个数组 z[u] 维护 g[u] 中乘零的次数,这样乘零和除零可以直接在 z[u] 上修改。

时间复杂度 \(\mathcal O(4\cdot 10^4\cdot m\log^2n)\),前面的 \(4\) 为矩阵乘法的常数。

#include<bits/stdc++.h>
#define poly array<int,128>
using namespace std;
const int maxn=3e4+5,mod=1e4+7,inv2=(mod+1)>>1;
int m,n,q,u,v,cnt;
int ed[maxn],fa[maxn],sz[maxn],dfn[maxn],son[maxn],top[maxn];
int w[maxn],inv[mod];
char ch[10];
poly f[maxn],g[maxn],h[maxn],l[maxn],z[maxn],p[128];
vector<int> e[maxn];
int qpow(int a,int k)
{
    int res=1;
    for(;k;a=a*a%mod,k>>=1) if(k&1) res=res*a%mod;
    return res;
}
int add(int x,int y)
{
    if((x+=y)>=mod) x-=mod;
    return x;
}
int dec(int x,int y)
{
    if((x-=y)<0) x+=mod;
    return x;
}
void fwt(poly &f,int n,int op)
{
    for(int k=2,m=1;k<=n;k<<=1,m<<=1)
        for(int i=0;i<n;i+=k)
            for(int j=i;j<i+m;j++)
            {
                int x=f[j],y=f[j+m];
                f[j]=add(x,y),f[j+m]=dec(x,y);
                if(op==-1) f[j]=f[j]*inv2%mod,f[j+m]=f[j+m]*inv2%mod;
            }
}
poly operator+(poly a,poly b)
{
    for(int i=0;i<m;i++) a[i]=add(a[i],b[i]);
    return a;
}
poly operator-(poly a,poly b)
{
    for(int i=0;i<m;i++) a[i]=dec(a[i],b[i]);
    return a;
}
poly operator*(poly a,poly b)
{
    for(int i=0;i<m;i++) a[i]=a[i]*b[i]%mod;
    return a;
}
void operator+=(poly &a,poly b)
{
    a=a+b;
}
void operator-=(poly &a,poly b)
{
    a=a-b;
}
void operator*=(poly &a,poly b)
{
    a=a*b;
}
void dfs1(int u,int _f)
{
    sz[u]=1,f[u]=p[w[u]];
    for(auto v:e[u])
    {
        if(v==_f) continue;
        dfs1(v,u),fa[v]=u,sz[u]+=sz[v];
        if(sz[v]>=sz[son[u]]) son[u]=v;
        f[u]*=f[v]+p[0],h[u]+=h[v];
    }
    h[u]+=f[u];
}
void dfs2(int u,int topf)
{
    dfn[u]=++cnt,top[u]=topf,ed[u]=cnt;
    for(int i=0;i<m;i++) g[u][i]=1;
    if(son[u]) dfs2(son[u],topf),ed[u]=ed[son[u]];
    for(auto v:e[u])
    {
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
        l[u]+=h[v];
        poly tmp=f[v]+p[0];
        for(int i=0;i<m;i++) tmp[i]?g[u][i]=g[u][i]*tmp[i]%mod:z[u][i]++;
    }
}
struct mat
{
    poly a,b,c,d;
    mat operator*(const mat &x)
    {
        return {a*x.a,a*x.b+b,c*x.a+x.c,c*x.b+d+x.d};
    }
};
mat get_mat(int u)
{
    static poly tmp;
    for(int i=0;i<m;i++) tmp[i]=z[u][i]?0:p[w[u]][i]*g[u][i]%mod;
    return {tmp,tmp,tmp,l[u]+tmp};
}
namespace sgmt
{///博主为了偷懒写了zkw线段树,但是直接导致用时翻倍,大家不要学我qwq
    int p;
    mat val[3*maxn];
    void build()
    {
        p=1<<(__lg(n+1)+1);
        for(int i=1;i<=n;i++) val[p+dfn[i]]=get_mat(i);
        for(int i=(p+n)>>1;i;i--) val[i]=val[i<<1|1]*val[i<<1];
    }
    void modify(int u)
    {
        val[p+dfn[u]]=get_mat(u);
        for(int i=(p+dfn[u])>>1;i;i>>=1) val[i]=val[i<<1|1]*val[i<<1];
    }
    mat query(int l,int r)
    {
        static int st[20];
        int top=0;
        mat res={::p[0],f[0],f[0],f[0]};///f[0]指代常多项式0,p[0]指代常多项式1经过FWT后的结果
        for(l=p+l-1,r=p+r+1;l^r^1;l>>=1,r>>=1)
        {
            if(~l&1) st[++top]=l^1;
            if(r&1) res=res*val[r^1];
        }
        for(int i=top;i>=1;i--) res=res*val[st[i]];
        return res;
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%d",&w[i]);
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d",&u,&v);
        e[u].push_back(v),e[v].push_back(u);
    }
    for(int i=1;i<mod;i++) inv[i]=qpow(i,mod-2);
    for(int i=0;i<m;i++) p[i][i]=1,fwt(p[i],m,1);
    dfs1(1,0),dfs2(1,1);
    sgmt::build();
    for(scanf("%d",&q);q--;)
    {
        scanf("%s",ch);
        if(ch[0]=='C')
        {
            scanf("%d%d",&u,&v),w[u]=v;
            while(u)
            {
                mat lst=sgmt::query(dfn[top[u]],ed[u]);
                sgmt::modify(u);
                mat now=sgmt::query(dfn[top[u]],ed[u]);
                u=fa[top[u]];
                l[u]=l[u]-lst.d+now.d;
                for(int i=0;i<m;i++)
                {
                    int x=add(lst.c[i],p[0][i]),y=add(now.c[i],p[0][i]);
                    x?g[u][i]=g[u][i]*inv[x]%mod:z[u][i]--;
                    y?g[u][i]=g[u][i]*y%mod:z[u][i]++;
                }
            }
        }
        else
        {
            scanf("%d",&u);
            poly tmp=sgmt::query(dfn[1],ed[1]).d;
            fwt(tmp,m,-1),printf("%d\n",tmp[u]);
        }
    }
    return 0;
}
posted @ 2023-01-28 17:01  peiwenjun  阅读(2)  评论(0编辑  收藏  举报