初探动态DP

最近学了动态\(DP\),非常精妙的方法啊
写一篇博客记录一下


例题1

这道题\(Luogu1352\),相信大家都做过.
可以很快地写出转移方程.
\(f[i][1]\)表示第i个节点选择的最大快乐度,\(f[i][0]\)表示第i个节点不选的快乐度,\(f[u][0]=\sum_{v\in sonu}max(f[u][0],f[u][1]),f[u][1]=\sum_{v\in sonu}f[u][0]\)


现在,增加一个操作,就是修改点权,然后依然查询最大独立集.
于是,难度瞬间变成了黑题.
现在考虑如何维护.

首先,因为要修改,常用的是树剖.
我们可以在前文所述的\(f\)数组上增加一个\(g\)数组
然后一条一条重链\(dp\).
这里我们在处理一条重链的时候,先处理与它相连的所有重链,最后处理它.\(g\)数组维护所有轻儿子的信息,而\(f\)数组维护\(g\)和重儿子的信息(也就是所有儿子).因此可以很快地写出转移方程
\(g[u][0]=\sum_{v\in lightsonu}max(f[u][0],f[u][1]),g[u][1]=\sum_{v\in lightsonu}f[u][0]\)
\(v\)\(u\)的重儿子,则
\(f[u][0]=g[u][0]+min(f[v][1],f[v][0]),f[u][1]=f[v][0]+g[u][1]\)
于是可以支持用线段树修改


一个常见的黑科技是把转移方程写成矩阵的形式,用线段树(树链剖分)维护矩阵乘积即可.
但是这个矩阵不是通常意义下的矩阵.
我们常写的矩阵是这样的
\(a[i][j]=\sum_{k=1}^na[i][k]*a[k][j]\)
而现在的矩阵是这样的
\(a[i][j]=max\{a[i][k]*a[k][j]\}\)
矩阵的所有性质都可行(结合律,交换律,分配律...)
因此我们就做完了这个问题,时间复杂度\(O(2^3nlog^2n)\)


代码

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#define N (100010)
#define M (200010)
#define inf (0x7f7f7f7f)
#define rg register int
#define Label puts("NAIVE")
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
	static const int IN_LEN=1000000;
	static char buf[IN_LEN],*s,*t;
	return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
	static bool iosig;
	static char c;
	for(iosig=false,c=read();!isdigit(c);c=read()){
		if(c=='-')iosig=true;
		if(c==-1)return;
	}
	for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
	if(iosig)x=-x;
}
inline char readchar(){
	static char c;
	for(c=read();!isalpha(c);c=read())
	if(c==-1)return 0;
	return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
	if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
	*ooh++=c;
}
template<class T>
inline void print(T x){
	static int buf[30],cnt;
	if(x==0)print('0');
	else{
		if(x<0)print('-'),x=-x;
		for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
		while(cnt)print((char)buf[cnt--]);
	}
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
struct Matrix{
	LL a[2][2];
	Matrix(){memset(a,0,sizeof(a));}
	Matrix operator *(Matrix x){
		Matrix res;
		for(int i=0;i<2;i++)
		for(int j=0;j<2;j++)
		for(int k=0;k<2;k++)
		res.a[i][j]=max(res.a[i][j],a[i][k]+x.a[k][j]);
		return res;
	}
}a[N<<4],val[N],ans;
int n,m,w[N];
int fi[M],ne[M],b[M],E,ind;
int top[N],fa[N],siz[N],son[N],dfn[N],rdfn[N],ed[N];
LL f[N][2];
void add(int x,int y){
	ne[++E]=fi[x],fi[x]=E,b[E]=y;
}
void dfs1(int u,int pre){
	int maxsiz=-1,ma=0; fa[u]=pre;
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v==pre)continue;
		dfs1(v,u);
		if(siz[v]>maxsiz)maxsiz=siz[v],ma=v;
		siz[u]+=siz[v];
	}
	son[u]=ma,siz[u]++;
}
void dfs2(int u){
	dfn[u]=++ind,rdfn[ind]=u;
	if(!son[u]){ed[u]=u;return;}
	top[son[u]]=top[u],dfs2(son[u]),ed[u]=ed[son[u]];
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v==son[u]||v==fa[u])continue;
		top[v]=v,dfs2(v);
	}
}
void dp(int u){
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v==fa[u])continue;
		dp(v),f[u][0]+=max(f[v][0],f[v][1]);
		f[u][1]+=f[v][0];
	}
	f[u][1]+=1ll*w[u];
}
void build(int l,int r,int x){
	if(l==r){
		int u=rdfn[l],g0=0,g1=w[u];
		for(int i=fi[u];i;i=ne[i]){
			int v=b[i];
			if(v==fa[u]||v==son[u])continue;
			g0+=max(f[v][0],f[v][1]),g1+=f[v][0];
		}
		a[x].a[0][0]=a[x].a[0][1]=g0;
		a[x].a[1][0]=g1,a[x].a[1][1]=-inf;
		val[l]=a[x];
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,x*2),build(mid+1,r,x*2+1);
	a[x]=a[x*2]*a[x*2+1];
}
void change(int k,int l,int r,int x){
	if(l==r){
		a[x]=val[l];
		return;
	}
	int mid=(l+r)>>1;
	if(k<=mid)change(k,l,mid,x*2);
	else change(k,mid+1,r,x*2+1);
	a[x]=a[x*2]*a[x*2+1];
}
Matrix query(int l,int r,int L,int R,int x){
	if(l==L&&r==R)return a[x];
	int mid=(L+R)>>1;
	if(r<=mid)return query(l,r,L,mid,x*2);
	else if(l>mid)return query(l,r,mid+1,R,x*2+1);
	else return query(l,mid,L,mid,x*2)*query(mid+1,r,mid+1,R,x*2+1);
}
void update(int u,int t){
	int pos=dfn[u];
	val[pos].a[1][0]+=t-w[u],w[u]=t;
	Matrix pre,now;
	while(u){
		pre=query(dfn[top[u]],dfn[ed[u]],1,n,1);
        change(pos,1,n,1);
        now=query(dfn[top[u]],dfn[ed[u]],1,n,1);
        u=fa[top[u]],pos=dfn[u];
        val[pos].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(pre.a[0][0],pre.a[1][0]);
        val[pos].a[0][1]=val[pos].a[0][0];
		val[pos].a[1][0]+=now.a[0][0]-pre.a[0][0];
	}
}
int main(){
	read(n),read(m);
	for(int i=1;i<=n;i++)read(w[i]);
	for(int i=1,x,y;i<n;i++){
		read(x),read(y);
		add(x,y),add(y,x);
	}
	dfs1(1,0),top[1]=1,dfs2(1),fa[1]=0;
	dp(1),build(1,n,1);
	while(m--){
		int x,y;
		read(x),read(y);
		update(x,y);
		ans=query(1,dfn[ed[1]],1,n,1);
		printf("%lld\n",max(ans.a[0][0],ans.a[1][0]));
	}
}

例题2

这里是题目

题意

给定一棵树,每个点有点权,求某棵子树内的一个连通块,使它的权值和最大.带修改.

分析

还是动态\(dp\).令\(f[u]\)表示以\(u\)为根的子树中,选\(u\)时的最大权值和.
那么,\(f[u]=max(0,w[u]+\sum f[v])\).令\(s[u]\)表示以u为根的子树中最大权值和.那么,\(s[u]=max(s[v],f[u])\).答案就是\(s[x]\),\(x\)为询问的节点.
由于要支持修改,因此使用常见的套路,将轻重链分开.令\(g[u]=f[u]-f[heavyson[u]]\),则\(f[u]=f[heavyson[u]]+g[u]\)
废话,这样除了大常数还有什么用
于是我们发现它变成了区间最大子段和的形式,可以用线段树维护这个东西.
现在我们还需要维护\(s\).考虑对于每个节点建一个堆维护\(s\).由于要资瓷删除旧版本,因此用两个\(priority\_queue\)维护就好了.
代码如下

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<queue>
#define N (200010)
#define M (N<<1)
#define inf (0x7f7f7f7f)
#define rg register int
typedef long double ld;
typedef long long LL;
typedef unsigned long long ull;
using namespace std;
inline char read(){
	static const int IN_LEN=1000000;
	static char buf[IN_LEN],*s,*t;
	return (s==t?t=(s=buf)+fread(buf,1,IN_LEN,stdin),(s==t?-1:*s++):*s++);
}
template<class T>
inline void read(T &x){
	static bool iosig;
	static char c;
	for(iosig=false,c=read();!isdigit(c);c=read()){
		if(c=='-')iosig=true;
		if(c==-1)return;
	}
	for(x=0;isdigit(c);c=read())x=((x+(x<<2))<<1)+(c^'0');
	if(iosig)x=-x;
}
inline char readchar(){
	static char c;
	for(c=read();!isalpha(c);c=read())
	if(c==-1)return 0;
	return c;
}
const int OUT_LEN = 10000000;
char obuf[OUT_LEN],*ooh=obuf;
inline void print(char c) {
	if(ooh==obuf+OUT_LEN)fwrite(obuf,1,OUT_LEN,stdout),ooh=obuf;
	*ooh++=c;
}
template<class T>
inline void print(T x){
	static int buf[30],cnt;
	if(x==0)print('0');
	else{
		if(x<0)print('-'),x=-x;
		for(cnt=0;x;x/=10)buf[++cnt]=x%10+48;
		while(cnt)print((char)buf[cnt--]);
	}
}
inline void flush(){fwrite(obuf,1,ooh-obuf,stdout);}
struct heap{
    priority_queue<LL>p,q;
    void push(LL x){p.push(x);}
    void erase(LL x){q.push(x);}
    LL top(){
        while(!q.empty()&&p.top()==q.top())p.pop(),q.pop();
        return(p.empty()?0:p.top());
    }
}mx[N];
struct seg{
	LL lm,rm,mx,sum;
	seg(){lm=rm=mx=sum=0;}
	seg operator +(seg x){
		seg res;
		res.sum=sum+x.sum;
		res.lm=max(lm,sum+x.lm);
		res.rm=max(x.rm,rm+x.sum);
		res.mx=max(max(x.mx,mx),x.lm+rm);
		return res;
	}
};
struct xds{int l,r;seg v;}a[N<<3];
int n,m,fi[N],ne[M],b[M],w[N],E;
int dep[N],dfn[N],rdfn[N],siz[N],son[N],ed[N],top[N],fa[N],ind;
LL g[N],f[N],s[N];
void dfs1(int u,int pre){
	siz[u]=1,fa[u]=pre;
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v==pre)continue;
		dfs1(v,u),siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
	}
}
void dfs2(int u){
	dfn[u]=++ind,rdfn[ind]=u;
	if(!son[u]){ed[u]=u;return;}
	top[son[u]]=top[u],dfs2(son[u]),ed[u]=ed[son[u]];
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v==son[u]||v==fa[u])continue;
		top[v]=v,dfs2(v);
	}
}
void dp(int u){
	for(int i=fi[u];i;i=ne[i]){
		int v=b[i];
		if(v!=fa[u])dp(v); 
		if(v!=son[u])g[u]+=f[v],mx[u].push(s[v]);
	}
	g[u]+=w[u],f[u]=max(g[u]+f[son[u]],(LL)0);
	s[u]=max(s[son[u]],max(f[u],mx[u].top()));
}
void build(int l,int r,int x){
	a[x].l=l,a[x].r=r;
	if(l==r){
		int t=rdfn[l];
		a[x].v.lm=a[x].v.rm=max(g[t],0ll);
		a[x].v.mx=max(g[t],mx[t].top());
		a[x].v.sum=g[t]; return;
	}
	int mid=(l+r)>>1;
	build(l,mid,x*2),build(mid+1,r,x*2+1);
	a[x].v=a[x*2].v+a[x*2+1].v;
}
void change(int k,int x){
	if(a[x].l==a[x].r){
		int t=rdfn[a[x].l];
		a[x].v.lm=a[x].v.rm=max(g[t],0ll);
		a[x].v.mx=max(g[t],mx[t].top());
		a[x].v.sum=g[t]; return;
	}
	int mid=(a[x].l+a[x].r)>>1;
	if(k<=mid)change(k,x*2);else change(k,x*2+1);
	a[x].v=a[x*2].v+a[x*2+1].v;
}
seg query(int l,int r,int x){
	if(a[x].l==l&&a[x].r==r)return a[x].v;
	int mid=(a[x].l+a[x].r)>>1;
	if(r<=mid)return query(l,r,x*2);
	else if(l>mid)return query(l,r,x*2+1);
	else return query(l,mid,x*2)+query(mid+1,r,x*2+1);
}
void modify(int u,int st,LL val){
    seg pre,now;
    while(u){
    	if(u!=st)mx[u].erase(pre.mx),mx[u].push(now.mx);
        pre=query(dfn[top[u]],dfn[ed[u]],1);
        g[u]+=val,change(dfn[u],1);
        now=query(dfn[top[u]],dfn[ed[u]],1);
        val=now.lm-f[top[u]],f[top[u]]=now.lm,u=fa[top[u]];
    }
}
void add(int x,int y){ne[++E]=fi[x],fi[x]=E,b[E]=y;}
int main(){
	read(n),read(m);
	for(int i=1;i<=n;i++)read(w[i]);
	for(int i=1,x,y;i<n;i++){
		read(x),read(y);
		add(x,y),add(y,x);
	}
	dfs1(1,0),top[1]=1,dfs2(1),dp(1),build(1,n,1);
	while(m--){
		char ch=readchar(); int x; read(x);
		if(ch=='M'){LL val;read(val),modify(x,x,val-w[x]),w[x]=val;}
		else print(query(dfn[x],dfn[ed[x]],1).mx),print('\n');
	}
	return flush(),0;
}
posted @ 2018-11-30 15:06  Romeolong  阅读(163)  评论(0编辑  收藏  举报