【DP】动态 DP

准备退役 whk 了,最后学点东西。

不得不承认,CSPS2022 T4 对动态 DP 起到了良好的普及效果。

P4719 【模板】"动态 DP"&动态树分治

\(f_{u,0/1}\) 表示不选/选 \(u\)\(u\) 的子树内的最大权独立集。

不带修改的情况,有

\[f_{u,0}=\sum\max(f_{v,0},f_{v,1}) \]

\[f_{u,1}=val_u+\sum f_{v,0} \]

答案即为 \(\max(f_{1,0},f_{1,1})\)

如果带修改呢?修改的结点影响的仅是其到根节点的路径。可以用树剖(重链剖分),为啥?

树剖有很好的性质:重儿子的 \(dfs\) 序连续,任意节点到根节点最多经过 \(\log n\) 条轻边。

这启示我们可以把轻儿子的信息合并起来,对于重儿子在跳重链的时候合并。

\(g_{u,0/1}\) 表示只考虑轻儿子,且 \(u\) 不选/选的最大权独立集。则有

\[f_{u,0}=g_{u,0}+\max(f_{son_u,0},f_{son_u,1}) \]

\[f_{u,1}=g_{u,1}+f_{son_u,0} \]

\(son_u\) 表示 \(u\) 的重儿子。这样我们还把 \(\sum\) 去掉了。

现在的问题就是如何快速合并和修改了,注意到可以使用广义矩阵乘法实现,有

\[f_{u,0}=\max(g_{u,0}+f_{son_u,0},g_{u,0}+f_{son_u,1}) \]

\[f_{u,1}=\max(g_{u,1}+f_{son_u,0},-\infty) \]

\[\begin{bmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&-\infty\end{bmatrix}\begin{bmatrix}f_{son_u,0}\\f_{son_u,1}\end{bmatrix}=\begin{bmatrix}f_{u,0}\\f_{u,1}\end{bmatrix} \]

注意这里矩阵一定得这么写,因为链头在区间左端,链尾在区间右端,维护的初始信息在叶子结点,所以只能把它放在右边。

怎么合并?废话,重链直接合并,向上跳轻边的时候合并到上一级重链,修改的时候用增量的方式改变父亲代表的矩阵。

点击查看代码
#define pb push_back
const int N=1e5+10,inf=0x3f3f3f3f;

int n,m;
vector<int> e[N];
int siz[N],fa[N],son[N],in[N],out[N],top[N],rnk[N],tim=0;
int f[N][2],a[N];

struct matrix{
    int m[2][2];
    inline matrix operator *(const matrix& b)const{
        matrix res;
        res.m[0][0]=res.m[0][1]=res.m[1][0]=res.m[1][1]=-inf;
        for(int k=0;k<2;++k){
            for(int i=0;i<2;++i){
                for(int j=0;j<2;++j){
                    res.m[i][j]=max(res.m[i][j],m[i][k]+b.m[k][j]);
                }
            }
        }return res;
    }
}val[N],mat[N<<2],bef,aft;

void dfs1(int u,int f){
    fa[u]=f,siz[u]=1;
    for(int v:e[u])if(v!=f){
        dfs1(v,u);siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(int u,int t){//预处理出初始矩阵
    top[u]=t,rnk[in[u]=++tim]=u,out[t]=max(out[t],tim);
    //out 表示链顶所在链的链底
    f[u][0]=0,f[u][1]=a[u];
    val[u].m[0][0]=val[u].m[0][1]=0;
    val[u].m[1][0]=a[u];val[u].m[1][1]=-inf;
    if(son[u]){
        dfs2(son[u],t);//重儿子不用算到 g 里面
        f[u][0]+=max(f[son[u]][0],f[son[u]][1]);
        f[u][1]+=f[son[u]][0];
    }
    for(int v:e[u])if(v!=fa[u]&&v!=son[u]){
        dfs2(v,v);
        f[u][0]+=max(f[v][0],f[v][1]);
        f[u][1]+=f[v][0];
        val[u].m[0][0]+=max(f[v][0],f[v][1]);
        val[u].m[0][1]=val[u].m[0][0];
        val[u].m[1][0]+=f[v][0];
    }
}

#define ls p<<1
#define rs p<<1|1
void build(int p,int l,int r){
    if(l==r)return mat[p]=val[rnk[l]],void();
    int mid=(l+r)>>1;
    build(ls,l,mid),build(rs,mid+1,r);
    mat[p]=mat[ls]*mat[rs];
}
void modify(int p,int l,int r,int pos){
    if(l==r)return mat[p]=val[rnk[pos]],void();
    int mid=(l+r)>>1;
    if(pos<=mid)modify(ls,l,mid,pos);
    else modify(rs,mid+1,r,pos);
    mat[p]=mat[ls]*mat[rs];
}
matrix query(int p,int l,int r,int ql,int qr){
    if(ql<=l&&r<=qr)return mat[p];
    int mid=(l+r)>>1;
    if(qr<=mid)return query(ls,l,mid,ql,qr);
    if(ql>mid)return query(rs,mid+1,r,ql,qr);
    return query(ls,l,mid,ql,mid)*query(rs,mid+1,r,mid+1,qr);
}

void change(int u,int k){
    val[u].m[1][0]+=k-a[u];a[u]=k;
    while(u){
        bef=query(1,1,n,in[top[u]],out[top[u]]);
        modify(1,1,n,in[u]);
        aft=query(1,1,n,in[top[u]],out[top[u]]);
        u=fa[top[u]];
	//向上跳轻边用增量法修改,注意矩阵中的元素代表的意义以及转移式
        val[u].m[0][0]+=max(aft.m[0][0],aft.m[1][0])-max(bef.m[0][0],bef.m[1][0]);
        val[u].m[0][1]=val[u].m[0][0];
        val[u].m[1][0]+=aft.m[0][0]-bef.m[0][0];
    }
}

int main(){
    read(n),read(m);
    for(int i=1;i<=n;++i)read(a[i]);
    for(int i=1,u,v;i<n;++i){
        read(u),read(v);
        e[u].pb(v),e[v].pb(u);
    }dfs1(1,0),dfs2(1,1);build(1,1,n);
    for(int i=1,x,y;i<=m;++i){
        read(x),read(y);
        change(x,y);
        matrix ans=query(1,1,n,in[1],out[1]);
        printf("%d\n",max(ans.m[0][0],ans.m[1][0]));
    }
    return 0;
}

P5024 [NOIP2018 提高组] 保卫王国

最小权覆盖 = 全集 - 最大权独立集

对于强制选择,则给他的权值加上 \(-\infty\);若强制不选,则加上 \(+\infty\)。再给权值加上对应的值就行了。之后的内容就和上面一样了。记得询问完要改回来。

点击查看代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define getchar()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
template <typename T>
inline void read(T& r) {
	r=0;bool w=0; char ch=getchar();
	while(ch<'0'||ch>'9') w=ch=='-'?1:0,ch=getchar();
	while(ch>='0'&&ch<='9') r=(r<<3)+(r<<1)+(ch^48), ch=getchar();
	r=w?-r:r;
}
#define pb push_back
#define int long long
const int N=1e5+10,inf=1e10;

int n,m;
vector<int> e[N];
int siz[N],fa[N],son[N],in[N],out[N],top[N],rnk[N],tim=0;
int f[N][2],a[N],sum=0;

struct matrix{
    int m[2][2];
    inline matrix operator *(const matrix& b)const{
        matrix res;
        res.m[0][0]=res.m[0][1]=res.m[1][0]=res.m[1][1]=-inf;
        for(int k=0;k<2;++k){
            for(int i=0;i<2;++i){
                for(int j=0;j<2;++j){
                    res.m[i][j]=max(res.m[i][j],m[i][k]+b.m[k][j]);
                }
            }
        }return res;
    }
}val[N],mat[N<<2],bef,aft;

void dfs1(int u,int f){
    fa[u]=f,siz[u]=1;
    for(int v:e[u])if(v!=f){
        dfs1(v,u);siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(int u,int t){
    top[u]=t,rnk[in[u]=++tim]=u,out[t]=max(out[t],tim);
    f[u][0]=0,f[u][1]=a[u];
    val[u].m[0][0]=val[u].m[0][1]=0;
    val[u].m[1][0]=a[u];val[u].m[1][1]=-inf;
    if(son[u]){
        dfs2(son[u],t);
        f[u][0]+=max(f[son[u]][0],f[son[u]][1]);
        f[u][1]+=f[son[u]][0];
    }
    
    for(int v:e[u])if(v!=fa[u]&&v!=son[u]){
        dfs2(v,v);
        f[u][0]+=max(f[v][0],f[v][1]);
        f[u][1]+=f[v][0];
        val[u].m[0][0]+=max(f[v][0],f[v][1]);
        val[u].m[0][1]=val[u].m[0][0];
        val[u].m[1][0]+=f[v][0];
    }
}

#define ls p<<1
#define rs p<<1|1
void build(int p,int l,int r){
    if(l==r)return mat[p]=val[rnk[l]],void();
    int mid=(l+r)>>1;
    build(ls,l,mid),build(rs,mid+1,r);
    mat[p]=mat[ls]*mat[rs];
}
void modify(int p,int l,int r,int pos){
    if(l==r)return mat[p]=val[rnk[pos]],void();
    int mid=(l+r)>>1;
    if(pos<=mid)modify(ls,l,mid,pos);
    else modify(rs,mid+1,r,pos);
    mat[p]=mat[ls]*mat[rs];
}
matrix query(int p,int l,int r,int ql,int qr){
    if(ql<=l&&r<=qr)return mat[p];
    int mid=(l+r)>>1;
    if(qr<=mid)return query(ls,l,mid,ql,qr);
    if(ql>mid)return query(rs,mid+1,r,ql,qr);
    return query(ls,l,mid,ql,mid)*query(rs,mid+1,r,mid+1,qr);
}

void change(int u,int k){
    val[u].m[1][0]+=k;a[u]+=k;
    while(u){
        bef=query(1,1,n,in[top[u]],out[top[u]]);
        modify(1,1,n,in[u]);
        aft=query(1,1,n,in[top[u]],out[top[u]]);
        u=fa[top[u]];
        val[u].m[0][0]+=max(aft.m[0][0],aft.m[1][0])-max(bef.m[0][0],bef.m[1][0]);
        val[u].m[0][1]=val[u].m[0][0];
        val[u].m[1][0]+=aft.m[0][0]-bef.m[0][0];
    }
}

char skip[5];

signed main(){
    scanf("%lld%lld%s",&n,&m,skip);
    for(int i=1;i<=n;++i)read(a[i]),sum+=a[i];
    for(int i=1,u,v;i<n;++i){
        read(u),read(v);
        e[u].pb(v),e[v].pb(u);
    }dfs1(1,0),dfs2(1,1);build(1,1,n);
    for(int i=1,a,x,b,y;i<=m;++i){
        read(a),read(x),read(b),read(y);
        if((fa[a]==b||fa[b]==a)&&!x&&!y){printf("-1\n");continue;}
        change(a,!x?inf:-inf),change(b,!y?inf:-inf);
        matrix ans=query(1,1,n,in[1],out[1]);
        printf("%lld\n",sum-max(ans.m[0][0],ans.m[1][0])+(!x?inf:0)+(!y?inf:0));
        change(a,!x?-inf:inf),change(b,!y?-inf:inf);
    }
    return 0;
}

P7359 「JZOI-1」旅行

由于没有修改操作,故直接倍增查询即可,和 P8820 [CSP-S 2022] 数据传输 类似。

需要维护向上和向下的倍增矩阵,因此乘的顺序有区别,查询的时候就是按顺序 \(s->lca 向上,lca->t 向下\)。(这样转移矩阵写成右乘比较好写?)

点击查看代码
#define pb push_back
const int N=2e5+10;
#define int long long

int n,L,q;
struct node{int v,a,z;};
vector<node>e[N];
int dep[N],anc[N][21];

struct matrix{
    int m[2][2];
    void init(){memset(m,0x3f,sizeof m);}
    matrix operator*(const matrix& a)const{
        matrix res;res.init();
        for(int k=0;k<2;++k)for(int i=0;i<2;++i)for(int j=0;j<2;++j){
            res.m[i][j]=min(res.m[i][j],m[i][k]+a.m[k][j]);
        }return res;
    }
}up[N][21],down[N][21],tmp[N],ans;int tot=0;

void dfs(int u,int fa){
    for(node t:e[u]){
        int v=t.v,a=t.a,z=t.z;
        if(v==fa)continue;
        anc[v][0]=u;dep[v]=dep[u]+1;
        for(int i=1;(1<<i)<=dep[v];++i)anc[v][i]=anc[anc[v][i-1]][i-1];
        matrix &qwq=up[v][0],&qwq2=down[v][0];
        qwq.m[0][0]=qwq.m[1][0]=a,
        qwq.m[0][1]=a-z+L,qwq.m[1][1]=a-z;
        qwq2.m[0][0]=qwq2.m[1][0]=a,
        qwq2.m[0][1]=a+z+L,qwq2.m[1][1]=a+z;
        for(int i=1;(1<<i)<=dep[v];++i){
            up[v][i]=up[v][i-1]*up[anc[v][i-1]][i-1];
            down[v][i]=down[anc[v][i-1]][i-1]*down[v][i-1];
        }
        dfs(v,u);
    }
}

int query(int u,int v){
	ans.init();ans.m[0][0]=0;
	if(dep[u]>dep[v]){
		for(int i=20;i>=0;i--)
			if(dep[u]-(1<<i)>=dep[v])ans=ans*up[u][i],u=anc[u][i];
	}else if(dep[v]>dep[u]){
		for (int i=20;i>=0;i--)
			if(dep[v]-(1<<i)>=dep[u])tmp[++tot]=down[v][i],v=anc[v][i];
	}
	if(u==v){
		while(tot)ans=ans*tmp[tot--];
		return min(ans.m[0][0],ans.m[0][1]);
	}
	for(int i=20;i>=0;i--)
		if(anc[u][i]!=anc[v][i]){
			ans=ans*up[u][i],tmp[++tot]=down[v][i];
			u=anc[u][i],v=anc[v][i];
		}
	ans=ans*up[u][0],tmp[++tot]=down[v][0];
	while(tot)ans=ans*tmp[tot--];
	return min(ans.m[0][0],ans.m[0][1]);
}

signed main(){
    read(n),read(L),read(q);
    for(int i=1,u,v,a,z,t;i<n;++i){
        read(u),read(v),read(a),read(z),read(t);
        e[u].pb((node){v,a,t?-z:z}),e[v].pb((node){u,a,t?z:-z});
    }dfs(1,0);
    for(int i=1,s,t;i<=q;++i){
        read(s),read(t);
        printf("%lld\n",query(s,t));
    }
    return 0;
}
posted @ 2022-11-04 11:07  RuntimeErr  阅读(56)  评论(2编辑  收藏  举报