"动态 DP"&动态树分治 整理

刚刚学完,整理一下。

一般是对于简单的树形dp,加上丧心病狂的修改操作,并且支持在线操作的解决方法。

看洛谷的模板题

给定一棵 \(n\) 个点的树,点带点权,有 \(m\) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)。你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

显然,如果没有修改操作的话是可以直接线性求出答案的。我们设 f[u][1] 为选取点 u 的答案,f[u][0] 为不选取点 u 的答案,可以得到转移方程:\(f[u][0] = \sum \max(f[v][0],f[v][1]),f[u][1] = a[u] + \sum f[v][0]\)

很明显对于所有的修改我们都重新求一遍答案的话复杂度会爆炸。

考虑每一次修改那些位置的 dp 数组发生了变化。明显只有修改位置到根这条链上会有变化。那么假如数据随机,可以 \(mlogn\) 通过此题。可数据不可能随机对吧,当树退化成一条链的时候,这个做法一定会被卡成 \(nm\) 。那么怎么办呢?

考虑进行轻重链剖分。

我们设 g 为所有轻儿子对父亲的贡献,那么就有了 \(f[u][0] = g[u][0] + \max(f[son[u]][0],f[son[u]][1]),f[u][1] = a[u] + g[u][0] + f[son[u]][0]\) 。发现 a[u] ,g[u][0] 都至于 u 有关,那就把 a[u] 放进 g[u][0] 里。

然后得到 \(f[u][0] = g[u][0] + \max(f[son[u]][0],f[son[u]][1]),f[u][1] = g[u][0] + f[son[u]][0]\)

非常显然,对于一条到根的链,最多只有 log 个轻儿子,也就是只有 log 个 g 需要修改,似乎可以加速了,yeah!

但是还有一个问题,那就是对于一条重链,我又如何快速得到答案呢?

我们把前面说的转移方程口胡到矩阵上,定义一种矩阵乘法 '' ,有 ab=c 时,\(c[i][j] = \max(a[i][k]+b[k][j])\)

可以得到转移矩阵 \(\begin{vmatrix}g[i][0]&f[i][0]\\f[i][1]&-inf\end{vmatrix} \times \begin{vmatrix}f[j][0]\\f[j][1]\end{vmatrix} = \begin{vmatrix}f[i][0]\\f[i][1]\end{vmatrix}\) ,在这里 j = son[i]。

口胡后发现它满足结合律,因此可以使用线段树来进行维护。

对于修改,我们往上跳,不停的撤销原先的贡献,再将新的贡献加上去。对于询问,我们在线段树上查询根这条重链的开头和结尾,求得答案。

#include <cstdio>
#include <cstring>
#include <algorithm>

#define mid (l+r>>1)

using namespace std;

int read()
{
	int a = 0,x = 1;char ch = getchar();
	while(ch > '9' || ch < '0') {if(ch == '-') x = -1;ch = getchar();}
	while(ch >= '0' && ch <= '9') {a = a*10 + ch-'0';ch = getchar();}
	return a*x;
}
const int N=1e6+7,inf = 1e9+7;
int n,m;

int head[N],go[N],nxt[N],cnt,a[N];
void add(int u,int v)
{
	go[++cnt] = v;
	nxt[cnt] = head[u];
	head[u] = cnt;
}
#define end End
int dfn[N],pos[N],end[N],str[N],siz[N],son[N],g[N][2],fa[N],f[N][2],dep[N];

void dfs1(int u)
{
	siz[u] = 1;
	for(int e = head[u];e;e = nxt[e]) {
		int v = go[e];if(v == fa[u]) continue;
		fa[v] = u;dfs1(v);
		siz[u] += siz[v];
		if(siz[v] > siz[son[u]]) son[u] = v;
	}
}

struct node{
	int a[2][2];
	node (int x,int y) {a[0][0] = a[0][1] = x,a[1][0] = y,a[1][1] = -inf;}
	node () {}
	friend node operator * (node a,node b)
	{
		node c = (node){-inf,-inf};
		for(int i = 0;i < 2;i ++)
			for(int j = 0;j < 2;j ++)
				for(int k = 0;k < 2;k ++)
					c.a[i][j] = max(c.a[i][j],a.a[i][k] + b.a[k][j]);
		return c;
	}
}val[N],tre[N];

void dfs2(int u,int h)
{
	str[u] = h,end[h] = u;dep[u] = dep[fa[u]] + 1;
	dfn[u] = ++cnt,pos[cnt] = u;
	if(son[u]) dfs2(son[u],h);
	f[u][1] = f[son[u]][0] + a[u],f[u][0] = max(f[son[u]][0],f[son[u]][1]);
	for(int e = head[u];e;e = nxt[e]) {
		int v = go[e];if(v == fa[u] || v == son[u]) continue;
		dfs2(v,v);f[u][1] += f[v][0],f[u][0] += max(f[v][0],f[v][1]);
	}
	g[u][0] = f[u][0] - max(f[son[u]][0],f[son[u]][1]);
	g[u][1] = f[u][1] - f[son[u]][0];
	val[u] = (node){g[u][0],g[u][1]};
}

void build(int root,int l,int r)
{
	if(l == r) {tre[root] = val[pos[l]];return ;}
	build(root<<1,l,mid);build(root<<1|1,mid+1,r);
	tre[root] = tre[root<<1] * tre[root<<1|1];
}

void update(int root,int l,int r,int p)
{
	if(l == r && l == p) {tre[root] = val[pos[p]];return ;}
	if(p <= mid) update(root<<1,l,mid,p);
	else update(root<<1|1,mid+1,r,p);
	tre[root] = tre[root<<1] * tre[root<<1|1];
}

node query(int root,int l,int r,int ql,int qr)
{
	if(l >= ql && r <= qr) return tre[root];
	if(qr <= mid) return query(root<<1,l,mid,ql,qr);
	else if(ql > mid) return query(root<<1|1,mid+1,r,ql,qr);
	else return query(root<<1,l,mid,ql,qr) * query(root<<1|1,mid+1,r,ql,qr);
}

void solve(int p,int x)
{
	val[p].a[1][0] += x-a[p];
	a[p] = x;node tmp1,tmp2;
	while(p) {
		tmp1 = query(1,1,n,dfn[str[p]],dfn[end[str[p]]]);
		update(1,1,n,dfn[p]);
		tmp2 = query(1,1,n,dfn[str[p]],dfn[end[str[p]]]);
		p = fa[str[p]];
		if(!p) break;
		val[p].a[0][0] += max(tmp2.a[1][0],tmp2.a[0][0]) - max(tmp1.a[1][0],tmp1.a[0][0]);
		val[p].a[0][1] = val[p].a[0][0];
		val[p].a[1][0] += tmp2.a[0][0] - tmp1.a[0][0];
	}
}

int main()
{
	// freopen("random.in","r",stdin);
	// freopen("sol.out","w",stdout);
	n = read(),m = read();
	for(int i = 1;i <= n;i ++) a[i] = read();
	for(int i = 1;i < n;i ++) {
		int u = read(),v = read();
		add(u,v);add(v,u);
	}
	cnt = 0;dfs1(1),dfs2(1,1);
	build(1,1,n);
	for(int i = 1;i <= m;i ++) {
		int p = read(),x = read();
		solve(p,x);node tmp = query(1,1,n,dfn[1],dfn[end[1]]);
		printf("%d\n",max(tmp.a[0][0],tmp.a[1][0]));
	}
}
posted @ 2021-06-07 10:37  nao-nao  阅读(42)  评论(0编辑  收藏  举报