[模板] 动态DP

一、题目

点此看题

二、解法

动态 \(dp\) 的思路主要是用矩阵乘法加速 \(dp\),所以首先要知道矩阵乘法的扩展版:

\[c(i,k)=\max\{a(i,j)+b(j,k)\} \]

令人震惊的是上面这东西也满足结合律,现在我们来证明一下,假设有三个矩阵 \(a,b,c\) 相乘,大小分别是 \(n\times m,m\times p,p\times q\),我们把最终某一个位置上的值暴力展开:

\[(i,j)=\max_{k=1}^ma(i,k)+\Big(\max_{l=1}^p b(k,l)+c(l,j)\Big) \]

\[=\max_{k=1}^m\max_{l=1}^p a(i,k)+b(k,l)+c(l,j) \]

\[=\max_{l=1}^p\Big(\max_{k=1}^m a(i,k)+b(k,l)\Big)+c(l,j) \]

所以先乘 \(a,b\) 还是先乘 \(b,c\) 对答案没有影响,结合律得证。


首先写出暴力的 \(dp\) 柿子,设 \(f(u,0/1)\) 表示 \(u\) 这个点不选\(/\)选的最大权值,转移:

\[f(u,0)=\sum\max(f(v,0),f(v,1)) \]

\[f(u,1)=a(u)+\sum f(v,0) \]

先来考虑一下链怎么做,我们构造一个像这样的转移矩阵:

\[\left(\begin{matrix}0&0\\a(u)&-\infty\end{matrix}\right)\times\left(\begin{matrix}f(v,0)\\f(v,1)\end{matrix}\right)=\left(\begin{matrix}f(u,0)\\f(u,1)\end{matrix}\right) \]

然后要求根的 \(dp\) 值就直接把所有矩阵乘起来就行了,时间复杂度 \(O(n\log n)\)

那么我们能不能把上面的做法搬到树上呢?考虑把树剖分成链然后套上面的做法,也就是用树链剖分。每个点的转移矩阵就针对他的重儿子来定义,但同时我们要考虑轻儿子对他 \(dp\) 值的贡献,所以再定义 \(f'(u,0/1)\) 表示 \(u\) 不选\(/\)选,考虑 \(u\)\(u\) 的所有轻儿子的最大值,那么有如下转移:

\[f(u,0)=f'(u,0)+\max\{f(son,0),f(son,1)\} \]

\[f(u,1)=f'(u,1)+f(son,0) \]

写成矩阵就是这个样子的:

\[\left(\begin{matrix}f'(u,0)&f’(u,0)\\f'(u,1)&-\infty\end{matrix}\right)\times\left(\begin{matrix}f(son,0)\\f(son,1)\end{matrix}\right)=\left(\begin{matrix}f(u,0)\\f(u,1)\end{matrix}\right) \]

先考虑怎么统计答案,我们找到根所在的那条重链,把所有转移矩阵乘起来就行了。

再考虑如何修改,修改一个点的点权只会对它的祖先产生影响。而且由于路径上只有 \(O(\log n)\) 条轻边,所以一共只需要改 \(O(\log n)\) 个矩阵,这部分可以看看代码:

void modify(int u,int w)//把u点权改成w 
{
	val[u].a[1][0]+=w-a[u];
	a[u]=w;
	while(u)
	{
		matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);//算出f(u,0/1)
		upd(1,1,n,num[u]);//在线段树上更新那个位置的矩阵
		matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);//算出新的f(u,0/1)
		u=fa[top[u]];//要更新重链顶端父亲的转移矩阵
		val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
		val[u].a[0][1]=val[u].a[0][0];
		val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
	}
}

用一个线段树维护矩阵套上树链剖分:\(O(2^3\cdot n\log^2 n)\)

#include <cstdio>
#include <iostream>
using namespace std;
const int M = 100005;
const int inf = 1e9;
int read()
{
	int x=0,f=1;char c;
	while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
	while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
	return x*f;
}
int n,m,tot,cnt,f[M],a[M],id[M],fa[M];
int siz[M],son[M],num[M],top[M],bot[M],dp[M][2];
//top表示重链头
//bot表示重链尾
//num表示这个点在线段树上的位置
//id表示线段树上位置所对应的点 
//dp表示初始的dp数组 
struct edge
{
	int v,next;
	edge(int V=0,int N=0) : v(V) , next(N) {}
}e[2*M];
struct matrix
{
	int a[2][2];
	matrix() {a[0][0]=a[0][1]=a[1][0]=a[1][1]=-inf;}
	matrix operator * (const matrix &b) const
	{
		matrix r;
		for(int i=0;i<2;i++)
			for(int j=0;j<2;j++)
				for(int k=0;k<2;k++)
					r.a[i][k]=max(r.a[i][k],a[i][j]+b.a[j][k]);
		return r;
	}
	void print()
	{
		puts("---------");
		for(int i=0;i<2;i++,puts(""))
			for(int j=0;j<2;j++)
				printf("%d ",a[i][j]);
	}
}val[M],tr[4*M];
//线段树部分
void up(int i)
{
	tr[i]=tr[i<<1]*tr[i<<1|1];
}
void build(int i,int l,int r)
{
	if(l==r)
	{
		tr[i]=val[id[l]];
		return ;
	}
	int mid=(l+r)>>1;
	build(i<<1,l,mid);
	build(i<<1|1,mid+1,r);
	up(i);
}
void upd(int i,int l,int r,int x)//修改x这个位置的矩阵
{
	if(l==r)
	{
		tr[i]=val[id[x]];
		return ;
	}
	int mid=(l+r)>>1;
	if(mid>=x) upd(i<<1,l,mid,x);
	else upd(i<<1|1,mid+1,r,x);
	up(i);
}
matrix ask(int i,int l,int r,int L,int R)
{
	if(L<=l && r<=R) return tr[i];
	int mid=(l+r)>>1;
	if(R<=mid) return ask(i<<1,l,mid,L,R);
	if(L>mid) return ask(i<<1|1,mid+1,r,L,R);
	return ask(i<<1,l,mid,L,R)*ask(i<<1|1,mid+1,r,L,R);
}
//树链剖分部分 
void dfs1(int u,int p)
{
	siz[u]=1;fa[u]=p;
	for(int i=f[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v==p) continue;
		dfs1(v,u);
		siz[u]+=siz[v];
		if(siz[son[u]]<siz[v]) son[u]=v;
	}
}
void dfs2(int u,int tp)
{
	top[u]=tp;
	num[u]=++cnt;
	id[cnt]=u;
	val[u].a[0][0]=val[u].a[0][1]=0;
	val[u].a[1][0]=dp[u][1]=a[u];
	if(son[u])
	{
		dfs2(son[u],tp),bot[u]=bot[son[u]];
		dp[u][0]+=max(dp[son[u]][0],dp[son[u]][1]);
		dp[u][1]+=dp[son[u]][0];
	}
	else bot[u]=u;//如果没有重儿子底部就是自己
	for(int i=f[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v==fa[u] || v==son[u]) continue;
		dfs2(v,v);
		dp[u][0]+=max(dp[v][0],dp[v][1]);
		dp[u][1]+=dp[v][0];
		val[u].a[0][0]+=max(dp[v][0],dp[v][1]);
		val[u].a[0][1]=val[u].a[0][0];
		val[u].a[1][0]+=dp[v][0];
		//(0,0)/(0,1)表示这个点不选,(1,0)表示这个点要选 
	}
}
void modify(int u,int w)//把u点权改成w 
{
	val[u].a[1][0]+=w-a[u];
	a[u]=w;
	while(u)
	{
		matrix t1=ask(1,1,n,num[top[u]],num[bot[u]]);
		upd(1,1,n,num[u]);
		matrix t2=ask(1,1,n,num[top[u]],num[bot[u]]);
		u=fa[top[u]];
		val[u].a[0][0]+=max(t2.a[0][0],t2.a[1][0])-max(t1.a[0][0],t1.a[1][0]);
		val[u].a[0][1]=val[u].a[0][0];
		val[u].a[1][0]+=t2.a[0][0]-t1.a[0][0];
	}
}
signed main()
{
	n=read();m=read();
	for(int i=1;i<=n;i++)
		a[i]=read();
	for(int i=1;i<n;i++)
	{
		int u=read(),v=read();
		e[++tot]=edge(v,f[u]),f[u]=tot;
		e[++tot]=edge(u,f[v]),f[v]=tot;
	}
	dfs1(1,0);
	dfs2(1,1);
	build(1,1,n);
	while(m--)
	{
		int x=read(),y=read();
		modify(x,y);
		matrix t1=ask(1,1,n,num[1],num[bot[1]]);
		printf("%d\n",max(t1.a[0][0],t1.a[1][0]));
	}
}
posted @ 2021-03-13 15:14  C202044zxy  阅读(118)  评论(0编辑  收藏  举报