【知识总结】动态 DP

勾起了我悲伤的回忆 —— NOIP2018 316pts ……

主要思想:将 DP 过程分解为方便单点修改和一个区间合并的操作(通常类似矩阵乘法),然后用数据结构(通常为线段树)维护。

例:给定一个长为 \(n\) 的整数序列,相邻两个数最多选一个,有 \(m\) 次修改序列中的一个数,求每次修改后选出数之和的最大值。
\(n,m\leq 10^5\)

如果不会做不带修改的情况,请默默摁 Ctrl + w 然后去学 DP 入门

如果不带修改,明显设 \(f_{i,0/1}\) 表示当第 \(i\) 个点选 (0) / 不选 (1) 时,前 \(i\) 个点的和的最大值。于是有如下转移方程:

\[f_{i,0}=f_{i-1,1} \]

\[f_{i,1}=\max(f_{i-1,0},f_{i-1,1})+a_i \]

如果加入修改操作呢?只有这两个 DP 方程比较难办,因为修改一个值就要重新计算后面的所有答案。GG

接下来是「动态 DP 」中最巧妙的部分:考虑用一个矩阵来表示从 \(i-1\) 点向 \(i\) 点转移,用某个表示「初始状态」的矩阵依次乘上每个点的转移就是答案。因为矩阵乘法有结合律,所以可以把答案表示成「初始状态」乘上「修改点前面的矩阵乘积」乘上「当前位置修改后的矩阵」乘上「修改点后面的矩阵乘积」。这样只需要用线段树单点修改和查询区间乘积(事实上这道题只需要查全局乘积)即可。

然而,这道题中转移的运算并不是加和乘,尤其是其中还有一个碍眼的求最大值。但我们可以把矩阵乘法的定义稍加修改,把原来两个整数的「乘法」改为两个整数的加法,「加法」改为对两个整数取最大值。这样我们就构造如下转移矩阵:

\[\begin{bmatrix} f_{i-1,0}&f_{i-1,1} \end{bmatrix} \begin{bmatrix} 0&a_i\\ 0&-\infty\\ \end{bmatrix}= \begin{bmatrix} f_{i,0}&f_{i,1}\\ \end{bmatrix}\]

还有一个很多人没考虑过的细节 (可能是大佬们认为这个问题太显然不需要考虑) :这个「初始状态」是什么呢?对于这道题,前一个数如果不选是不影响当前决策的,而如果选了的话就会造成一个当前点不能选的「约束」。而第一个点无论如何都不会受到这种「约束」,所以第一个点的「前一个点」应该被看作「没有选」,即初始状态为 \(\begin{bmatrix}0&-\infty\end{bmatrix}\)

我们把这个问题扩展到树上,即每条边的两端点中至少选一个点(洛谷 4719【模板】动态 DP )。考虑树链剖分来转化成序列问题。设 \(f_{i,0/1}\) 表示 \(i\) 点选 / 不选时 \(i\) 点子树中的最大权值和,\(g_{i,0/1}\) 表示 \(i\) 点选 / 不选时 \(i\) 点子树除 \(s_i\) 的子树以外的部分中的最大权值和,其中 \(s_i\)\(i\) 的重儿子。对于一条重链有如下方程:

\[\begin{bmatrix} f_{s_i,0}&f_{s_i,1} \end{bmatrix} \begin{bmatrix} g_{i,0}&g_{i,1}\\ g_{i,0}&-\infty\\ \end{bmatrix}= \begin{bmatrix} f_{i,0}&f_{i,1}\\ \end{bmatrix}\]

这样,每个点的答案是「初始状态」乘上它到所在重链末尾的矩阵乘积。

至于具体实现,可以开始先一遍 DP 算出所有的 \(f\)\(g\) 。每次修改时沿着重链向上爬,暴力修改链首父亲的 \(g\) 值。链首到链首父亲的边是一条轻边,所以这样每次修改一个点时要更新 \(g\) 值的点的数量约等于当前点到根的路径上的轻边数量(可能有加一减一之类的细节),是 \(O(\log n)\) 。因此总复杂度 \(O(mlog^2n)\)

和上面类似的分析,初始状态(叶子节点那个不存在的重儿子的 \(f\) 值)是 \(\begin{bmatrix}0&-\infty\end{bmatrix}\) 。用这个东西去乘相当于取原矩阵的第一行,所以不需要「显式」地乘。

代码:

很抱歉我代码里的矩阵行列和上文是反的,所有矩阵乘法的顺序也是反的我也不知道怎么回事 QAQ 。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cctype>
using namespace std;

namespace zyt
{
	template<typename T>
	inline bool read(T &x)
	{
		char c;
		bool f = false;
		x = 0;
		do
			c = getchar();
		while (c != EOF && c != '-' && !isdigit(c));
		if (c == EOF)
			return false;
		if (c == '-')
			f = true, c = getchar();
		do
			x = x * 10 + c - '0', c = getchar();
		while (isdigit(c));
		if (f)
			x = -x;
		return true;
	}
	template<typename T>
	inline void write(T x)
	{
		static char buf[20];
		char *pos = buf;
		if (x < 0)
			putchar('-'), x = -x;
		do
			*pos++ = x % 10 + '0';
		while (x /= 10);
		while (pos > buf)
			putchar(*--pos);
	}
	const int N = 1e5 + 10, INF = 0x3f3f3f3f;
	int n, m, head[N], ecnt, w[N], size[N], son[N], fa[N], dfn[N], dfncnt, top[N], f[N][2], g[N][2], end[N], pos[N];
	struct edge
	{
		int to, next;
	}e[N << 1];
	void add(const int a, const int b)
	{
		e[ecnt] = (edge){b, head[a]}, head[a] = ecnt++;
	}
	void dfs(const int u, const int f)
	{
		fa[u] = f, size[u] = 1;
		for (int i = head[u]; ~i; i = e[i].next)
		{
			int v = e[i].to;
			if (v == f)
				continue;
			dfs(v, u);
			size[u] += size[v];
			if (size[v] > size[son[u]])
				son[u] = v;
		}
	}
	void dfs2(const int u, const int t)
	{
		top[u] = t, dfn[u] = ++dfncnt, pos[dfncnt] = u, end[t] = u;
		if (son[u])
			dfs2(son[u], t);
		for (int i = head[u]; ~i; i = e[i].next)
		{
			int v = e[i].to;
			if (v == fa[u] || v == son[u])
				continue;
			dfs2(v, v);
		}
	}
	void dfs3(const int u)
	{
		g[u][0] = 0, g[u][1] = w[u];
		for (int i = head[u]; ~i; i = e[i].next)
		{
			int v = e[i].to;
			if (v == fa[u] || v == son[u])
				continue;
			dfs3(v);
			g[u][0] += max(f[v][0], f[v][1]);
			g[u][1] += f[v][0];
		}
		f[u][0] = g[u][0], f[u][1] = g[u][1];
		if (son[u])
		{
			dfs3(son[u]);
			f[u][0] += max(f[son[u]][0], f[son[u]][1]);
			f[u][1] += f[son[u]][0];
		}
	}
	struct Matrix
	{
		int data[2][2], n, m;
		Matrix(const int _n = 0, const int _m = 0)
			: n(_n), m(_m)
		{
			for (int i = 0; i < n; i++)
				for (int j = 0; j < m; j++)
					data[i][j] = -INF;
		}
		Matrix operator * (const Matrix &b) const
		{
			Matrix ans(n, b.m);
			for (int i = 0; i < n; i++)
				for (int k = 0; k < m; k++)
					for (int j = 0; j < b.m; j++)
						ans.data[i][j] = max(ans.data[i][j], data[i][k] + b.data[k][j]);
			return ans;
		}
	}val[N];
	namespace Segment_Tree
	{
		struct node
		{
			Matrix m;
		}tree[N << 2];
		void update(const int rot)
		{
			tree[rot].m = tree[rot << 1].m * tree[rot << 1 | 1].m;
		}
		void build(const int rot, const int lt, const int rt)
		{
			tree[rot].m = Matrix(2, 2);
			if (lt == rt)
				return void(tree[rot].m = val[pos[lt]]);
			int mid = (lt + rt) >> 1;
			build(rot << 1, lt, mid), build(rot << 1 | 1, mid + 1, rt);
			update(rot);
		}
		void change(const int rot, const int lt, const int rt, const int p)
		{
			if (lt == rt)
				return void(tree[rot].m = val[pos[p]]);
			int mid = (lt + rt) >> 1;
			if (p <= mid)
				change(rot << 1, lt, mid, p);
			else
				change(rot << 1 | 1, mid + 1, rt, p);
			update(rot);
		}
		Matrix query(const int rot, const int lt, const int rt, const int ls, const int rs)
		{
			if (ls <= lt && rt <= rs)
				return tree[rot].m;
			int mid = (lt + rt) >> 1;
			if (rs <= mid)
				return query(rot << 1, lt, mid, ls, rs);
			else if (ls > mid)
				return query(rot << 1 | 1, mid + 1, rt, ls, rs);
			else
				return query(rot << 1, lt, mid, ls, rs) * query(rot << 1 | 1, mid + 1, rt, ls, rs);
		}
	}
	int work()
	{
		using namespace Segment_Tree;
		read(n), read(m);
		memset(head, -1, sizeof(int[n + 1]));
		for (int i = 1; i <= n; i++)
			read(w[i]), val[i] = Matrix(2, 2);
		for (int i = 1; i < n; i++)
		{
			int a, b;
			read(a), read(b);
			add(a, b), add(b, a);
		}
		dfs(1, 0), dfs2(1, 1), dfs3(1);
		for (int i = 1; i <= n; i++)
			val[i].data[0][0] = val[i].data[0][1] = g[i][0], val[i].data[1][0] = g[i][1], val[i].data[1][1] = -INF;
		build(1, 1, n);
		while (m--)
		{
			int u, x;
			read(u), read(x);
			val[u].data[1][0] += x - w[u];
			w[u] = x;
			Matrix a, b;
			while (u)
			{
				a = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
				change(1, 1, n, dfn[u]);
				b = query(1, 1, n, dfn[top[u]], dfn[end[top[u]]]);
				u = fa[top[u]];
				val[u].data[0][0] += max(b.data[0][0], b.data[1][0]) - max(a.data[0][0], a.data[1][0]);
				val[u].data[0][1] = val[u].data[0][0];
				val[u].data[1][0] += b.data[0][0] - a.data[0][0];
			}
			Matrix ans = query(1, 1, n, dfn[1], dfn[end[1]]);
			write(max(ans.data[0][0], ans.data[1][0])), putchar('\n');
		}
		return 0;
	}
}
int main()
{
#ifdef BlueSpirit
	freopen("4719.in", "r", stdin);
#endif
	return zyt::work();
}
posted @ 2019-07-13 23:06  Inspector_Javert  阅读(236)  评论(0编辑  收藏  举报