【DP】动态 DP
准备退役 whk 了,最后学点东西。
不得不承认,CSPS2022 T4 对动态 DP 起到了良好的普及效果。
P4719 【模板】"动态 DP"&动态树分治
设 \(f_{u,0/1}\) 表示不选/选 \(u\),\(u\) 的子树内的最大权独立集。
不带修改的情况,有
\[f_{u,0}=\sum\max(f_{v,0},f_{v,1})
\]
\[f_{u,1}=val_u+\sum f_{v,0}
\]
答案即为 \(\max(f_{1,0},f_{1,1})\)。
如果带修改呢?修改的结点影响的仅是其到根节点的路径。可以用树剖(重链剖分),为啥?
树剖有很好的性质:重儿子的 \(dfs\) 序连续,任意节点到根节点最多经过 \(\log n\) 条轻边。
这启示我们可以把轻儿子的信息合并起来,对于重儿子在跳重链的时候合并。
设 \(g_{u,0/1}\) 表示只考虑轻儿子,且 \(u\) 不选/选的最大权独立集。则有
\[f_{u,0}=g_{u,0}+\max(f_{son_u,0},f_{son_u,1})
\]
\[f_{u,1}=g_{u,1}+f_{son_u,0}
\]
\(son_u\) 表示 \(u\) 的重儿子。这样我们还把 \(\sum\) 去掉了。
现在的问题就是如何快速合并和修改了,注意到可以使用广义矩阵乘法实现,有
\[f_{u,0}=\max(g_{u,0}+f_{son_u,0},g_{u,0}+f_{son_u,1})
\]
\[f_{u,1}=\max(g_{u,1}+f_{son_u,0},-\infty)
\]
即
\[\begin{bmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&-\infty\end{bmatrix}\begin{bmatrix}f_{son_u,0}\\f_{son_u,1}\end{bmatrix}=\begin{bmatrix}f_{u,0}\\f_{u,1}\end{bmatrix}
\]
注意这里矩阵一定得这么写,因为链头在区间左端,链尾在区间右端,维护的初始信息在叶子结点,所以只能把它放在右边。
怎么合并?废话,重链直接合并,向上跳轻边的时候合并到上一级重链,修改的时候用增量的方式改变父亲代表的矩阵。
点击查看代码
#define pb push_back
const int N=1e5+10,inf=0x3f3f3f3f;
int n,m;
vector<int> e[N];
int siz[N],fa[N],son[N],in[N],out[N],top[N],rnk[N],tim=0;
int f[N][2],a[N];
struct matrix{
int m[2][2];
inline matrix operator *(const matrix& b)const{
matrix res;
res.m[0][0]=res.m[0][1]=res.m[1][0]=res.m[1][1]=-inf;
for(int k=0;k<2;++k){
for(int i=0;i<2;++i){
for(int j=0;j<2;++j){
res.m[i][j]=max(res.m[i][j],m[i][k]+b.m[k][j]);
}
}
}return res;
}
}val[N],mat[N<<2],bef,aft;
void dfs1(int u,int f){
fa[u]=f,siz[u]=1;
for(int v:e[u])if(v!=f){
dfs1(v,u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t){//预处理出初始矩阵
top[u]=t,rnk[in[u]=++tim]=u,out[t]=max(out[t],tim);
//out 表示链顶所在链的链底
f[u][0]=0,f[u][1]=a[u];
val[u].m[0][0]=val[u].m[0][1]=0;
val[u].m[1][0]=a[u];val[u].m[1][1]=-inf;
if(son[u]){
dfs2(son[u],t);//重儿子不用算到 g 里面
f[u][0]+=max(f[son[u]][0],f[son[u]][1]);
f[u][1]+=f[son[u]][0];
}
for(int v:e[u])if(v!=fa[u]&&v!=son[u]){
dfs2(v,v);
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
val[u].m[0][0]+=max(f[v][0],f[v][1]);
val[u].m[0][1]=val[u].m[0][0];
val[u].m[1][0]+=f[v][0];
}
}
#define ls p<<1
#define rs p<<1|1
void build(int p,int l,int r){
if(l==r)return mat[p]=val[rnk[l]],void();
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
mat[p]=mat[ls]*mat[rs];
}
void modify(int p,int l,int r,int pos){
if(l==r)return mat[p]=val[rnk[pos]],void();
int mid=(l+r)>>1;
if(pos<=mid)modify(ls,l,mid,pos);
else modify(rs,mid+1,r,pos);
mat[p]=mat[ls]*mat[rs];
}
matrix query(int p,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr)return mat[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ls,l,mid,ql,qr);
if(ql>mid)return query(rs,mid+1,r,ql,qr);
return query(ls,l,mid,ql,mid)*query(rs,mid+1,r,mid+1,qr);
}
void change(int u,int k){
val[u].m[1][0]+=k-a[u];a[u]=k;
while(u){
bef=query(1,1,n,in[top[u]],out[top[u]]);
modify(1,1,n,in[u]);
aft=query(1,1,n,in[top[u]],out[top[u]]);
u=fa[top[u]];
//向上跳轻边用增量法修改,注意矩阵中的元素代表的意义以及转移式
val[u].m[0][0]+=max(aft.m[0][0],aft.m[1][0])-max(bef.m[0][0],bef.m[1][0]);
val[u].m[0][1]=val[u].m[0][0];
val[u].m[1][0]+=aft.m[0][0]-bef.m[0][0];
}
}
int main(){
read(n),read(m);
for(int i=1;i<=n;++i)read(a[i]);
for(int i=1,u,v;i<n;++i){
read(u),read(v);
e[u].pb(v),e[v].pb(u);
}dfs1(1,0),dfs2(1,1);build(1,1,n);
for(int i=1,x,y;i<=m;++i){
read(x),read(y);
change(x,y);
matrix ans=query(1,1,n,in[1],out[1]);
printf("%d\n",max(ans.m[0][0],ans.m[1][0]));
}
return 0;
}
P5024 [NOIP2018 提高组] 保卫王国
最小权覆盖 = 全集 - 最大权独立集
对于强制选择,则给他的权值加上 \(-\infty\);若强制不选,则加上 \(+\infty\)。再给权值加上对应的值就行了。之后的内容就和上面一样了。记得询问完要改回来。
点击查看代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define getchar()(p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
template <typename T>
inline void read(T& r) {
r=0;bool w=0; char ch=getchar();
while(ch<'0'||ch>'9') w=ch=='-'?1:0,ch=getchar();
while(ch>='0'&&ch<='9') r=(r<<3)+(r<<1)+(ch^48), ch=getchar();
r=w?-r:r;
}
#define pb push_back
#define int long long
const int N=1e5+10,inf=1e10;
int n,m;
vector<int> e[N];
int siz[N],fa[N],son[N],in[N],out[N],top[N],rnk[N],tim=0;
int f[N][2],a[N],sum=0;
struct matrix{
int m[2][2];
inline matrix operator *(const matrix& b)const{
matrix res;
res.m[0][0]=res.m[0][1]=res.m[1][0]=res.m[1][1]=-inf;
for(int k=0;k<2;++k){
for(int i=0;i<2;++i){
for(int j=0;j<2;++j){
res.m[i][j]=max(res.m[i][j],m[i][k]+b.m[k][j]);
}
}
}return res;
}
}val[N],mat[N<<2],bef,aft;
void dfs1(int u,int f){
fa[u]=f,siz[u]=1;
for(int v:e[u])if(v!=f){
dfs1(v,u);siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int t){
top[u]=t,rnk[in[u]=++tim]=u,out[t]=max(out[t],tim);
f[u][0]=0,f[u][1]=a[u];
val[u].m[0][0]=val[u].m[0][1]=0;
val[u].m[1][0]=a[u];val[u].m[1][1]=-inf;
if(son[u]){
dfs2(son[u],t);
f[u][0]+=max(f[son[u]][0],f[son[u]][1]);
f[u][1]+=f[son[u]][0];
}
for(int v:e[u])if(v!=fa[u]&&v!=son[u]){
dfs2(v,v);
f[u][0]+=max(f[v][0],f[v][1]);
f[u][1]+=f[v][0];
val[u].m[0][0]+=max(f[v][0],f[v][1]);
val[u].m[0][1]=val[u].m[0][0];
val[u].m[1][0]+=f[v][0];
}
}
#define ls p<<1
#define rs p<<1|1
void build(int p,int l,int r){
if(l==r)return mat[p]=val[rnk[l]],void();
int mid=(l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
mat[p]=mat[ls]*mat[rs];
}
void modify(int p,int l,int r,int pos){
if(l==r)return mat[p]=val[rnk[pos]],void();
int mid=(l+r)>>1;
if(pos<=mid)modify(ls,l,mid,pos);
else modify(rs,mid+1,r,pos);
mat[p]=mat[ls]*mat[rs];
}
matrix query(int p,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr)return mat[p];
int mid=(l+r)>>1;
if(qr<=mid)return query(ls,l,mid,ql,qr);
if(ql>mid)return query(rs,mid+1,r,ql,qr);
return query(ls,l,mid,ql,mid)*query(rs,mid+1,r,mid+1,qr);
}
void change(int u,int k){
val[u].m[1][0]+=k;a[u]+=k;
while(u){
bef=query(1,1,n,in[top[u]],out[top[u]]);
modify(1,1,n,in[u]);
aft=query(1,1,n,in[top[u]],out[top[u]]);
u=fa[top[u]];
val[u].m[0][0]+=max(aft.m[0][0],aft.m[1][0])-max(bef.m[0][0],bef.m[1][0]);
val[u].m[0][1]=val[u].m[0][0];
val[u].m[1][0]+=aft.m[0][0]-bef.m[0][0];
}
}
char skip[5];
signed main(){
scanf("%lld%lld%s",&n,&m,skip);
for(int i=1;i<=n;++i)read(a[i]),sum+=a[i];
for(int i=1,u,v;i<n;++i){
read(u),read(v);
e[u].pb(v),e[v].pb(u);
}dfs1(1,0),dfs2(1,1);build(1,1,n);
for(int i=1,a,x,b,y;i<=m;++i){
read(a),read(x),read(b),read(y);
if((fa[a]==b||fa[b]==a)&&!x&&!y){printf("-1\n");continue;}
change(a,!x?inf:-inf),change(b,!y?inf:-inf);
matrix ans=query(1,1,n,in[1],out[1]);
printf("%lld\n",sum-max(ans.m[0][0],ans.m[1][0])+(!x?inf:0)+(!y?inf:0));
change(a,!x?-inf:inf),change(b,!y?-inf:inf);
}
return 0;
}
P7359 「JZOI-1」旅行
由于没有修改操作,故直接倍增查询即可,和 P8820 [CSP-S 2022] 数据传输 类似。
需要维护向上和向下的倍增矩阵,因此乘的顺序有区别,查询的时候就是按顺序 \(s->lca 向上,lca->t 向下\)。(这样转移矩阵写成右乘比较好写?)
点击查看代码
#define pb push_back
const int N=2e5+10;
#define int long long
int n,L,q;
struct node{int v,a,z;};
vector<node>e[N];
int dep[N],anc[N][21];
struct matrix{
int m[2][2];
void init(){memset(m,0x3f,sizeof m);}
matrix operator*(const matrix& a)const{
matrix res;res.init();
for(int k=0;k<2;++k)for(int i=0;i<2;++i)for(int j=0;j<2;++j){
res.m[i][j]=min(res.m[i][j],m[i][k]+a.m[k][j]);
}return res;
}
}up[N][21],down[N][21],tmp[N],ans;int tot=0;
void dfs(int u,int fa){
for(node t:e[u]){
int v=t.v,a=t.a,z=t.z;
if(v==fa)continue;
anc[v][0]=u;dep[v]=dep[u]+1;
for(int i=1;(1<<i)<=dep[v];++i)anc[v][i]=anc[anc[v][i-1]][i-1];
matrix &qwq=up[v][0],&qwq2=down[v][0];
qwq.m[0][0]=qwq.m[1][0]=a,
qwq.m[0][1]=a-z+L,qwq.m[1][1]=a-z;
qwq2.m[0][0]=qwq2.m[1][0]=a,
qwq2.m[0][1]=a+z+L,qwq2.m[1][1]=a+z;
for(int i=1;(1<<i)<=dep[v];++i){
up[v][i]=up[v][i-1]*up[anc[v][i-1]][i-1];
down[v][i]=down[anc[v][i-1]][i-1]*down[v][i-1];
}
dfs(v,u);
}
}
int query(int u,int v){
ans.init();ans.m[0][0]=0;
if(dep[u]>dep[v]){
for(int i=20;i>=0;i--)
if(dep[u]-(1<<i)>=dep[v])ans=ans*up[u][i],u=anc[u][i];
}else if(dep[v]>dep[u]){
for (int i=20;i>=0;i--)
if(dep[v]-(1<<i)>=dep[u])tmp[++tot]=down[v][i],v=anc[v][i];
}
if(u==v){
while(tot)ans=ans*tmp[tot--];
return min(ans.m[0][0],ans.m[0][1]);
}
for(int i=20;i>=0;i--)
if(anc[u][i]!=anc[v][i]){
ans=ans*up[u][i],tmp[++tot]=down[v][i];
u=anc[u][i],v=anc[v][i];
}
ans=ans*up[u][0],tmp[++tot]=down[v][0];
while(tot)ans=ans*tmp[tot--];
return min(ans.m[0][0],ans.m[0][1]);
}
signed main(){
read(n),read(L),read(q);
for(int i=1,u,v,a,z,t;i<n;++i){
read(u),read(v),read(a),read(z),read(t);
e[u].pb((node){v,a,t?-z:z}),e[v].pb((node){u,a,t?z:-z});
}dfs(1,0);
for(int i=1,s,t;i<=q;++i){
read(s),read(t);
printf("%lld\n",query(s,t));
}
return 0;
}