【luogu P4719】【模板】“动态 DP“&动态树分治(DDP)(重链剖分)(线段树)

【模板】"动态 DP"&动态树分治

题目链接:luogu P4719

题目大意

给你一棵树,点带权,每次操作修改点权,要你求最大权独立集的权值大小。

思路

这道题就是 DDP 的模板题啦。
(其实还多了一个树剖)

DDP 是什么呢,就是动态 DP,就是一些简单的 DP 在加上了一个带修,然后我们用矩阵乘法来乘模拟它转移的过程,然后用数据结构来维护矩阵的乘,从而快速处理修改。

那我们考虑先一开始的 DP。
\(f_{i,0/1}\) 表示 \(i\) 的子树,\(i\) 这个点不选 / 选的最大权独立集。
然后转移就是枚举儿子 \(j\)\(f_{i,0}=\sum\limits_{j}\max(f_{j,0},f_{j,1}),f_{i,1}=a_i+\sum\limits_{j}f_{j,0}\)
然后答案是 \(\max(f_{1,0},f_{1,1})\)(以 \(1\) 为根做树形 DP)

然后考虑修改之后有哪些部分要改,那就是修改的点往上都根的链。
那修改链嘛,不难想到一个东西叫做树链剖分,那我们就直接上重链剖分。
那既然要这样了,我们转移肯定不能带个 \(\sum\),不然你怎么优化,所以我们考虑根据轻重儿子弄一个 \(g_{i,0/1}\) 表示 \(i\) 的子树,\(i\) 的轻儿子可选可不选 / 都不能选的最大权独立集。
那我们设 \(j\)\(i\) 的重儿子:\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1}),f_{i,1}=a_i+g_{i,1}+f_{j,1}\)
然后发现 \(f_{i,1}\) 里面 \(a_i,g_{i,1}\) 两个下标都是 \(i\) 相关的,考虑把 \(a_i\) 放进 \(g_{i,1}\) 里面。
\(g_{i,1}\) 就表示 \(i\) 的子树,\(i\) 的轻儿子都不选,\(i\) 选的最大权独立集。

然后重新写一次:
\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1})\)
\(f_{i,1}=g_{i,1}+f_{j,1}\)

那轻边的转移就好了,接着考虑重链要怎么线段树维护。

看到这种简洁的东西我们考虑怎么用矩阵乘法来优化。
发现问题出在 \(\max\) 上,普通的矩阵乘法搞不了这玩意儿。
那我们就大胆重新定义矩阵乘法咯。
定义一个 \(*\)\(A*B=\max\limits_{k}\{A_{i,k},B_{k,j}\}\)
记得验证结合律:
\(\max(\max(x,y),z)=\max(x,y,z)\)\(\max(x,\max(y,z))=\max(x,y,z)\),那肯定一样啦。
(其实就是因为 \(\max\) 是满足结合律的)

然后你想想能不能用这个来搞。
首先第一个的样子肯定得改:\(f_{i,0}=g_{i,0}+\max(f_{j,0},f_{j,1})=\max(g_{i,0}+f_{j,0},g_{i,0}+f_{j,1})\)
那第二个我们就可以看做是 \(f_{i,1}=\max(g_{i,1}+f_{j,1},-\infty)\)
那矩阵是不是可以是:
\(\begin{vmatrix}g_{i,0}&g_{i,1}\\ g_{i,0}&\infty\end{vmatrix}\)
好像可以?

那我们就有:\(\begin{vmatrix}f_{j,0}&f_{j,1}\end{vmatrix}*\begin{vmatrix}g_{i,0}&g_{i,1}\\ g_{i,0}&\infty\end{vmatrix}=\begin{vmatrix}f_{i,0}&f_{i,1}\end{vmatrix}\)

那就可以啦!
然后你会发现直接上就错了,因为你这个树形 DP 是自下而上的,那线段树来说你是从左到右,也就是从上到下的。

那接下来就是两个解决方法:

  1. 直接改乘起来的顺序,线段树里面改,就合并的时候是 \(val_{2x+1}*val_{2x}\),然后查询答案的时候也是右边的答案乘上左边的。
  2. 直接改矩阵的样子,让它满足:\(Z*\begin{vmatrix}f_{j,0}&f_{j,1}\end{vmatrix}=\begin{vmatrix}f_{i,0}&f_{i,1}\end{vmatrix}\)\(Z\) 就是要改成的矩阵)
    然后再这里其实不难,就原来的矩阵重心对称一下就好了:\(\begin{vmatrix}g_{i,0}&g_{i,0}\\g_{i,1}&\infty\end{vmatrix}\)

然后搞就可以啦。

代码

改乘起来顺序形

#include<cstdio>
#include<vector>
#include<cstring>

using namespace std;

const int N = 1e5 + 100;
const int M = 2;
int n, m, a[N], x, y;
vector <int> G[N];

struct matrix {
	int a[M][M];
	
	matrix() {
		memset(a, -0x3f, sizeof(a));
	}
	
	matrix operator *(matrix y) {
		matrix re;
		for (int k = 0; k < 2; k++)
			for (int i = 0; i < 2; i++)
				for (int j = 0; j < 2; j++)
					re.a[i][j] = max(re.a[i][j], a[i][k] + y.a[k][j]);
		return re;
	}
};

int fa[N], sz[N], son[N], top[N], dfn[N], id[N], End[N], f[N][2];
matrix val[N];
void dfs0(int now, int father) {
	sz[now] = 1; fa[now] = father;
	for (int i = 0; i < G[now].size(); i++) {
		int x = G[now][i]; if (x == father) continue;
		dfs0(x, now); sz[now] += sz[x];
		if (sz[x] > sz[son[now]]) son[now] = x;
	}
}
void dfs1(int now, int father) {
	f[now][0] = 0; f[now][1] = a[now];
	val[now].a[0][0] = val[now].a[1][0] = 0;
	val[now].a[0][1] = a[now];
	
	dfn[++dfn[0]] = now; id[now] = dfn[0];
	if (son[now]) {
		top[son[now]] = top[now]; dfs1(son[now], now);
		f[now][0] += max(f[son[now]][0], f[son[now]][1]);
		f[now][1] += f[son[now]][0];
	}
	else End[top[now]] = now;
	for (int i = 0; i < G[now].size(); i++) {
		int x = G[now][i]; if (x == father || x == son[now]) continue;
		top[x] = x; dfs1(x, now);
		f[now][0] += max(f[x][0], f[x][1]);
		f[now][1] += f[x][0];
		val[now].a[0][0] += max(f[x][0], f[x][1]); val[now].a[1][0] += max(f[x][0], f[x][1]);
		val[now].a[0][1] += f[x][0];
	}
}

struct XD_tree {
	matrix v[N << 2];
	
	void up(int now) {
		v[now] = v[now << 1 | 1] * v[now << 1];
	}
	
	void build(int now, int l, int r) {
		if (l == r) {
			v[now] = val[dfn[l]]; return ;
		}
		int mid = (l + r) >> 1; build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
		up(now);
	}
	
	void change(int now, int l, int r, int pl) {
		if (l == r) {
			v[now] = val[dfn[l]]; return ;
		}
		int mid = (l + r) >> 1;
		if (pl <= mid) change(now << 1, l, mid, pl);
			else change(now << 1 | 1, mid + 1, r, pl);
		up(now);
	}
	
	matrix query(int now, int l, int r, int L, int R) {
		if (L <= l && r <= R) return v[now];
		int mid = (l + r) >> 1;
		if (L <= mid && mid < R) return query(now << 1 | 1, mid + 1, r, L, R) * query(now << 1, l, mid, L, R);
		if (L <= mid) return query(now << 1, l, mid, L, R);
		if (mid < R) return query(now << 1 | 1, mid + 1, r, L, R);
	}
	
	void update(int now, int va) {
		val[now].a[0][1] -= a[now]; val[now].a[0][1] += va; a[now] = va;//直接记每个数组当前的值,这样线段树就不用下传矩阵了
		matrix bef, aft;
		while (now) {
			bef = query(1, 1, n, id[top[now]], id[End[top[now]]]);
			change(1, 1, n, id[now]);
			aft = query(1, 1, n, id[top[now]], id[End[top[now]]]);
			now = fa[top[now]];
			
			if (!now) break;
			val[now].a[0][0] -= max(bef.a[0][0], bef.a[0][1]); val[now].a[0][0] += max(aft.a[0][0], aft.a[0][1]);
			val[now].a[1][0] -= max(bef.a[0][0], bef.a[0][1]); val[now].a[1][0] += max(aft.a[0][0], aft.a[0][1]);
			val[now].a[0][1] -= bef.a[0][0]; val[now].a[0][1] += aft.a[0][0];
		}
	}
}T;

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
	for (int i = 1; i < n; i++) {
		scanf("%d %d", &x, &y); G[x].push_back(y); G[y].push_back(x);
	}
	dfs0(1, 0); top[1] = 1; dfs1(1, 0);
	
	T.build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d %d", &x, &y);
		T.update(x, y);
		matrix ans = T.query(1, 1, n, id[1], id[End[1]]);
		//记得是从下向上搞,所以你要记录一个end表示这条重链的下面,然后一个点的答案是从它链下面到它
		printf("%d\n", max(ans.a[0][0], ans.a[0][1]));
	}
	
	return 0;
}

改矩阵形

#include<cstdio>
#include<vector>
#include<cstring>

using namespace std;

const int N = 1e5 + 100;
const int M = 2;
int n, m, a[N], x, y;
vector <int> G[N];

struct matrix {
	int a[M][M];
	
	matrix() {
		memset(a, -0x3f, sizeof(a));
	}
	
	matrix operator *(matrix y) {
		matrix re;
		for (int k = 0; k < 2; k++)
			for (int i = 0; i < 2; i++)
				for (int j = 0; j < 2; j++)
					re.a[i][j] = max(re.a[i][j], a[i][k] + y.a[k][j]);
		return re;
	}
};

int fa[N], sz[N], son[N], top[N], dfn[N], id[N], End[N], f[N][2];
matrix val[N];
void dfs0(int now, int father) {
	sz[now] = 1; fa[now] = father;
	for (int i = 0; i < G[now].size(); i++) {
		int x = G[now][i]; if (x == father) continue;
		dfs0(x, now); sz[now] += sz[x];
		if (sz[x] > sz[son[now]]) son[now] = x;
	}
}
void dfs1(int now, int father) {
	f[now][0] = 0; f[now][1] = a[now];
	val[now].a[0][0] = val[now].a[0][1] = 0;
	val[now].a[1][0] = a[now];
	
	dfn[++dfn[0]] = now; id[now] = dfn[0];
	if (son[now]) {
		top[son[now]] = top[now]; dfs1(son[now], now);
		f[now][0] += max(f[son[now]][0], f[son[now]][1]);
		f[now][1] += f[son[now]][0];
	}
	else End[top[now]] = now;
	for (int i = 0; i < G[now].size(); i++) {
		int x = G[now][i]; if (x == father || x == son[now]) continue;
		top[x] = x; dfs1(x, now);
		f[now][0] += max(f[x][0], f[x][1]);
		f[now][1] += f[x][0];
		val[now].a[0][0] += max(f[x][0], f[x][1]); val[now].a[0][1] += max(f[x][0], f[x][1]);
		val[now].a[1][0] += f[x][0];
	}
}

struct XD_tree {
	matrix v[N << 2];
	
	void up(int now) {
		v[now] = v[now << 1] * v[now << 1 | 1];
	}
	
	void build(int now, int l, int r) {
		if (l == r) {
			v[now] = val[dfn[l]]; return ;
		}
		int mid = (l + r) >> 1; build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
		up(now);
	}
	
	void change(int now, int l, int r, int pl) {
		if (l == r) {
			v[now] = val[dfn[l]]; return ;
		}
		int mid = (l + r) >> 1;
		if (pl <= mid) change(now << 1, l, mid, pl);
			else change(now << 1 | 1, mid + 1, r, pl);
		up(now);
	}
	
	matrix query(int now, int l, int r, int L, int R) {
		if (L <= l && r <= R) return v[now];
		int mid = (l + r) >> 1;
		if (L <= mid && mid < R) return query(now << 1, l, mid, L, R) * query(now << 1 | 1, mid + 1, r, L, R);
		if (L <= mid) return query(now << 1, l, mid, L, R);
		if (mid < R) return query(now << 1 | 1, mid + 1, r, L, R);
	}
	
	void update(int now, int va) {
		val[now].a[1][0] -= a[now]; val[now].a[1][0] += va; a[now] = va;//直接记每个数组当前的值,这样线段树就不用下传矩阵了
		matrix bef, aft;
		while (now) {
			bef = query(1, 1, n, id[top[now]], id[End[top[now]]]);
			change(1, 1, n, id[now]);
			aft = query(1, 1, n, id[top[now]], id[End[top[now]]]);
			now = fa[top[now]];
			
			if (!now) break;
			val[now].a[0][0] -= max(bef.a[0][0], bef.a[1][0]); val[now].a[0][0] += max(aft.a[0][0], aft.a[1][0]);
			val[now].a[0][1] -= max(bef.a[0][0], bef.a[1][0]); val[now].a[0][1] += max(aft.a[0][0], aft.a[1][0]);
			val[now].a[1][0] -= bef.a[0][0]; val[now].a[1][0] += aft.a[0][0];
		}
	}
}T;

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
	for (int i = 1; i < n; i++) {
		scanf("%d %d", &x, &y); G[x].push_back(y); G[y].push_back(x);
	}
	dfs0(1, 0); top[1] = 1; dfs1(1, 0);
	
	T.build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d %d", &x, &y);
		T.update(x, y);
		matrix ans = T.query(1, 1, n, id[1], id[End[1]]);
		//记得是从下向上搞,所以你要记录一个end表示这条重链的下面,然后一个点的答案是从它链下面到它
		printf("%d\n", max(ans.a[0][0], ans.a[1][0]));
	}
	
	return 0;
}
posted @ 2022-04-10 12:58  あおいSakura  阅读(46)  评论(0编辑  收藏  举报