概念
DDP,可以理解为转移会发生改变的动态规划。
当然这个改变是题目中给的,包括系数,转移位置的改变。显然暴力枚举这些改变是不现实的,我们要把改变体现到其他地方。
最经典的,体现到矩阵上。
我们把转移写成矩阵,那么改变转移就是改变转移矩阵。
具体的改变会落实到具体的题目上。
广义矩阵乘法
因为转移的多样性,矩阵乘法不一定需要用一般乘法的乘完相加。在满足结合律的情况下,可以是乘完取 \(\min\),加完取 \(\max\) 等。
如 CF750E,要删除最少,转移中需要取 \(\min\),所以写成矩阵时,重载乘法就用到了加完取 \(\min\),同时因为其有结合律,其仍旧可以像一般矩阵乘法进行上树等操作。
线段树维护
矩阵满足结合律,可以用线段树维护。
面对每一位转移不同的题目或者只需统计区间答案的题目时,使用线段树维护区间转移矩阵的积是很必要的。
主要是代码实现的难度。
struct mat
{
int mat[6][6];
}a,c;
mat operator *(mat a,mat b)
{
mat c;
memset(c.mat,63,sizeof(c.mat));
for(int k=0;k<5;k++)
{
for(int i=0;i<5;i++)
{
for(int z=0;z<5;z++)
{
c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
mat mul(mat a,mat b)
{
mat c;
memset(c.mat,63,sizeof(c.mat));
for(int k=0;k<5;k++)
{
for(int i=0;i<1;i++)
{
for(int z=0;z<5;z++)
{
c.mat[i][z]=min(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
int n,m,q,rt,w[200001];
mat sum[800001],inn;
void add(int o,int l,int r,int x,mat y)
{
if(l==r)
{
sum[o]=y;
return;
}
int mid=r+l>>1;
if(x<=mid) add((o<<1),l,mid,x,y);
else add((o<<1)+1,mid+1,r,x,y);
sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return sum[o];
int mid=l+r>>1;
if(mid>=y)
{
return get(o<<1,l,mid,x,y);
}
if(x>mid)
{
return get((o<<1)+1,mid+1,r,x,y);
}
return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);
}
解决树上DDP问题
使用树链剖分把树断为链,重链内是序列问题可以自己解决。而重链之间的转移成为难点。
我们称一个重链顶与他的父亲组成一个卡口。改变一个点的值后,所有他到父亲的卡口值会改变。体现轻重链,我们设 \(g_u\) 为只与 \(u\) 亲儿子有关的转移,\(f_{uw}\) 为 \(u\) 的重儿子的 \(DP\) 值,我们必须把 \(f_u\) 转移写成只与 \(g_u\) 和 \(f_{uw}\) 有关的式子。
为什么呢?
保证时间复杂度,因为每个重链内是序列问题,它是不用改变的,而到了卡口,\(g\) 值会变。若和其他 \(f\) 有关,那么改变一个点的值将导致他到根的所有 \(f\) 值改变,因为他们的转移都依赖于此。
模板题
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
using namespace std;
struct mat
{
int mat[2][2];
}gg[100001];
mat operator *(mat a,mat b)
{
mat c;
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=-100000000;
}
}
for(int k=0;k<2;k++)
{
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
mat mul(mat a,mat b)
{
mat c;
for(int i=0;i<2;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=-100000000;
}
}
for(int k=0;k<2;k++)
{
for(int i=0;i<1;i++)
{
for(int z=0;z<2;z++)
{
c.mat[i][z]=max(c.mat[i][z],a.mat[i][k]+b.mat[k][z]);
}
}
}
return c;
}
int n,m,q,rt,w[200001];
mat sum[800001];
int fat[100001],siz[100001],dep[100001],hson[100001],top[100001],cnt,dfn[100001],dis[100001],f[100001][2],downd[100001];
vector<int> g[1000001];
void add(int o,int l,int r,int x,mat y)
{
if(l==r)
{
sum[o]=y;
return;
}
int mid=r+l>>1;
if(x<=mid) add((o<<1),l,mid,x,y);
else add((o<<1)+1,mid+1,r,x,y);
sum[o]=sum[(o<<1)]*sum[(o<<1)+1];
}
mat get(int o,int l,int r,int x,int y)
{
if(x<=l&&y>=r) return sum[o];
int mid=l+r>>1;
if(mid>=y)
{
return get(o<<1,l,mid,x,y);
}
if(x>mid)
{
return get((o<<1)+1,mid+1,r,x,y);
}
return get(o<<1,l,mid,x,y)*get((o<<1)+1,mid+1,r,x,y);
}
void getdfsh(int u,int fa)
{
fat[u]=fa;
dep[u]=dep[fa]+1;
int lll=0;
f[u][1]=w[u];
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa) continue;
getdfsh(v,u);
if(siz[v]>lll)
{
hson[u]=v;
lll=siz[v];
}
siz[u]+=siz[v];
f[u][1]+=f[v][0];
f[u][0]+=max(f[v][0],f[v][1]);
}
siz[u]++;
}
void gettd(int u,int fa)
{
gg[u].mat[1][0]=w[u];
gg[u].mat[1][1]=-100000000;
dfn[u]=++cnt;
dis[u]=cnt;
if(hson[fat[u]]==u)
{
top[u]=top[fa];
downd[top[u]]=dfn[u];
}
else
{
top[u]=u;
downd[top[u]]=dfn[u];
}
if(hson[u]!=0) gettd(hson[u],u);
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa||v==hson[u]) continue;
gettd(v,u);
gg[u].mat[0][0]+=max(f[v][0],f[v][1]);
gg[u].mat[1][0]+=f[v][0];
}
gg[u].mat[0][1]=gg[u].mat[0][0];
}
void getdis(int u, int fa) {
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if (v==fa) continue;
getdis(v,u);
dis[u]=max(dis[u],dis[v]);
}
}
void update(int x,int val)
{
gg[x].mat[1][0]+=val-w[x];
w[x]=val;
while(x)
{
mat las=get(1,1,n,dfn[top[x]],downd[top[x]]);
add(1,1,n,dfn[x],gg[x]);
mat now=get(1,1,n,dfn[top[x]],downd[top[x]]);
x=fat[top[x]];
gg[x].mat[0][0]+=max(now.mat[0][0],now.mat[1][0])-max(las.mat[0][0],las.mat[1][0]);
gg[x].mat[0][1]=gg[x].mat[0][0];
gg[x].mat[1][0]+=now.mat[0][0]-las.mat[0][0];
}
}
signed main()
{
scanf("%d",&n);
scanf("%d",&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&w[i]);
}
for(int i=1,u,v;i<n;i++)
{
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
getdfsh(1,0);
gettd(1,0);
getdis(1,0);
for(int i=1;i<=n;i++)
{
add(1,1,n,dfn[i],gg[i]);
}
for(int i=1;i<=m;i++)
{
int x,val;
scanf("%d%d",&x,&val);
update(x,val);
mat ans=get(1,1,n,1,downd[1]);
printf("%d\n",max(ans.mat[0][0],ans.mat[1][0]));
}
}