[模板] 动态DP
一、题目
二、解法
动态 \(dp\) 的思路主要是用矩阵乘法加速 \(dp\),所以首先要知道矩阵乘法的扩展版:
\[c(i,k)=\max\{a(i,j)+b(j,k)\}
\]
令人震惊的是上面这东西也满足结合律,现在我们来证明一下,假设有三个矩阵 \(a,b,c\) 相乘,大小分别是 \(n\times m,m\times p,p\times q\),我们把最终某一个位置上的值暴力展开:
\[(i,j)=\max_{k=1}^ma(i,k)+\Big(\max_{l=1}^p b(k,l)+c(l,j)\Big)
\]
\[=\max_{k=1}^m\max_{l=1}^p a(i,k)+b(k,l)+c(l,j)
\]
\[=\max_{l=1}^p\Big(\max_{k=1}^m a(i,k)+b(k,l)\Big)+c(l,j)
\]
所以先乘 \(a,b\) 还是先乘 \(b,c\) 对答案没有影响,结合律得证。
首先写出暴力的 \(dp\) 柿子,设 \(f(u,0/1)\) 表示 \(u\) 这个点不选\(/\)选的最大权值,转移:
\[f(u,0)=\sum\max(f(v,0),f(v,1))
\]
\[f(u,1)=a(u)+\sum f(v,0)
\]
先来考虑一下链怎么做,我们构造一个像这样的转移矩阵:
\[\left(\begin{matrix}0&0\\a(u)&-\infty\end{matrix}\right)\times\left(\begin{matrix}f(v,0)\\f(v,1)\end{matrix}\right)=\left(\begin{matrix}f(u,0)\\f(u,1)\end{matrix}\right)
\]
然后要求根的 \(dp\) 值就直接把所有矩阵乘起来就行了,时间复杂度 \(O(n\log n)\)
那么我们能不能把上面的做法搬到树上呢?考虑把树剖分成链然后套上面的做法,也就是用树链剖分。每个点的转移矩阵就针对他的重儿子来定义,但同时我们要考虑轻儿子对他 \(dp\) 值的贡献,所以再定义 \(f'(u,0/1)\) 表示 \(u\) 不选\(/\)选,考虑 \(u\) 和 \(u\) 的所有轻儿子的最大值,那么有如下转移:
\[f(u,0)=f'(u,0)+\max\{f(son,0),f(son,1)\}
\]
\[f(u,1)=f'(u,1)+f(son,0)
\]
写成矩阵就是这个样子的:
\[\left(\begin{matrix}f'(u,0)&f’(u,0)\\f'(u,1)&-\infty\end{matrix}\right)\times\left(\begin{matrix}f(son,0)\\f(son,1)\end{matrix}\right)=\left(\begin{matrix}f(u,0)\\f(u,1)\end{matrix}\right)
\]
先考虑怎么统计答案,我们找到根所在的那条重链,把所有转移矩阵乘起来就行了。
再考虑如何修改,修改一个点的点权只会对它的祖先产生影响。而且由于路径上只有 \(O(\log n)\) 条轻边,所以一共只需要改 \(O(\log n)\) 个矩阵,这部分可以看看代码:
void modify(int u,int w)//把u点权改成w
{
val[u].a[1][0]+=w-a[u];
a[u]=w;
while(u)
{
matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);//算出f(u,0/1)
upd(1,1,n,num[u]);//在线段树上更新那个位置的矩阵
matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);//算出新的f(u,0/1)
u=fa[top[u]];//要更新重链顶端父亲的转移矩阵
val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
val[u].a[0][1]=val[u].a[0][0];
val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
}
}
用一个线段树维护矩阵套上树链剖分:\(O(2^3\cdot n\log^2 n)\)
#include <cstdio>
#include <iostream>
using namespace std;
const int M = 100005;
const int inf = 1e9;
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,tot,cnt,f[M],a[M],id[M],fa[M];
int siz[M],son[M],num[M],top[M],bot[M],dp[M][2];
//top表示重链头
//bot表示重链尾
//num表示这个点在线段树上的位置
//id表示线段树上位置所对应的点
//dp表示初始的dp数组
struct edge
{
int v,next;
edge(int V=0,int N=0) : v(V) , next(N) {}
}e[2*M];
struct matrix
{
int a[2][2];
matrix() {a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
matrix operator * (const matrix &b) const
{
matrix r;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
for(int k=0;k<2;k++)
r.a[i][k]=max(r.a[i][k],a[i][j]+b.a[j][k]);
return r;
}
void print()
{
puts("---------");
for(int i=0;i<2;i++,puts(""))
for(int j=0;j<2;j++)
printf("%d ",a[i][j]);
}
}val[M],tr[4*M];
//线段树部分
void up(int i)
{
tr[i]=tr[i<<1]*tr[i<<1|1];
}
void build(int i,int l,int r)
{
if(l==r)
{
tr[i]=val[id[l]];
return ;
}
int mid=(l+r)>>1;
build(i<<1,l,mid);
build(i<<1|1,mid+1,r);
up(i);
}
void upd(int i,int l,int r,int x)//修改x这个位置的矩阵
{
if(l==r)
{
tr[i]=val[id[x]];
return ;
}
int mid=(l+r)>>1;
if(mid>=x) upd(i<<1,l,mid,x);
else upd(i<<1|1,mid+1,r,x);
up(i);
}
matrix ask(int i,int l,int r,int L,int R)
{
if(L<=l && r<=R) return tr[i];
int mid=(l+r)>>1;
if(R<=mid) return ask(i<<1,l,mid,L,R);
if(L>mid) return ask(i<<1|1,mid+1,r,L,R);
return ask(i<<1,l,mid,L,R)*ask(i<<1|1,mid+1,r,L,R);
}
//树链剖分部分
void dfs1(int u,int p)
{
siz[u]=1;fa[u]=p;
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==p) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
num[u]=++cnt;
id[cnt]=u;
val[u].a[0][0]=val[u].a[0][1]=0;
val[u].a[1][0]=dp[u][1]=a[u];
if(son[u])
{
dfs2(son[u],tp),bot[u]=bot[son[u]];
dp[u][0]+=max(dp[son[u]][0],dp[son[u]][1]);
dp[u][1]+=dp[son[u]][0];
}
else bot[u]=u;//如果没有重儿子底部就是自己
for(int i=f[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa[u] || v==son[u]) continue;
dfs2(v,v);
dp[u][0]+=max(dp[v][0],dp[v][1]);
dp[u][1]+=dp[v][0];
val[u].a[0][0]+=max(dp[v][0],dp[v][1]);
val[u].a[0][1]=val[u].a[0][0];
val[u].a[1][0]+=dp[v][0];
//(0,0)/(0,1)表示这个点不选,(1,0)表示这个点要选
}
}
void modify(int u,int w)//把u点权改成w
{
val[u].a[1][0]+=w-a[u];
a[u]=w;
while(u)
{
matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);
upd(1,1,n,num[u]);
matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);
u=fa[top[u]];
val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
val[u].a[0][1]=val[u].a[0][0];
val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
}
}
signed main()
{
n=read();m=read();
for(int i=1;i<=n;i++)
a[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
e[++tot]=edge(v,f[u]),f[u]=tot;
e[++tot]=edge(u,f[v]),f[v]=tot;
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
while(m--)
{
int x=read(),y=read();
modify(x,y);
matrix t1=ask(1,1,n,num[1],num[bot[1]]);
printf("%d\n",max(t1.a[0][0],t1.a[1][0]));
}
}