【对不同形式矩阵的总结】WC 2009 最短路径问题(线段树+矩阵乘法)
题意
题目链接:https://www.luogu.org/problem/P4150
一个 \(6\times n\) 的网格图,每个格点有一个初始权值。有两种操作:
-
修改一个格子的权值
-
求两个格子之间的最短路的权值。
\(1 \leq n \leq 10^5\)
引言
显然这种题目肯定是要用线段树了,对于每一个线段树区间,我们考虑开三个 \(6\times 6\) 的数组,分别表示从左边第 \(i\) 行走到左边第 \(j\) 行、右边第 \(i\) 行走到右边第 \(j\) 行、左边第 \(i\) 行走到右边第 \(j\) 行的最小代价。这个转移是从 \((i,j),(j,k)\) 转移到 \((i,k)\) 的,和矩乘的形式有点像。首先,我们总结一下不同形式的矩阵。
正文
乘加矩阵
数字排成一个 \(n\) 行 \(m\) 列的数阵被称为矩阵,我们在若干计数问题中已经遇到过一类常见的乘加矩阵,一般可以定义如下的运算:
运算
-
加法运算
设 \(A,B,C\) 是三个 \(n\) 行 \(m\) 列的矩阵,若 \(C=A+B\) ,则 \(C_{i,j}=A_{i,j}+B_{i,j}\)。
-
乘法运算
设 \(A\) 为 \(n\) 行 \(p\) 列的矩阵, \(B\) 为 \(p\) 行 \(m\) 列的矩阵, \(C\) 为 \(n\) 行 \(m\) 列的矩阵,若 \(C = AB\) ,则 \(\displaystyle C_{i,j}=\sum_{k=1}^pA_{i,k}B_{k,j}\) 。
性质
我们不难发现这样定义运算,乘加矩阵是满足结合律和分配律,不满足交换律的(下文将给出统一证明)。
加min矩阵
我们可以将乘加矩阵的乘法换成加法,把加法换成求 \(\min\) ,就可以类似的定义出一类新的矩阵,加\(\min\)矩阵,运算定义如下:
运算
-
“加法”运算
设 \(A,B,C\) 是三个 \(n\) 行 \(m\) 列的矩阵,若 \(C=A+B\) ,则 \(C_{i,j}=\min\{A_{i,j},B_{i,j}\}\)。
-
“乘法”运算
设 \(A\) 为 \(n\) 行 \(p\) 列的矩阵, \(B\) 为 \(p\) 行 \(m\) 列的矩阵, \(C\) 为 \(n\) 行 \(m\) 列的矩阵,若 \(C = AB\) ,则 \(\displaystyle C_{i,j}=\min_{k=1}^p\{A_{i,k}+B_{k,j}\}\) 。
这类矩阵在转移是加\(\min\)形式的时候有奇效。
性质
加\(\min\)矩阵也是满足结合律和分配律,不满足交换律的。
矩阵的统一性质
矩阵统一定义
规定“乘法”和“加法”含义之后,就可以重新定义出一类矩阵,我们设元素 \(x,y\) 加法运算为 \(f(x,y)\) ,元素 \(x,y\) 的乘法运算为 \(g(x,y)\) ,在此意义下,元素的累和为 \(F_{i=lower}^{upper}\) 。
我们可以由此得到新定义矩阵的加法和乘法:
-
“加法”运算
设 \(A,B,C\) 是三个 \(n\) 行 \(m\) 列的矩阵,若 \(C=A+B\) ,则 \(C_{i,j}=f(A_{i,j},B_{i,j})\)。
-
“乘法”运算
设 \(A\) 为 \(n\) 行 \(p\) 列的矩阵, \(B\) 为 \(p\) 行 \(m\) 列的矩阵, \(C\) 为 \(n\) 行 \(m\) 列的矩阵,若 \(C = AB\) ,则 \(\displaystyle C_{i,j}=F_{k=1}^pg(A_{i,k},B_{k,j})\) 。
结论
观察上面两类常见矩阵的共同点,我们可以得到一个比较普遍的结论:
如果新定义的“乘法”对新定义的“加法”满足分配律的话,那么由该乘法和加法定义出的矩阵满足结合律和分配律。
形式的说, \(\displaystyle g(x,F_{i=lower}^{upper}y_i)=\displaystyle F_{i=lower}^{upper}g(x,y_i)\) 。
实际上,我们常说的分配律也就是指提取 \(\displaystyle\sum\) 的性质,可能说的不大清楚,我们举乘加矩阵和加\(\min\)矩阵为例(下文再证明该结论):
- (乘加矩阵)由于 \(\displaystyle x\sum_{i}y_i=\displaystyle \sum_{i}xy_i\) ,那么由数学上我们通常说的乘法和加法定义出的矩阵是满足结合律和分配律的。
- (加\(\min\)矩阵)由于 \(\displaystyle x+\min_{i}\{y_i\}= \min_{i}\{x+y_i\}\) ,那么由加法和 \(\min\) 运算定义出来的矩阵也是满足结合律和分配律的。
- (我也不知道有没有这种矩阵)由于 \(\displaystyle x+\prod_{i}y_i\neq\displaystyle \prod_{i}(x+y_i)\) ,所以把数学上的乘法当成元素的加法,把数学上的加法当成元素的乘法,如此定义出的矩阵不满足结合律和分配律(这个矩阵的加法为 \(\displaystyle C_{i,j}=A_{i,j}B_{i,j}\) ,乘法为 \(\displaystyle C_{i,j}=\prod_{k=1}^p(A_{i,k}+B_{k,j})\) ,随便举了一个例子,为了说明一下结论的普适性)。
该结论可以作为判断新定义出的矩阵是否满足结合律与分配律的条件。
为了方便起见,我们还是令 \(+\) 表示矩阵元素的“加法”运算, \(\times\) 表示矩阵元素的“乘法”运算。类似的, \(\displaystyle\sum\) 表示在新定义的加法下的累和。令新定义出的“加法”和“乘法”满足上述的分配律,我们来证明以此定义出的矩阵满足结合律和分配律。
结合律
设 \(A\) 为 \(n\) 行 \(p\) 列的矩阵, \(B\) 为 \(p\) 行 \(q\) 列的矩阵, \(C\) 为 \(q\) 行 \(m\) 列的矩阵,有 \((AB)C=A(BC)\) 。
证明:
得证,需要用到分配律提取 \(\displaystyle \sum\) 。
分配律
设 \(A\) 为 \(n\) 行 \(p\) 列的矩阵, \(B^{[1]}\) 到 \(B^{[c]}\) 均为 \(p\) 行 \(m\) 列的矩阵,有 \(\displaystyle A\sum_{a=1}^cB^{[i]}=\sum_{a=1}^cAB^{[i]}\) 。
证明:
得证,同样需要用到分配律提取 \(\displaystyle \sum\) 。
到此为止,我们已经证明了乘加矩阵,加\(\min\)矩阵,以及对乘加可分配的矩阵存在结合律和分配律。而矩阵的优化都依赖于结合律和分配律,比如常见的矩阵快速幂,预处理矩阵乘积等技巧,没有这两个定律的支撑都是一纸空谈。
本题思路
回到本题,我们发现每个线段树节点上的 \(6\times 6\) 的数组其实就是一个 \(6\times 6\) 的加\(\min\)矩阵。两个线段树节点合并的函数可以写的非常简单:
node operator + (const node &_)const
{
static node res;
res.A = A + C * _.A * (~C);
res.B = _.B + (~_.C) * B * _.C;
res.C = C * (MT1 + _.A * B) * _.C;
return res;
}
这里的取反符号表示矩阵行列翻转,在本题里表示从左边走到右边翻转成从右边走到左边。
\(MT1\) 表示单位 \(1\) 矩阵,即乘上去不变的矩阵,加\(\min\)矩阵的单位 \(1\) 矩阵大概长成这个样子:
比较好理解,注意结合矩阵的结合律和分配律。求最优解中的加号对应计数中的加法原理(并列),乘号对应乘法原理(分步)。
而对于询问,设起点在终点左边,称与起点横坐标相同的点形成的线为起点线,与种点横坐标相同的点形成的线为种点线,则可以把过程分成三步:
-
先在起点线右边兜一圈回来,再在起点线左边兜一圈回来;
-
兜到终点线;
-
先在终点线右边兜一圈回来,再在终点线左边兜一圈回来。
注意行数为 \(6\) ,一三步最多只能执行一次,就算可以执行多次,套上矩阵快速幂就行了(感受到矩阵的优美了没有)。
当然一三步可以不走,体现在状态里就是加上一个单位 \(1\) 矩阵,表示这步不执行。
详见代码,矩阵的好处之一就是写在代码中短而优雅。
代码
#pragma GCC optimize(3)
#include<bits/stdc++.h>
#define FOR(i, x, y) for(int i = (x), i##END = (y);i <= i##END; ++i)
#define DOR(i, x, y) for(int i = (x), i##END = (y);i >= i##END; --i)
template<typename T, typename _T>inline bool chk_min(T &x, const _T y){return y < x? x = y, 1 : 0;}
template<typename T, typename _T>inline bool chk_max(T &x, const _T y){return x < y? x = y, 1 : 0;}
typedef long long ll;
const int N = 1e5 + 5;
struct Matrix
{
int a[6][6];
Matrix() {asn0();}
void asn0() {FOR(i, 0, 5) FOR(j, 0, 5) a[i][j] = 2e9;}
void asn1() {FOR(i, 0, 5) FOR(j, 0, 5) a[i][j] = 2e9 * (i != j);}
int *operator [](const int x) {return a[x];}
Matrix operator ~() const
{
Matrix res = (*this);
FOR(i, 0, 5) FOR(j, 0, i - 1) std::swap(res.a[i][j], res.a[j][i]);
return res;
}
Matrix operator + (const Matrix _) const
{
Matrix res;
FOR(i, 0, 5) FOR(j, 0, 5) res.a[i][j] = std::min(a[i][j], _.a[i][j]);
return res;
}
Matrix operator *(const Matrix _) const
{
Matrix res;
FOR(i, 0, 5) FOR(j, 0, 5) FOR(k, 0, 5) chk_min(res.a[i][j], a[i][k] + _.a[k][j]);
return res;
}
void operator *=(const Matrix _)
{
(*this) = (*this) * _;
}
};
Matrix MT1;
struct node
{
Matrix A, B, C;
void reset(int *a)
{
static int _s[8], *s = _s + 1;
s[-1] = 0; FOR(i, 0, 5) s[i] = s[i - 1] + a[i];
FOR(i, 0, 5) FOR(j, 0, 5)
A[i][j] = B[i][j] = C[i][j] = (i > j ? s[i] - s[j - 1] : s[j] - s[i - 1]);
}
node operator + (const node &_)const
{
static node res;
res.A = A + C * _.A * (~C);
res.B = _.B + (~_.C) * B * _.C;
res.C = C * (MT1 + _.A * B) * _.C;
return res;
}
};
node nd[N << 2];
int a[N][6];
int n, q;
void push_up(int k) {nd[k] = nd[k << 1] + nd[k << 1 | 1];}
void build(int k, int a[N][6], int l = 1, int r = n)
{
if(l == r) {nd[k].reset(a[l]); return;}
int mid = (l + r) >> 1;
build(k << 1, a, l, mid);
build(k << 1 | 1, a, mid + 1, r);
push_up(k);
}
void update(int k, int x, int a[6], int l = 1, int r = n)
{
if(l == r) {nd[k].reset(a); return;}
int mid = (l + r) >> 1;
if(x <= mid) update(k << 1, x, a, l, mid);
else update(k << 1 | 1, x, a, mid + 1, r);
push_up(k);
}
node query(int k, int L, int R, int l = 1, int r = n)
{
if(L <= l && r <= R) return nd[k];
int mid = (l + r) >> 1;
if(R <= mid) return query(k << 1, L, R, l, mid);
else if(L > mid) return query(k << 1 | 1, L, R, mid + 1, r);
else return query(k << 1, L, R, l, mid) + query(k << 1 | 1, L, R, mid + 1, r);
}
int solve(int l, int a, int r, int b)
{
Matrix L, M, R, LM, MR, res;
M = query(1, l, r).C;
LM = query(1, 1, r).B;
MR = query(1, l, n).A;
if(l > 1) L = query(1, 1, l - 1).B;
else L = MT1;
if(r < n) R = query(1, r + 1, n).A;
else R = MT1;
res = (MT1 + MR * L) * M * (MT1 + R * LM);
return res[a][b];
}
int main()
{
MT1.asn1();
scanf("%d", &n);
FOR(j, 0, 5) FOR(i, 1, n)scanf("%d", &a[i][j]);
build(1, a);
scanf("%d", &q);
while(q--)
{
int cmd, x, y, p, q, c;
scanf("%d", &cmd);
if(cmd == 1)
{
scanf("%d%d%d", &x, &y, &c);
x--;
a[y][x] = c;
update(1, y, a[y]);
}
else
{
scanf("%d%d%d%d", &x, &y, &p, &q);
x--, p--;
if(y > q) std::swap(x, p), std::swap(y, q);
printf("%d\n", solve(y, x, q, p));
}
}
return 0;
}