【DP优化技巧】2. (广义)矩阵加快

例题

来看一道例题。P5024 [NOIP2018 提高组] 保卫王国

对于这道题,首先如果没有国王的询问,可以设定状态:\(f_{i,0/1}\) 代表以 \(i\) 为根的子树里面,自己选/不选的最小花费。

易得状态转移方程:

\[f_{u,0}=\sum_{v\in son_u} f_{v,1}\\ f_{u,1}=p_u+\sum_{v\in son_u} \min(f_{v,0},f_{v,1}) \]

这时候多了一个询问。这是有两种方法。一种是离线下来使用线段树合并优化 DP。另外一种就是这篇文章要讲解的方法。

我们多开一个状态(怎么想到的,我至今还没有明白):\(g_{i,0/1}\) 代表 \(i\) 的父亲节点必须选/不选,除了 \(i\) 这颗子树,其他子树的花费。

易得转移状态:

\[g_{u,0}=\sum_{v\in brother_u} f_{v,1}\\ g_{u,1}=p_{fa}+\sum_{v\in brother_u} \min(f_{v,0},f_{v,1}) \]

这样转移完全要 TLE,然而我们可以直接从大的里面减去,就能求出来了。

接着神奇的事情就发生了。

注意到,

\[f_{fa,0}=f_{x,1}+g_{x,0}\\ f_{fa,1}=\min(f_{x,0},f_{x,1})+g_{x,1} \]

把它重新写一下:

\[f_{fa,0}=\min(f_{x,0}+\infty ,f_{x,1}+g_{x,0})\\ f_{fa,1}=\min(f_{x,0}+g_{x,1},f_{x,1}+g_{x,1}) \]

再次注意到:\(\min(a,b)+c=\min(a+c,b+c)\),所以先做加法在取最小是满足分配律的,所以可以矩阵乘法。

定义:

\[\left[\begin{matrix} a&b \end{matrix} \right]*\left[\begin{matrix}c&d\\e&f\end{matrix}\right]=\left[\begin{matrix}\min(a+c,a+d)&\min(b+e,b+f)\end{matrix}\right] \]

类似的

\[\left[\begin{matrix}a&b\\c&d\end{matrix}\right]*\left[\begin{matrix}e&f\\g&h\end{matrix}\right]=\left[\begin{matrix}\min(a+e,b+g)&\min(a+f,b+h)\\\min(c+e,d+g)&\min(c+f,d+h)\end{matrix}\right] \]

所以:

\[\left[\begin{matrix}f_{fa,0}&f_{fa_,1}\end{matrix}\right]=\left[\begin{matrix}f_{x,0}&f_{x,1}\end{matrix}\right]*\left[\begin{matrix}\infty&g_{x,1}\\g_{x,0}&g_{x,1}\end{matrix}\right] \]

然后我们就对矩阵 \(\displaystyle\left[\begin{matrix}\infty&g_{x,1}\\g_{x,0}&g_{x,1}\end{matrix}\right]\) 进行树上倍增。每一次修改就是把某一个 \(f_{x,0/1}\) 设置为 \(\infty\)。然后稍微推一下就行了。

代码

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int __int128
namespace gtx{
//	Fast IO
	void read(int &x){
		x = 0;int h = 1;char tmp;
		do{tmp=getchar();if(tmp=='-')h*=-1;}while(!isdigit(tmp));
		while(isdigit(tmp)) x*=10,x+=tmp-'0',tmp=getchar();
		x*=h;
	}
	void read(char &x){do{x=getchar();}while(x==' '||x=='\n'||x=='\r');}
	void write(char x){putchar(x);}
	void write(int x){
		if(x<0) putchar('-'),x=-x;int st[200]={0},tot=0;
		do st[++tot]=x%10,x/=10; while(x);
		while(tot){putchar(st[tot--]+'0');}
	}
	void write(int x,char y){if(x<4e18)write(x);else write((__int128)-1);write(y);}
	#ifndef int
	void read(long long &x){
		x = 0;int h = 1;char tmp;
		do{tmp=getchar();if(tmp=='-')h*=-1;}while(!isdigit(tmp));
		while(isdigit(tmp)) x*=10,x+=tmp-'0',tmp=getchar();
		x*=h;
	}
	void write(long long x){
		if(x<0) putchar('-'),x=-x;int st[200]={0},tot=0;
		do st[++tot]=x%10,x/=10; while(x);
		while(tot){putchar(st[tot--]+'0');}
	}
	void write(long long x,char y){write(x);write(y);}
	#endif
	const int MAXN = 1e5+10;
	const int LOGN = log2((long long)MAXN)+2;
	int n,m,p[MAXN];char TTTTTTTTTTT;
	vector<int> v[MAXN];
	int f[MAXN][2],g[MAXN][2];
	void dfs(int u,int fa){
		for(int y:v[u]){
			int v = y;
			if(v==fa) continue;
			dfs(v,u);
			f[u][0] += f[v][1];
			f[u][1] += min(f[v][0],f[v][1]);
		}
		f[u][1] += p[u];
		for(int y:v[u]){
			int v = y;
			if(v==fa) continue;
			g[v][0] = f[u][0]-f[v][1];
			g[v][1] = f[u][1]-min(f[v][0],f[v][1]);
		}
	}
	struct matrix22{
		int a,b,c,d;
//		a  b
//		c  d
	};
	struct matrix12{
		int a,b;
//		a  b
	};
	matrix12 operator * (matrix12 a,matrix22 b){
		return {min(a.a+b.a,a.b+b.c),min(a.a+b.b,a.b+b.d)};
	}
	matrix22 operator * (matrix22 a,matrix22 b){
		return {min(a.a+b.a,a.b+b.c),min(a.a+b.b,a.b+b.d),
				min(a.c+b.a,a.d+b.c),min(a.c+b.b,a.d+b.d)};
	}
	const int INF = 0x3f3f3f3f3f3f3f3f;
	matrix22 ST[MAXN][LOGN];
	int fath[MAXN][LOGN],dep[MAXN];
	void init(int x,int fa){
		dep[x] = dep[fa]+1;
		fath[x][0] = fa;
		ST[x][0] = {INF,g[x][1],g[x][0],g[x][1]};
		for(int i:v[x]){
			if(i==fa) continue;
			init(i,x);
		}
	}
	void init_ST(){
		for(int j = 1;j<LOGN;j++){
			for(int i = 1;i<=n;i++){
				fath[i][j] = fath[fath[i][j-1]][j-1];
				ST[i][j] = ST[i][j-1]*ST[fath[i][j-1]][j-1];
			}
		}
	}
	matrix22 operator *= (matrix22 &x,matrix22 y){
		x = x*y;
		return x;
	}
	matrix22 climb(int &x,int w){
		matrix22 ans = {0,INF,INF,0};
		int k = 0;
		while(w){
			if(w&1){
				ans *= ST[x][k];
				x = fath[x][k];
			}
			k++;w>>=1;
		}
		return ans;
	}
	matrix12 LCA(int x,int y,matrix12 mx,matrix12 my){
		if(dep[x]<dep[y]) swap(x,y),swap(mx,my);
		auto ans = climb(x,dep[x]-dep[y]);
		mx = mx*ans;
		if(x==y){
			if(my.a==INF) mx.a = INF;
			else mx.b = INF;
			return mx*climb(x,dep[x]-1);
		}
		for(int j = LOGN-1;~j;j--){
			if(fath[x][j]!=fath[y][j]){
				mx = mx*ST[x][j];
				my = my*ST[y][j];
				x = fath[x][j];
				y = fath[y][j];
			}
		}
		int lca = fath[x][0];
		int f0 = f[lca][0]-f[x][1]-f[y][1];
		int f1 = f[lca][1]-min(f[x][1],f[x][0])-min(f[y][1],f[y][0]);
		f0 += mx.b;f0 += my.b;
		f1 += min(mx.a,mx.b);f1 += min(my.a,my.b);
		matrix12 o = {f0,f1};
		return o*climb(lca,dep[lca]-1);
	}
	signed main(){
		read(n);read(m);read(TTTTTTTTTTT);read(TTTTTTTTTTT);
		for(int i = 1;i<=n;i++){
			read(p[i]); 
		}
		for(int i = 1;i<n;i++){
			int a,b;
			read(a);read(b);
			v[a].push_back(b);
			v[b].push_back(a);
		}
		dfs(1,0);
		init(1,0);
		init_ST();
		for(int i = 1;i<=m;i++){
			int a,x,b,y;
			read(a);read(x);read(b);read(y);
			matrix12 A = {f[a][0],f[a][1]};
			matrix12 B = {f[b][0],f[b][1]};
			if(x==1) A.a=INF;
			else A.b=INF;
			if(y==1) B.a=INF;
			else B.b=INF;
			auto tmp = LCA(a,b,A,B);
			write(min(tmp.a,tmp.b),endl);
		}
		return 0;
	}
}
signed main(){
//	freopen("P5024_10.in","r",stdin);
//	freopen("P5024.out","w",stdout);
//	ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	int T = 1;
//	gtx::read(T);
	while(T--) gtx::main();
	return 0;
}

posted @ 2024-11-26 18:44  辜铜星  阅读(11)  评论(0编辑  收藏  举报