【DP优化技巧】2. (广义)矩阵加快
例题
来看一道例题。P5024 [NOIP2018 提高组] 保卫王国
对于这道题,首先如果没有国王的询问,可以设定状态:\(f_{i,0/1}\) 代表以 \(i\) 为根的子树里面,自己选/不选的最小花费。
易得状态转移方程:
\[f_{u,0}=\sum_{v\in son_u} f_{v,1}\\
f_{u,1}=p_u+\sum_{v\in son_u} \min(f_{v,0},f_{v,1})
\]
这时候多了一个询问。这是有两种方法。一种是离线下来使用线段树合并优化 DP。另外一种就是这篇文章要讲解的方法。
我们多开一个状态(怎么想到的,我至今还没有明白):\(g_{i,0/1}\) 代表 \(i\) 的父亲节点必须选/不选,除了 \(i\) 这颗子树,其他子树的花费。
易得转移状态:
\[g_{u,0}=\sum_{v\in brother_u} f_{v,1}\\
g_{u,1}=p_{fa}+\sum_{v\in brother_u} \min(f_{v,0},f_{v,1})
\]
这样转移完全要 TLE
,然而我们可以直接从大的里面减去,就能求出来了。
接着神奇的事情就发生了。
注意到,
\[f_{fa,0}=f_{x,1}+g_{x,0}\\
f_{fa,1}=\min(f_{x,0},f_{x,1})+g_{x,1}
\]
把它重新写一下:
\[f_{fa,0}=\min(f_{x,0}+\infty ,f_{x,1}+g_{x,0})\\
f_{fa,1}=\min(f_{x,0}+g_{x,1},f_{x,1}+g_{x,1})
\]
再次注意到:\(\min(a,b)+c=\min(a+c,b+c)\),所以先做加法在取最小是满足分配律的,所以可以矩阵乘法。
定义:
\[\left[\begin{matrix}
a&b
\end{matrix}
\right]*\left[\begin{matrix}c&d\\e&f\end{matrix}\right]=\left[\begin{matrix}\min(a+c,a+d)&\min(b+e,b+f)\end{matrix}\right]
\]
类似的
\[\left[\begin{matrix}a&b\\c&d\end{matrix}\right]*\left[\begin{matrix}e&f\\g&h\end{matrix}\right]=\left[\begin{matrix}\min(a+e,b+g)&\min(a+f,b+h)\\\min(c+e,d+g)&\min(c+f,d+h)\end{matrix}\right]
\]
所以:
\[\left[\begin{matrix}f_{fa,0}&f_{fa_,1}\end{matrix}\right]=\left[\begin{matrix}f_{x,0}&f_{x,1}\end{matrix}\right]*\left[\begin{matrix}\infty&g_{x,1}\\g_{x,0}&g_{x,1}\end{matrix}\right]
\]
然后我们就对矩阵 \(\displaystyle\left[\begin{matrix}\infty&g_{x,1}\\g_{x,0}&g_{x,1}\end{matrix}\right]\) 进行树上倍增。每一次修改就是把某一个 \(f_{x,0/1}\) 设置为 \(\infty\)。然后稍微推一下就行了。
代码
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int __int128
namespace gtx{
// Fast IO
void read(int &x){
x = 0;int h = 1;char tmp;
do{tmp=getchar();if(tmp=='-')h*=-1;}while(!isdigit(tmp));
while(isdigit(tmp)) x*=10,x+=tmp-'0',tmp=getchar();
x*=h;
}
void read(char &x){do{x=getchar();}while(x==' '||x=='\n'||x=='\r');}
void write(char x){putchar(x);}
void write(int x){
if(x<0) putchar('-'),x=-x;int st[200]={0},tot=0;
do st[++tot]=x%10,x/=10; while(x);
while(tot){putchar(st[tot--]+'0');}
}
void write(int x,char y){if(x<4e18)write(x);else write((__int128)-1);write(y);}
#ifndef int
void read(long long &x){
x = 0;int h = 1;char tmp;
do{tmp=getchar();if(tmp=='-')h*=-1;}while(!isdigit(tmp));
while(isdigit(tmp)) x*=10,x+=tmp-'0',tmp=getchar();
x*=h;
}
void write(long long x){
if(x<0) putchar('-'),x=-x;int st[200]={0},tot=0;
do st[++tot]=x%10,x/=10; while(x);
while(tot){putchar(st[tot--]+'0');}
}
void write(long long x,char y){write(x);write(y);}
#endif
const int MAXN = 1e5+10;
const int LOGN = log2((long long)MAXN)+2;
int n,m,p[MAXN];char TTTTTTTTTTT;
vector<int> v[MAXN];
int f[MAXN][2],g[MAXN][2];
void dfs(int u,int fa){
for(int y:v[u]){
int v = y;
if(v==fa) continue;
dfs(v,u);
f[u][0] += f[v][1];
f[u][1] += min(f[v][0],f[v][1]);
}
f[u][1] += p[u];
for(int y:v[u]){
int v = y;
if(v==fa) continue;
g[v][0] = f[u][0]-f[v][1];
g[v][1] = f[u][1]-min(f[v][0],f[v][1]);
}
}
struct matrix22{
int a,b,c,d;
// a b
// c d
};
struct matrix12{
int a,b;
// a b
};
matrix12 operator * (matrix12 a,matrix22 b){
return {min(a.a+b.a,a.b+b.c),min(a.a+b.b,a.b+b.d)};
}
matrix22 operator * (matrix22 a,matrix22 b){
return {min(a.a+b.a,a.b+b.c),min(a.a+b.b,a.b+b.d),
min(a.c+b.a,a.d+b.c),min(a.c+b.b,a.d+b.d)};
}
const int INF = 0x3f3f3f3f3f3f3f3f;
matrix22 ST[MAXN][LOGN];
int fath[MAXN][LOGN],dep[MAXN];
void init(int x,int fa){
dep[x] = dep[fa]+1;
fath[x][0] = fa;
ST[x][0] = {INF,g[x][1],g[x][0],g[x][1]};
for(int i:v[x]){
if(i==fa) continue;
init(i,x);
}
}
void init_ST(){
for(int j = 1;j<LOGN;j++){
for(int i = 1;i<=n;i++){
fath[i][j] = fath[fath[i][j-1]][j-1];
ST[i][j] = ST[i][j-1]*ST[fath[i][j-1]][j-1];
}
}
}
matrix22 operator *= (matrix22 &x,matrix22 y){
x = x*y;
return x;
}
matrix22 climb(int &x,int w){
matrix22 ans = {0,INF,INF,0};
int k = 0;
while(w){
if(w&1){
ans *= ST[x][k];
x = fath[x][k];
}
k++;w>>=1;
}
return ans;
}
matrix12 LCA(int x,int y,matrix12 mx,matrix12 my){
if(dep[x]<dep[y]) swap(x,y),swap(mx,my);
auto ans = climb(x,dep[x]-dep[y]);
mx = mx*ans;
if(x==y){
if(my.a==INF) mx.a = INF;
else mx.b = INF;
return mx*climb(x,dep[x]-1);
}
for(int j = LOGN-1;~j;j--){
if(fath[x][j]!=fath[y][j]){
mx = mx*ST[x][j];
my = my*ST[y][j];
x = fath[x][j];
y = fath[y][j];
}
}
int lca = fath[x][0];
int f0 = f[lca][0]-f[x][1]-f[y][1];
int f1 = f[lca][1]-min(f[x][1],f[x][0])-min(f[y][1],f[y][0]);
f0 += mx.b;f0 += my.b;
f1 += min(mx.a,mx.b);f1 += min(my.a,my.b);
matrix12 o = {f0,f1};
return o*climb(lca,dep[lca]-1);
}
signed main(){
read(n);read(m);read(TTTTTTTTTTT);read(TTTTTTTTTTT);
for(int i = 1;i<=n;i++){
read(p[i]);
}
for(int i = 1;i<n;i++){
int a,b;
read(a);read(b);
v[a].push_back(b);
v[b].push_back(a);
}
dfs(1,0);
init(1,0);
init_ST();
for(int i = 1;i<=m;i++){
int a,x,b,y;
read(a);read(x);read(b);read(y);
matrix12 A = {f[a][0],f[a][1]};
matrix12 B = {f[b][0],f[b][1]};
if(x==1) A.a=INF;
else A.b=INF;
if(y==1) B.a=INF;
else B.b=INF;
auto tmp = LCA(a,b,A,B);
write(min(tmp.a,tmp.b),endl);
}
return 0;
}
}
signed main(){
// freopen("P5024_10.in","r",stdin);
// freopen("P5024.out","w",stdout);
// ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
int T = 1;
// gtx::read(T);
while(T--) gtx::main();
return 0;
}