动态 DP 学习笔记

前言

新科技永无止境(

当你拥有一个状态转移方程看上去很简单的,画风正常的树形 DP 题。

然后 duliu 出题人又加上了单点修改点权(边权)的操作,那么就可能是比较裸的动态 DP 了。

主要思想

模板题

给定一棵带点权的树,每次操作修改某个点的点权,每次求最大权独立集的权值。

首先,动态 DP 要求转移方程必须足够简单。

例如这里,不带修改时画风正常的 DP 是,设 \(f(u,0/1)\) 表示点 \(u\) 选或不选,以它为根的子树的最大权独立集。

\[f(u,0)=\sum \max(f(v,0),f(v,1)) \]

\[f(u,1)=a(u)+\sum f(v,0) \]

如果每次重做一遍,不难得到 \(O(nq)\) 的暴力做法。

但是发现对于单点修改的点 \(u\),只有它到根的链上的点的 \(f\) 可能发生改变,所以可以只暴力跳链。

这样看上去很厉害但最坏还是 \(O(nq)\) 的,但有一定的启发性,提示我们用可以快速处理链信息的数据结构维护。

可选的有很多,LCT、树剖等都是可选的,甚至还有全局平衡二叉树,这里介绍比较常见的树剖写法。

为了迎合树剖从重儿子转移到父亲节点的特性,定义 \(g(v,0/1)\) 表示除了重儿子 \(v\) 之外的 \(u\) 子树的答案,有:

\[g(u,0)=\sum \max(f(v',0),f(v',1)) \]

\[g(u,1)=a(u)+\sum f(v',0) \]

其中上式中的 \(v'\in subtree(u),v'\neq v\),那么转移方程就变成了:

\[f(u,0)=\max(f(v,0),f(v,1))+g(u,0) \]

\[f(u,1)=f(v,0)+g(u,1) \]

这看上去还是比较难受,所以再形式化一点:

\[f(u,0)=\max(f(v,0)+g(u,0),f(v,1)+g(u,0)) \]

\[f(u,1)=\max(f(v,0)+g(u,1),-\infty) \]

定义广义矩阵乘法,\(a_{i,j}=\max(a_{i,j},b_{i,k}+c_{k,j})\),那么:

\[\begin{vmatrix}f(v,0) & f(v,1)\end{vmatrix}\begin{vmatrix}g(u,0) &g(u,1)\\g(u,0) & -\infty \end{vmatrix}=\begin{vmatrix}f(u,0)&f(u,1)\end{vmatrix} \]

这玩意是满足结合律的,感性证明因为 \(\max\) 和加法运算都是满足结合律的叭。

但是树剖维护 DFS 序时,线段树习惯上是左区间乘右区间,所以应当改为左乘。

乘法顺序是应用矩阵乘法时需要严谨考虑的方面,因为它不满足交换律。所以最终方程:

\[\begin{vmatrix}g(u,0)&g(u,0)\\g(u,1) & -\infty \end{vmatrix}\begin{vmatrix}f(v,0) \\ f(v,1)\end{vmatrix}=\begin{vmatrix}f(u,0)\\f(u,1)\end{vmatrix} \]

发现叶子节点的转移矩阵没用,那么恰好可以作为初值,对于叶子节点 \(u\),初值长这样:

\[\begin{vmatrix}0&-\infty\\ a(u)& -\infty \end{vmatrix} \]

那么一个节点的答案就是 它所在的重链的叶子节点一直乘到它的转移矩阵 的 \((0,0)\)\((1,0)\) 位。

对于修改的点权,它会影响到:

  1. 自己的转移矩阵。
  2. 自己所在重链的比它深度浅的所有节点。
  3. 自己所在重链顶点的父亲的转移矩阵。

于是就先修改自己,然后重新统计整个重链的答案(即查询整个重链的乘),再把偏移量加到重链顶点的父亲上。

最后整棵树的答案就是从 \(1\) 开始的整个重链的乘。

代码实现

核心 tricks:

  1. 利用 \(n\) 个矩阵记录所有点的转移矩阵,线段树初始化和单调修改时直接赋值。(这里的 val 数组)
  2. 因为需要 Query 整个链来得到某个链的答案,所以需要记录链末端。(这里的 las 数组)

提前将矩阵的定义、转移方程等核心部分都写在纸上整理好,再打代码将比较简洁(

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

const int N = 1e5 + 10;

int n, m, a[N], f[N][2];
int tot, id[N], top[N], dfn[N];
int fa[N], las[N], son[N], siz[N];
int cnt, head[N];

struct Edge{int nxt, to;} ed[N << 1];

struct Matrix{
	int mat[2][2];
	
	Matrix() {memset(mat, 0xcf, sizeof(mat));}
	
	Matrix operator * (const Matrix &b) const{
		Matrix c;
		
		for(int i = 0; i < 2; i ++)
			for(int j = 0; j < 2; j ++)
				for(int k = 0; k < 2; k ++)
					c.mat[i][j] = max(c.mat[i][j], mat[i][k] + b.mat[k][j]);
		
		return c;
	}
} val[N], dat[N << 2];

int read(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
	return x * f;
}

void add(int u, int v){
	ed[++ cnt] = (Edge){head[u], v};
	head[u] = cnt;
}

void dfs1(int u, int Fa){
	fa[u] = Fa, siz[u] = 1;
	
	for(int i = head[u], v; i; i = ed[i].nxt)
		if((v = ed[i].to) != Fa){
			dfs1(v, u);
			siz[u] += siz[v];
			if(!son[u] || siz[v] > siz[son[u]]) son[u] = v;
		}
}

void dfs2(int u, int Top){
	id[dfn[u] = ++ tot] = u;
	top[u] = Top;
	las[Top] = max(las[Top], tot);
	
	f[u][0] = 0, f[u][1] = a[u];
	val[u].mat[0][0] = val[u].mat[0][1] = 0;
	val[u].mat[1][0] = a[u];
	
	if(son[u]) {
		dfs2(son[u], Top);
		f[u][0] += max(f[son[u]][0], f[son[u]][1]);
		f[u][1] += f[son[u]][0];
	} 
	
	for(int i = head[u], v; i; i = ed[i].nxt)
		if((v = ed[i].to) != fa[u] && v != son[u]){
			dfs2(v, v);
			f[u][0] += max(f[v][0], f[v][1]);
			f[u][1] += f[v][0];
			
			val[u].mat[0][0] += max(f[v][0], f[v][1]);
			val[u].mat[0][1] = val[u].mat[0][0];
			val[u].mat[1][0] += f[v][0];
		}
}

void Push_Up(int p) {dat[p] = dat[p << 1] * dat[p << 1 | 1];}

void Build(int p, int l, int r){
	if(l == r) {dat[p] = val[id[l]]; return;}
	int mid = (l + r) >> 1;
	Build(p << 1, l, mid);
	Build(p << 1 | 1, mid + 1, r);
	Push_Up(p);
}

void Update(int p, int l, int r, int k){
	if(l == r) {dat[p] = val[id[k]]; return;}
	int mid = (l + r) >> 1;
	if(k <= mid) Update(p << 1, l, mid, k);
	else Update(p << 1 | 1, mid + 1, r, k);
	Push_Up(p);
}

Matrix Query(int p, int l, int r, int L, int R){
	if(L <= l && r <= R) return dat[p];
	int mid = (l + r) >> 1;
	if(L >  mid)
		return Query(p << 1 | 1, mid + 1, r, L, R);
	if(R <= mid) 
		return Query(p << 1, l, mid, L, R);
	return Query(p << 1, l, mid, L, R) * Query(p << 1 | 1, mid + 1, r, L, R);
}

void Modify(int u, int v) {
	val[u].mat[1][0] += v - a[u];
	a[u] = v;
	
	while(u){
		Matrix pre = Query(1, 1, n, dfn[top[u]], las[top[u]]);
		Update(1, 1, n, dfn[u]);
		Matrix now = Query(1, 1, n, dfn[top[u]], las[top[u]]);
		u = fa[top[u]];
		
		val[u].mat[0][0] += max(now.mat[0][0], now.mat[1][0]) - max(pre.mat[0][0], pre.mat[1][0]);
		val[u].mat[0][1] = val[u].mat[0][0];
		val[u].mat[1][0] += now.mat[0][0] - pre.mat[0][0];
	}
}

int main(){
	n = read(), m = read();
	for(int i = 1; i <= n; i ++) a[i] = read();
	for(int i = 1; i <  n; i ++){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	dfs1(1, 0), dfs2(1, 1); Build(1, 1, n);
	while(m --){
		int x = read(), y = read();
		Modify(x, y);
		Matrix ans = Query(1, 1, n, dfn[1], las[1]);
		printf("%d\n", max(ans.mat[0][0], ans.mat[1][0]));
	}
	return 0;
}

简单例题

相比较于树剖这种形式化的数据结构维护,个人认为动态 DP 最有价值的是转移矩阵的思想加速递推。

例如保卫王国,可以用简单的 树上倍增 + 矩阵思想 达到很好的解题效果。

这道题求的是树上最小点覆盖集,其实可以用 最小点覆盖 = 全集 - 最大独立集 来直接转化为上面的板子题。

但这道题没有修改点权的操作,只有强制选 / 不选点,那么可以考虑倍增。

整体思想是将强制点不能选的赋值为 \(\infty\),然后倍增到 Lca 下面,简单讨论得到 Lca 的值,最后一路倍增到根节点得到答案。

具体实现的细节还是比较多的,例如倍增的矩阵应当是右乘的,且转移矩阵是挂在子节点自己身上。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<set>
using namespace std;

typedef long long LL;
const int N = 1e5 + 10;
const LL INF = 1e12;

struct Edge{int nxt, to;} ed[N << 1];

struct Matrix{
	LL mat[2][2];
	
	Matrix(){
		for(int i = 0; i < 2; i ++)
		for(int j = 0; j < 2; j ++)	
		mat[i][j] = INF;
	}
	
	Matrix operator * (const Matrix &b) const{
		Matrix c;
		for(int i = 0; i < 2; i ++)
			for(int j = 0; j < 2; j ++)
				for(int k = 0; k < 2; k ++)
					c.mat[i][j] = min(c.mat[i][j], mat[i][k] + b.mat[k][j]);
		return c;
	}
};

int n, m, cnt, p[N], head[N];
int dep[N], fa[N][17];
LL f[N][2];
Matrix g[N][17]; 
set<pair<int, int> > E;

int read(){
	int x = 0, f = 1; char c = getchar();
	while(c < '0' || c > '9') f = (c == '-') ? -1 : 1, c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
	return x * f;
}

void Add_Edge(int u, int v){
	ed[++ cnt] = (Edge){head[u], v};
	head[u] = cnt;
}

void dfs1(int u, int Fa){
	fa[u][0] = Fa, dep[u] = dep[Fa] + 1;
	f[u][0] = 0, f[u][1] = p[u];

	for(int i = head[u], v; i; i = ed[i].nxt)
		if((v = ed[i].to) != Fa){
			dfs1(v, u);
			f[u][0] += f[v][1];
			f[u][1] += min(f[v][0], f[v][1]);
		} 
}

void dfs2(int u, int Fa){
	for(int i = 1; (1 << i) <= dep[u]; i ++){
		fa[u][i] = fa[fa[u][i - 1]][i - 1];
		g[u][i] = g[u][i - 1] * g[fa[u][i - 1]][i - 1];
	}
		
	for(int i = head[u], v; i; i = ed[i].nxt)
		if((v = ed[i].to) != Fa){
			g[v][0].mat[1][0] = f[u][0] - f[v][1];
			g[v][0].mat[0][1] = g[v][0].mat[1][1] = f[u][1] - min(f[v][0], f[v][1]);
			
			dfs2(v, u);
		}
}

LL Query(int x, int o1, int y, int o2){
	if(dep[x] > dep[y]) swap(x, y), swap(o1, o2);
	Matrix t1, t2, t;  int u;
	t1.mat[0][o1] = f[x][o1];
	t2.mat[0][o2] = f[y][o2];
	
	for(int i = 16; i >= 0; i --)
		if(dep[fa[y][i]] >= dep[x]) t2 = t2 * g[y][i], y = fa[y][i];

	if(x == y)
		t.mat[0][o1] = t2.mat[0][o1], u = x;
	else{
		for(int i = 16; i >= 0; i --)
			if(fa[x][i] != fa[y][i])
				t1 = t1 * g[x][i], x = fa[x][i],
				t2 = t2 * g[y][i], y = fa[y][i];
		u = fa[x][0];
		t.mat[0][0] = (f[u][0] - f[x][1] - f[y][1]) + t1.mat[0][1] + t2.mat[0][1];
		t.mat[0][1] = f[u][1] - min(f[x][0], f[x][1]) - min(f[y][0], f[y][1])
					+ min(t1.mat[0][0], t1.mat[0][1]) + min(t2.mat[0][0], t2.mat[0][1]);
	}
	
	for(int i = 16; i >= 0; i --)
		if(fa[u][i]) t = t * g[u][i], u = fa[u][i];

	return min(t.mat[0][0], t.mat[0][1]);
}

int main(){
	n = read(), m = read();
	char str[5]; scanf("%s", str);
	for(int i = 1; i <= n; i ++) p[i] = read();
	for(int i = 1; i <  n; i ++){
		int u = read(), v = read();
		Add_Edge(u, v), Add_Edge(v, u);
		E.insert(make_pair(u, v));
		E.insert(make_pair(v, u));
	}
	dfs1(1, 0), dfs2(1, 0);
	while(m --){
		int u = read(), x = read();
		int v = read(), y = read();
		if(!x && !y && E.find(make_pair(u, v)) != E.end())
			puts("-1");
		else
			printf("%lld\n", Query(u, x, v, y));
	}
	return 0;
}
posted @ 2021-08-20 12:32  LPF'sBlog  阅读(54)  评论(0编辑  收藏  举报