动态dp学习笔记
我们经常会遇到一些问题,是一些dp的模型,但是加上了什么待修改强制在线之类的,十分毒瘤,如果能有一个模式化的东西解决这类问题就会非常好。
给定一棵n个点的树,点带点权。
有m次操作,每次操作给定x,y,表示修改点x的权值为y。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
如果不带修改,那就是一个最简单是树形dp问题。
我们设一个dp[i][0],dp[i][1]表示以i为根的子树
动态dp能够使用的一个前提就是它的转移是线性的,这样我们就可以用矩阵乘法实现快速转移了。
注意:这里的矩阵乘法是广义的,中间运算不一定是乘法,最后也不一定是求和,只要能满足矩阵乘法的性质就可以了。
重链剖分
这也是动态dp比较关键的内容,因为问题在树上,树的每个节点都可能有多个儿子节点,直接算贡献比较麻烦。
所以用重链剖分只保留一个儿子,其他的儿子放在一起统一计算,这样我们就把一个树上问题转化成了序列上的问题。
比如这道题,我们把树轻重链划分完后。
我们把轻子树的答案算完后直接加入状态中,然后答案就变成了一条重链的矩阵连乘积,用线段树维护矩阵的乘积即可。
每次修改时,根据重链剖分,答案包含这个点的位置最多有log个,所以每次就对这些位置修改就好了 。
代码
#include<iostream> #include<cstdio> #include<cstring> #define N 100002 using namespace std; typedef long long ll; int tot,head[N],size[N],deep[N],fa[N],son[N],top[N],dp[N][2],dfn[N],tag[N],ed[N],a[N],cntt,ls[N<<1],rs[N<<1],n,m,root; inline ll rd(){ ll x=0;char c=getchar();bool f=0; while(!isdigit(c)){if(c=='-')f=1;c=getchar();} while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();} return f?-x:x; } struct edge{int n,to;}e[N<<1]; inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;} struct matrix{ int a[2][2]; matrix(){memset(a,-0x3f,sizeof(a));} matrix operator *(const matrix &b)const{ matrix c; for(int i=0;i<2;++i) for(int j=0;j<2;++j) for(int k=0;k<2;++k) c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]); return c; } }data[N],tr[N<<1]; void dfs1(int u){ size[u]=1; for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]){ int v=e[i].to;deep[v]=deep[u]+1;fa[v]=u; dfs1(v); size[u]+=size[v]; if(size[v]>size[son[u]])son[u]=v; } } void dfs2(int u){ dfn[u]=++dfn[0];tag[dfn[0]]=u; if(!top[u])top[u]=u; ed[top[u]]=max(ed[top[u]],dfn[u]); data[u].a[0][0]=data[u].a[0][1]=0; data[u].a[1][0]=a[u]; dp[u][1]=a[u]; if(son[u]){ top[son[u]]=top[u],dfs2(son[u]); dp[u][0]+=max(dp[son[u]][0],dp[son[u]][1]); dp[u][1]+=dp[son[u]][0]; } for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]&&e[i].to!=son[u]){ int v=e[i].to;dfs2(v); dp[u][0]+=max(dp[v][0],dp[v][1]); dp[u][1]+=dp[v][0]; data[u].a[0][1]+=max(dp[v][0],dp[v][1]); data[u].a[0][0]+=max(dp[v][0],dp[v][1]); data[u].a[1][0]+=dp[v][0]; } } void build(int &cnt,int l,int r){ if(!cnt)cnt=++cntt; if(l==r){tr[cnt]=data[tag[l]];return;} int mid=(l+r)>>1; build(ls[cnt],l,mid);build(rs[cnt],mid+1,r); tr[cnt]=tr[ls[cnt]]*tr[rs[cnt]]; } void upd(int cnt,int l,int r,int x){ if(l==r){tr[cnt]=data[tag[x]];return;} int mid=(l+r)>>1; if(mid>=x)upd(ls[cnt],l,mid,x); else upd(rs[cnt],mid+1,r,x); tr[cnt]=tr[ls[cnt]]*tr[rs[cnt]]; } matrix query(int cnt,int l,int r,int L,int R){ if(l>=L&&r<=R)return tr[cnt]; int mid=(l+r)>>1; if(mid>=L&&mid<R)return query(ls[cnt],l,mid,L,R)*query(rs[cnt],mid+1,r,L,R); else if(mid>=L)return query(ls[cnt],l,mid,L,R); else return query(rs[cnt],mid+1,r,L,R); } void _upd(int u,int vall){ data[u].a[1][0]+=vall-a[u]; a[u]=vall; matrix now,pre; while(u){ pre=query(1,1,n,dfn[top[u]],ed[top[u]]); upd(1,1,n,dfn[u]); now=query(1,1,n,dfn[top[u]],ed[top[u]]); u=fa[top[u]]; data[u].a[0][0]+=max(now.a[1][0],now.a[0][0])-max(pre.a[1][0],pre.a[0][0]); data[u].a[0][1]=data[u].a[0][0]; data[u].a[1][0]+=now.a[0][0]-pre.a[0][0]; } } int main(){ n=rd();m=rd(); for(int i=1;i<=n;++i)a[i]=rd(); int u,v; for(int i=1;i<n;++i){ u=rd();v=rd(); add(u,v);add(v,u); } dfs1(1);dfs2(1); build(root,1,n); while(m--){ u=rd();v=rd(); _upd(u,v); matrix nowans=query(1,1,n,dfn[1],ed[1]); printf("%d\n",max(nowans.a[0][1],max(nowans.a[0][0],nowans.a[1][0]))); } return 0; }