动态DP
应用
动态\(DP\)主要是解决:在树上或链上\(dp\)后,后期对树上链上的点进行修改,然后询问修改后的答案。
其经典例题:
给\(n\)个点的树,给出每个点的点权,求最大权独立集。中途给出\(m\)个修改,每次修改后输出修改后的最优答案。
前置算法
我们主要考虑树上,解决这类问题,需要用到三个算法,树形\(dp\),树链剖分,矩阵乘法。
树形\(dp\)
我们先简单考虑不进行修改,那么这是一道非常简单的树形\(dp\)题,设\(f[x][1]\)为此点必选,\(f[x][0]\)为此点必不选的最大权值。
转移是,设\(v\)为\(x\)子节点,\(f[x][1]=\sum f[v][0] +a[x]\),\(f[x][0]=\sum max(f[v][0],f[v][1])\)。
设\(1\)为根节点。答案就是\(max(f[1][0],f[1][1])\)。
树链剖分
往往在树上进行多点修改,就需要用到\(dfs\)序或者树链剖分,前者主要是处理子树,后者是处理链,树链剖分里面分为重链和轻链,因为轻链连接的子树大小是小于等于其父亲的子树的大小的一半,所以轻链只有\(log\)个,那么在上跳的时候,也就只会跳\(log\)次。
矩阵乘法
矩阵乘法是把一个转移式变成矩阵,然后用矩阵乘法来计算,就不用一个一个地转移,而是很快的乘出总的转移式,在修改某处的\(dp\)值后,普通做法是暴力更新一遍,而矩阵乘法可以结合线段树,单点更新后按照线段树操作合并即可,大大优化了复杂度。
动态\(dp\)
讲到动态\(dp\)了,动态\(dp\)的大致思路是 : 把重链的转移与轻链分离,重链上的每个点都有个转移矩阵,转移矩阵的值由与这个点相连的轻链上的值决定,每次修改一个点,就依次沿着重链上跳,每次跳到一个新的重链上时,就根据刚刚跳过来的轻链上的值,修改这个点的转移矩阵,继续上跳。
具体做法:我们设\(g[x][0]\)表示\(x\)点必不选,排除重儿子的情况下,的最大收益,\(g[x][1]\)则表示\(x\)点必选的最大收益。
转移和前面树形\(dp\)类似,我们原来的\(f\)转移就变成了:(设\(v\)为重儿子)
\(f[x][0]=g[x][0]+max(f[v][0],f[v][1]).\) \(f[x][1]=g[x][1]+f[v][0]\)
我们对这个式子变一下:
\(f[x][0]=max(g[x][0]+f[v][0],g[x][0]+f[v][1])\) \(f[x][1]=(-inf+f[v][1],g[x][1]+f[v][0])\)
我们发现,这个和矩阵转移式很像,但是\(+\)法变成了取\(max\)!
感性手推了一下,发现也满足结合律......感性理解感性理解
所以我们定义一个新的矩阵乘法,由\(\sum a_{ik}+b_{kj}\)改为\(max( a_{ik}+b_{kj})\)。
那么可得:
这个就是转移矩阵了:
每次修改一个点的值后,相应修改其转移矩阵,然后跳到重链顶端,求出重链顶端的\(f\)值,由这个\(f\)去更新其父亲点的\(g\),然后同时更新其夫妻点的转移矩阵。然后继续上跳即可。
贴一下代码:
#include <bits/stdc++.h>
using namespace std;
int n,m;
const int MAXN=1e5+5;
struct mat{
int n,m;
int w[3][3];
};
mat operator * (const mat &aa,const mat &bb){
mat nw;
nw.n=aa.n,nw.m=bb.m;
for(int i=1;i<=nw.n;i++){
for(int j=1;j<=nw.m;j++){
nw.w[i][j]=-100000000;
for(int k=1;k<=aa.m;k++){
nw.w[i][j]=max(nw.w[i][j],aa.w[i][k]+bb.w[k][j]);
}
}
}
return nw;
}
int a[MAXN];
int g[MAXN][2];
mat f[MAXN];
int cnt;
int dep[MAXN];
int idx[MAXN];
int fa[MAXN];
int top[MAXN];
int maxn[MAXN];
int son[MAXN];
int siz[MAXN];
mat t[MAXN*4];
vector<int> q[MAXN];
void update(int l,int r,int x,mat v,int id){
if(l==r){
t[id]=v;
return ;
}
int mid=(l+r)/2;
if(mid>=x)update(l,mid,x,v,id*2);
else update(mid+1,r,x,v,id*2+1);
t[id]=t[id*2]*t[id*2+1];
}
mat query(int l,int r,int z,int y,int id){
if(l==z&&r==y)return t[id];
int mid=(l+r)/2;
if(mid>=y)return query(l,mid,z,y,id*2);
else if(mid<z)return query(mid+1,r,z,y,id*2+1);
else return query(l,mid,z,mid,id*2)*query(mid+1,r,mid+1,y,id*2+1);
}
void dfs_init(int x,int pr){
fa[x]=pr,dep[x]=dep[pr]+1;
for(int i=0;i<q[x].size();i++){
int nx=q[x][i];
if(nx==pr)continue;
dfs_init(nx,x);
siz[x]+=siz[nx];
if(siz[nx]>=siz[son[x]])son[x]=nx;
f[x].w[1][1]+=f[nx].w[2][1];
f[x].w[2][1]+=max(f[nx].w[1][1],f[nx].w[2][1]);
}
f[x].w[1][1]+=a[x];
siz[x]++;
}
mat G(int x){
mat nw;
nw.m=nw.n=2;
nw.w[1][1]=-100000000,nw.w[1][2]=g[x][1];
nw.w[2][1]=nw.w[2][2]=g[x][0];
return nw;
}
void calg(int x){
g[x][0]=g[x][1]=0;
for(int i=0;i<q[x].size();i++){
int nx=q[x][i];
if(nx==fa[x]||nx==son[x])continue;
g[x][0]+=max(f[nx].w[1][1],f[nx].w[2][1]);
g[x][1]+=f[nx].w[2][1];
}
g[x][1]+=a[x];
}
void dfs_link(int x,int pr){
cnt++,idx[x]=cnt;
maxn[top[x]]=x;
if(son[x]){
top[son[x]]=top[x];
dfs_link(son[x],x);
}
for(int i=0;i<q[x].size();i++){
int nx=q[x][i];
if(nx==pr||nx==son[x])continue;
top[nx]=nx;
dfs_link(nx,x);
}
calg(x);
update(1,n,idx[x],G(x),1);
}
void up(int x,int v){
a[x]=v;
while(x)
{
calg(x);
update(1,cnt,idx[x],G(x),1);
f[top[x]]=query(1,n,idx[top[x]],idx[maxn[top[x]]],1)*f[son[maxn[top[x]]]];
x=fa[top[x]];
}
}
int main()
{
scanf("%d%d",&n,&m);
f[0].n=2,f[0].m=1;f[0].w[1][1]=f[0].w[2][1]=0;
for(int i=1;i<=n;i++){
f[i].n=2;f[i].m=1;f[i].w[1][1]=f[i].w[2][1]=0;
scanf("%d",&a[i]);
}
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
q[x].push_back(y);
q[y].push_back(x);
}
dfs_init(1,0);
top[1]=1;
dfs_link(1,0);
for(int op=1;op<=m;op++)
{
int x,v;
scanf("%d%d",&x,&v);
up(x,v);
printf("%d\n",max(f[1].w[1][1],f[1].w[2][1]));
}
return 0;
}