「笔记」广义矩阵乘法与 DP
写在前面
广义矩阵乘法和 DP 有什么关系?
对于一类序列上的 DP 问题,当使用矩阵的形式表示 DP 状态,用矩阵运算表示转移方程时,若发现此时的矩阵运算满足广义矩阵乘法,则可以通过预处理区间矩阵乘积来实现带元素修改的 DP 问题,因此又称动态 DP。
广义矩阵乘法
对于一 \(p\times m\) 的矩阵 \(A\),与 \(m\times q\) 的矩阵 \(B\),定义广义矩阵乘法 \(A\times B = C\) 的结果是一个 \(p\times q\) 的矩阵 \(C\),满足:
其中 \(\oplus\) 与 \(\otimes\) 是两种二元运算。
考察这种广义矩阵乘法是否满足结合律:
观察上式可知,当\(\oplus\) 运算满足交换律,\(\otimes\) 运算满足交换律、结合律,且 \(\otimes\) 对 \(\oplus\) 存在分配律,即存在 \(\left(\bigoplus a\right)\otimes b = \bigoplus \left( a\otimes b \right)\) 时,广义矩阵乘法满足结合律。根据上述运算规律,对二式进行 \(\oplus\) 的交换后有:
维护 DP
以 P1115 最大子段和 为例。
给定一个长度为 \(n\) 的数列 \(a\),选出其中连续且非空的一段使得这段和最大。
\(1\le n\le 2\times 10^5\),\(-10^4\le a_i\le 10^4\)。
1S,128MB。
记 \(f_i\) 表示以 \(a_i\) 结尾的最大子段和,初始化 \(f_0 = -\infin\)。转移时考察是否要加上前面一段的贡献。前面一段的最大贡献为 \(f_{i-1}\)。则显然有:
定义 \(g\) 为 \(f\) 的前缀最大值,答案即为 \(g_n\)。算法总时间复杂度 \(O(n)\) 级别。
考虑加法运算运算与取 \(\max\) 运算的性质:发现取 \(\max\) 满足交换律与结合律,且加法对取 \(\max\) 满足分配率,即有:
考虑定义一种广义矩阵乘法 \(A\times B = C\),满足:
考虑将上述状态转移方程写成广义矩阵乘法形式。当从 \(i-1\) 转移到 \(i\) 时,显然有:
根据上述分析,显然该运算满足结合律,则有:
其中 \(\prod\) 表示连续广义矩阵乘法。预处理整个序列的广义矩阵乘积后,根据上式即得 答案 \(g_{n}\)。总复杂度 \(O\left(3^3\times n\right)\) 级别。
静态区间查询
SP1043 GSS1 - Can you answer these queries I
给定一个长度为 \(n\) 的数列 \(a\),给定 \(m\) 次询问。
每次询问给定区间 \([l,r]\),要求选出区间 \([l,r]\) 中连续且非空的一段使得这段和最大,输出最大子段和。
\(1\le n\le 5\times 10^4\),\(-15007\le a_i\le 15007\)。
原题面中并没有给出 \(m\) 的范围,此处根据实际测试情况推断 \(m\) 与 \(n\) 同阶。
230ms,1.46G。
发现上述题目中广义矩阵乘法做法 复杂度比直接做还劣 有着很好的扩展性。对于任意区间,预处理区间对应的广义矩阵乘积后即得该区间的最大子段和。
问题变为如何快速求得区间广义矩阵乘积。广义矩阵乘法满足结合律,且本题中没有修改操作,考虑对于每个位置 \(i\) 预处理以 \(i\) 为左端点的长度为 \(2\) 的幂的区间的广义矩阵乘积。回答询问时倍增拼凑区间即可。总时间复杂度 \(O\left(3^3 (n+m)\log n\right)\) 级别。
不用维护一堆乱七八糟的玩意,个人认为比隔壁直接上线段树好写(
//知识点:矩阵乘法,倍增
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const int kL = 3;
const LL kInf = 1e9 + 2077;
//=============================================================
int n, m;
struct Matrix {
LL a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
} f[kN][21];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
LL Query(int l, int r) {
Matrix ans;
ans.a[0][0] = ans.a[0][1] = -kInf;
for (int i = 20; i >= 0; -- i) {
if (l + (1 << i) - 1 <= r) {
ans = ans * f[l][i];
l += (1 << i);
}
}
return ans.a[0][1];
}
//=============================================================
int main() {
n = read();
for (int i = 1; i <= n; ++ i) {
f[i][0].a[0][0] = f[i][0].a[2][0] = f[i][0].a[0][1] = f[i][0].a[2][1]
= read();
f[i][0].a[1][0] = f[i][0].a[0][2] = f[i][0].a[1][2] = -kInf;
}
for (int i = 1; i <= 20; ++ i) {
for (int j = 1; j + (1 << i) - 1 <= n; ++ j) {
f[j][i] = f[j][i - 1] * f[j + (1 << (i - 1))][i - 1];
}
}
m = read();
for (int i = 1; i <= m; ++ i) {
int l = read(), r = read();
printf("%lld\n", Query(l, r));
}
return 0;
}
动态区间查询
SP1716 GSS3 - Can you answer these queries III
给定一个长度为 \(n\) 的数列 \(a\),给定 \(m\) 次操作:
- 单点修改。
- 给定区间 \([l,r]\),要求选出区间 \([l,r]\) 中连续且非空的一段使得这段和最大,输出最大子段和。
\(1\le n,m\le 5\times 10^4\),\(-10^4\le a_i\le 10^4\)。
330ms,1.46G。
在上题的基础上加入了单点修改操作。发现每次修改仅会影响对应位置的矩阵,以及包含该位置的区间的广义矩阵乘积,考虑线段树维护广义矩阵乘积,每次修改仅需更新自叶到根的 \(\log n\) 个位置的对应区间。总时间复杂度 \(O\left(3^3 (n+m)\log n\right)\)。
//知识点:矩阵乘法,线段树
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const int kL = 3;
const LL kInf = 1e9 + 2077;
//=============================================================
int n, m, a[kN];
struct Matrix {
LL a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
};
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
namespace Seg {
#define ls (now_<<1)
#define rs (now_<<1|1)
#define mid ((L_+R_)>>1)
Matrix sum[kN << 2];
void Pushup(int now_) {
sum[now_] = sum[ls] * sum[rs];
}
void Build(int now_, int L_, int R_) {
if (L_ == R_) {
sum[now_].a[0][0] = sum[now_].a[2][0] = sum[now_].a[0][1]
= sum[now_].a[2][1] = a[L_];
sum[now_].a[1][0] = sum[now_].a[0][2] = sum[now_].a[1][2] = -kInf;
return ;
}
Build(ls, L_, mid), Build(rs, mid + 1, R_);
Pushup(now_);
}
void Modify(int now_, int L_, int R_, int pos_, LL val_) {
if (L_ == R_) {
sum[now_].a[0][0] = sum[now_].a[2][0] = sum[now_].a[0][1]
= sum[now_].a[2][1] = val_;
return ;
}
if (pos_ <= mid) Modify(ls, L_, mid, pos_, val_);
else Modify(rs, mid + 1, R_, pos_, val_);
Pushup(now_);
}
Matrix Query(int now_, int L_, int R_, int l_, int r_) {
if (l_ == L_ && R_ == r_) return sum[now_];
if (r_ <= mid) return Query(ls, L_, mid, l_, r_);
if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
return Query(ls, L_, mid, l_, mid) * Query(rs, mid + 1, R_, mid + 1, r_);
}
#undef ls
#undef rs
#undef mid
}
int Query(int l_, int r_) {
Matrix ans;
ans.a[0][0] = ans.a[0][1] = -kInf;
return (ans * Seg::Query(1, 1, n, l_, r_)).a[0][1];
}
//=============================================================
int main() {
n = read();
for (int i = 1; i <= n; ++ i) a[i] = read();
Seg::Build(1, 1, n);
m = read();
for (int i = 1; i <= m; ++ i) {
int opt = read(), x = read(), y = read();
if (opt == 0) Seg::Modify(1, 1, n, x, y);
if (opt == 1) printf("%d\n", Query(x, y));
}
return 0;
}
动态树形 DP
给定一棵 \(n\) 个点的树,点有点权。给定 \(m\) 次点权修改操作,求每次操作后整棵树的 最大点权独立集 的权值。
一棵树的独立集定义为满足任意一条边的两端点都不同时存在于集合中的树的一个点集,一个独立集的价值定义为集合中所有点的点权之和。
\(1\le n,m\le 10^5\),\(-100\le\) 点权 \(\le 100\)。
1S,256MB。
先考虑朴素 DP。钦定 1 为根,设 \(f_{u,0/1}\) 表示钦定点 \(u\) 不在/在 独立集时以 \(u\) 为根的子树的最大点权独立集的权值,显然有:
答案即为 \(\max (f_{1,0}, f_{1, 1})\)。
要求支持修改,又树的形态不变,考虑用树链剖分维护。但发现每个节点的 DP 值与其所有儿子有关,而树剖只能支持修改重链/子树信息。于是考虑对于每个节点,先将其轻儿子的贡献求和,再考虑其重儿子的贡献,使得可以通过对重链的修改/查询来维护上述信息。这种思想在 LCT 维护子树信息时也有所应用。
记 \(g_{u,0/1}\) 表示钦定 \(u\) 的重儿子不在独立集,点 \(u\) 不在/在 独立集时以 \(u\) 为根的子树的最大点权独立集的权值。记 \(\operatorname{s}_u\) 表示 \(u\) 的重儿子,显然有:
则对 \(f\) 的转移可以改写成下列形式:
出现了一个熟悉的形式,套路地定义广义矩阵乘法 \(A\times B = C\),满足:
根据上述转移方程,有下列关系成立。
于是可以考虑先预处理出 \(g\) 数组初始化转移矩阵,再使用线段树维护区间矩阵乘积。转移矩阵写在前面是因为 dfs 序列中深度较浅的点在前,转移矩阵写在前面可以直接按 dfs 序求得区间矩阵乘积并转移。若转移矩阵写在后面,需要先将区间内的元素顺序反转。经过预处理后,求得以 1 为根的重链对应区间的矩阵乘积,即得 \(f_{u,0}\) 与 \(f_{u,1}\)。正确性显然,重链一定以某叶节点为链底,以 1 为根的重链上所有轻儿子子树信息的并即为整棵树的信息。
考虑修改操作对哪些位置的 \(g\) 会产生影响。考虑其实际含义,\(g\) 维护的是轻儿子子树信息。被影响的节点显然为指定的修改位置 \(x\),以及子树中包含被修改位置,且为轻儿子的节点的父亲,后者可以通过从被修改位置不断跳重链来进行遍历。每次跳到的重链的顶的父亲,即为对应节点。
每次更新上述节点时先求得修改前以该节点的对应轻儿子的子树信息,修改子树中的节点后再求得该节点的对应轻儿子子树信息。根据两次求得的子树信息的差更新该节点的 \(g\),并将即将被修改的节点调整为当前节点。建议结合代码理解。
总复杂度 \(O(8n\log n + 8m\log^2 n)\) 级别。
//知识点:树形 DP,矩阵乘法,重链剖分,线段树
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e5 + 10;
const int kL = 2;
const int kInf = 1e9 + 2077;
//=============================================================
int n, m, e_num, head[kN], val[kN], v[kN << 1], ne[kN << 1];
int dfn_num, dfn[kN], id[kN], f[kN][2], g[kN][2];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
struct Matrix {
int a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
} matrix[kN];
void Add(int u_, int v_) {
v[++ e_num] = v_;
ne[e_num] = head[u_];
head[u_] = e_num;
}
namespace Seg { //维护区间矩阵乘积
#define ls (now_<<1)
#define rs (now_<<1|1)
#define mid ((L_+R_)>>1)
Matrix sum[kN << 2];
void Pushup(int now_) {
sum[now_] = sum[ls] * sum[rs];
}
void Build(int now_, int L_, int R_) {
if (L_ == R_) {
sum[now_] = matrix[L_];
return ;
}
Build(ls, L_, mid), Build(rs, mid + 1, R_);
Pushup(now_);
}
void Modify(int now_, int L_, int R_, int pos_) {
if (L_ == R_) {
sum[now_] = matrix[pos_];
return ;
}
if (pos_ <= mid) Modify(ls, L_, mid, pos_);
else Modify(rs, mid + 1, R_, pos_);
Pushup(now_);
}
Matrix Query(int now_, int L_, int R_, int l_, int r_) {
if (l_ == L_ && R_ == r_) return sum[now_];
if (r_ <= mid) return Query(ls, L_, mid, l_, r_);
if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
return Query(ls, L_, mid, l_, mid) * Query(rs, mid + 1, R_, mid + 1, r_);
}
#undef ls
#undef rs
#undef mid
}
namespace HLD {
int fa[kN], sz[kN], son[kN], dep[kN], top[kN], end[kN];
void Dfs1(int u_, int fa_) {
sz[u_] = 1;
fa[u_] = fa_;
f[u_][1] = val[u_];
dep[u_] = dep[fa_] + 1;
for (int i = head[u_]; i; i = ne[i]) { //预处理 f
int v_ = v[i];
if (v_ == fa_) continue ;
Dfs1(v_, u_);
sz[u_] += sz[v_];
if (sz[v_] > sz[son[u_]]) son[u_] = v_;
f[u_][0] += std::max(f[v_][0], f[v_][1]);
f[u_][1] += f[v_][0];
}
}
void Dfs2(int u_, int top_) {
dfn[u_] = ++ dfn_num;
id[dfn_num] = u_;
top[u_] = top_;
Chkmax(end[top_], dfn_num);
if (son[u_]) Dfs2(son[u_], top_);
g[u_][1] = val[u_];
for (int i = head[u_]; i; i = ne[i]) { //预处理 g
int v_ = v[i];
if (v_ == fa[u_] || v_ == son[u_]) continue ;
Dfs2(v_, v_);
g[u_][0] += std::max(f[v_][0], f[v_][1]);
g[u_][1] += f[v_][0];
}
}
void Modify(int u_, int val_) {
matrix[dfn[u_]].a[1][0] += val_ - val[u_]; //修改 u_ 的 g[u_][1]
val[u_] = val_; //更新点权
while (u_) { //u_ 不断上跳
Matrix old = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //以 top[u_] 为根的子树的信息
Seg::Modify(1, 1, n, dfn[u_]); //修改节点 u_ 的信息(单点修改矩阵)
Matrix newone = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //更新后以 top[u_] 为根的子树的信息
u_ = fa[top[u_]]; //更新轻儿子 u_ 的父亲的 g
//注意下文的赋值还未更新到线段树上,上面需要求得未修改之前的信息,更新线段树信息要在之后进行
matrix[dfn[u_]].a[0][0] += std::max(newone.a[0][0], newone.a[1][0]) -
std::max(old.a[0][0], old.a[1][0]);
matrix[dfn[u_]].a[0][1] = matrix[dfn[u_]].a[0][0];
matrix[dfn[u_]].a[1][0] += newone.a[0][0] - old.a[0][0];
}
}
int Query() { //求得以 1 为根的重链对应区间的矩阵乘积,即得答案
//重链一定以某叶节点为链底,以 1 为根的重链上所有轻儿子子树信息的并即为整棵树的信息。
Matrix ans = Seg::Query(1, 1, n, 1, end[1]);
return std::max(ans.a[0][0], ans.a[1][0]);
}
}
//=============================================================
int main() {
n = read(), m = read();
for (int i = 1; i <= n; ++ i) val[i] = read();
for (int i = 1; i < n; ++ i) {
int u_ = read(), v_ = read();
Add(u_, v_), Add(v_, u_);
}
HLD::Dfs1(1, 0), HLD::Dfs2(1, 1);
for (int i = 1; i <= n; ++ i) { //构造转移矩阵
matrix[dfn[i]].a[0][0] = matrix[dfn[i]].a[0][1] = g[i][0];
matrix[dfn[i]].a[1][0] = g[i][1], matrix[dfn[i]].a[1][1] = -kInf;
}
Seg::Build(1, 1, n);
while (m --) {
int x_ = read(), y_ = read();
HLD::Modify(x_, y_);
printf("%d\n", HLD::Query());
}
return 0;
}
例题
先咕着。
写在最后
鸣谢:
参考: