浅谈 DDP 与 广义矩阵乘法
浅谈 DDP 与 广义矩阵乘法
更好的阅读体验戳此进入
引入
例题 #1
题目 SP1043 GSS1 - Can you answer these queries I。
题面:对于序列 $ a_n \(,\) q $ 次询问求 $ \left[ l_i, r_i \right] $ 的区间最大子段和。
$ 1 \le n \le 5 \times 10^4, a_i \le 15007 \(。\) q $ 在 int
范围内。
有显然的 DP,令 $ dp(i) $ 表示以 $ a_i $ 结尾的区间 $ \left[ 1, i \right] $ 的最大子段和。
一般的区间最大子段和转移为 $ dp(i) = \max(dp(i - 1) + a_i, a_i) $,对于本题则为:
但是每次询问都重新维护一遍显然会寄,所以考虑一些人类智慧。
广义矩阵乘法
首先对于普通的矩阵乘法,我们随便举个例子:
我们考虑把这玩意省略的东西写全一点:
我们发现这里一共有两种符号,$ + $ 和 $ \times $,考虑替换这两个符号,假设我们存在两个二元运算符 $ \oplus $ 和 $ \otimes $,替换一下就是:
这就是广义矩阵乘法了。
然后为了让这东西能有点用,所以现在我们还需要让广义矩阵乘法具有结合律,也就是:
或者写成:
展开一下有:
不难发现,如果 $ \otimes $ 对 $ \oplus $ 存在分配律,或者说满足 $ (a \oplus b) \otimes c = (a \otimes c) \oplus (b \otimes c) $,那么把上面的两个式子拆开之后,最后就会发现实际上是相同的。
所以我们便可得出这个结论——如果 $ \otimes $ 对 $ \oplus $ 存在分配律,那么形成的广义矩阵乘法矩阵之间就满足结合律。
DDP
为了便于理解,我们先引入一个更简单的例子。
例题 #0
题面:求斐波那契数列。
十分显然的 $ dp(i) = dp(i - 1) + dp(i - 2) $。
我们考虑一些让复杂度更高奇妙的方法。
我们考虑把这个转移矩阵记作 $ X $,转移的起始和终点分别写成一个矩阵,为了凑齐矩阵,可以把终点塞一个 $ 0 $,或者 $ dp(i - 1) $ 这种无关的值,如替换为 $ 0 $,则有:
此时的 $ \otimes $ 就是 $ \times \(,\) \oplus $ 就是 $ + $,也就是说我们这个矩阵之间的乘法就是经典的矩阵乘法,则不难算出我们需要的矩阵 $ X $ 为:
也就是说:
于是我们不如继续尝试把 $ \texttt{LHS} $ 的带有 $ dp $ 的式子接着拆下去:
这东西虽然可以转移的,但是自己观察一下,这东西的形式和最开始的柿子好像形式很相似,于是我们考虑如果不去选择 $ 0 $ 而去选择 $ dp(i - 1) $ 可能会更方便一些(当然选择 $ 0 $ 也是可以转移的,只是会更复杂一些)。所以我们最开始的式子就变成了:
然后这样也很简单就能求出来 $ X $ 了:
然后我们继续往下转移:
此时我们便能发现一些奇妙的推导:
于是这个时候我们便可以矩阵快速幂(因此我们才需要使广义矩阵之间满足结合律,否则用不了矩阵快速幂或者线段树猫树之类的)把原来的线性求,变成现在的 $ \log $ 求了(虽然常数大了点)。
当然这只是最最基础的 DDP。
例题 #0.5
题目:LG-P1115 最大子段和。
最大字段和模板题,因为比较水所以这里就不叙述的太详细了,令 $ dp(i) $ 表示以 $ a_i $ 结尾的区间 $ \left[ 1, i \right] $ 的最大子段和,有转移 $ dp(i) = \max(dp(i - 1) + a_i, a_i) $,但是这样求完是还需要再扫一遍取 $ \max $,这样我们就不太好用矩阵转移,于是考虑再令 $ ans(i) = \max(dp(i), ans(i - 1)) $。于是我们不难推出如下转移:
同时注意这里的矩阵是广义矩阵乘法,$ \otimes $ 为 $ + \(,\) \oplus $ 为 $ \max $,显然 $ + $ 对 $ \max $ 有分配律,则矩阵有结合律。
然后像之前一样拆下去就可以得到一大串矩阵相乘,成功大幅增加常数成功提高了转移的扩展性。
例题 #1
现在让我们回到文章伊始提到的例题,这里我们再回顾一下:
题目 SP1043 GSS1 - Can you answer these queries I。
题面:对于序列 $ a_n \(,\) q $ 次询问求 $ \left[ l_i, r_i \right] $ 的区间最大子段和。
$ 1 \le n \le 5 \times 10^4, a_i \le 15007 \(。\) q $ 在 int
范围内。
有显然的 DP,令 $ dp(i) $ 表示以 $ a_i $ 结尾的区间 $ \left[ 1, i \right] $ 的最大子段和。
一般的区间最大子段和转移为 $ dp(i) = \max(dp(i - 1) + a_i, a_i) $,对于本题则为:
和上一题一样,我们还是考虑令:
$ i = l $ 就是个特判可以不用考虑,考虑把 $ l \lt i \le r $ 的式子变成广义矩乘:
然后继续往下拆:
一直拆到边界,就是:
因为我们要的答案为 $ ans(r) $,所以把 $ i $ 换成 $ r $。并且显然有 $ ans(l) = dp(l) = a_l $,则原式化为:
可以发现我们想要得到 $ ans(r) $,只需要从已知具体值的 $ \begin{bmatrix} a_l & a_l & 0 \end{bmatrix} $,乘上一大堆矩阵。
为了便于书写,我们记:
则:
而在实际使用中我们还有很多小技巧与转化来让我们实现起来更简单:
如对于矩阵我们一般通过新开一个结构体然后重载 *
实现,这个时候我们可以考虑把初始矩阵补齐成 $ 3 \times 3 $ 的,这样最终答案也会变成 $ 3 \times 3 $,我们便只需要维护长宽均为 $ 3 $ 的定大小的矩阵即可,至于补齐的值是什么这个实际上是任意的,因为我们不需要这些补齐的值算出来的答案,所以最终计算出什么也是无关紧要的,计算后直接输出答案矩阵的 $ (0, 0) $ 位置,也就是 $ ans(r) $ 即可。
举个例子:
然后在重载矩阵的乘法的时候,我们应该注意,是否有可能会超出 int
或 long long
等,所以在我们取 $ \max $ 的时候可以考虑额外加一个 $ -\infty $,也就是当有两个数相加比 $ -\infty $ 更小的时候那么其中一定至少有一个数为 $ -\infty $,这个时候我们直接将答案变成 $ -\infty $ 即可。
然后如果严格按照上面的算法实现出来后,会发现一个问题,在某些特殊的时候答案会错误,经过 debug 我们发现当 $ l = r $ 的时候,本来应该只有最初的初始矩阵,但是按照我们的算法会多出来一个 $ \left[ l + 1, r \right] $ 的查询,但此时 $ l + 1 \gt r $,我们当然可以考虑特判这个情况,但是我们实际上还可以再次转换一下思路,让这个东西更泛用,更美观。
我们考虑能否将初始矩阵 $ \begin{bmatrix}a_l & a_l & 0\end{bmatrix} $ 也写成 $ A_l $,然后在最前面再弄一个新的初始矩阵 $ T $,让它和 $ A_l $ 运算后会变成我们想要的 $ \begin{bmatrix}a_l & a_l & 0\end{bmatrix} $,于是很容易写出:
不难算出:
所以我们的式子就最终变成了:
现在这个式子就很好维护了。
这个时候我们观察那一大串 $ A $,不难发现这玩意好像就是把 $ \left[ l, r \right] $ 的 $ A $ 乘到一起了,而且每次询问会有不同的 $ l, r $,这不就是区间查询吗,我们于是可以考虑用一些妙妙数据结构去维护这玩意。
此时我们可能会想到树状数组,但是仔细思考一下就会发现是不行的。众所周知树状数组在运算上有一个类似差分的过程,也就是说需要有逆元的存在,但是在我们定义的矩阵和矩阵之间的广义矩乘构成的群中是否存在逆元呢?显然我们都知道矩阵求逆的一个必要条件是行列式不为 $ 0 $,这东西是否不为 $ 0 $ 是无法确定的,也就是说不一定会存在逆元,所以不能用树状数组或者前缀积之类的东西区间查询。
一个比较显然的东西就是线段树,每个节点存的是一个矩阵,或者说对应区间的矩阵乘起来(看到这里大概就又能意识到让矩阵满足结合律的重要性了)。当然这里的乘仍然指的是 $ \otimes $ 为 $ + \(,\) \oplus $ 为 $ \max $ 的广义矩乘。用线段树维护的话最后的复杂度大概是 $ O(\xi^3 (n + q)) $,其中 $ \xi $ 为矩阵的行数或列数,这里 $ \xi = 3 $。
对于本题来说,我们可以发现是多次查询没有修改,所以也可以考虑使用猫树,使查询复杂度变为 $ O(1) $,这样总复杂度就变成了 $ O(\xi^3(n \log n + q)) $,具体介绍可以看咕咕日报 一种神奇的数据结构——猫树。
当然用平衡树,分块之类的也能做,这里不再过多赘述。
这里提供一个 DDP + 猫树 实现的 GSS1。(注意猫树节点数需要补齐到 $ 2^t $,且要开好数组范围的大小,否则很容易 RE)
#define _USE_MATH_DEFINES
#include <bits/extc++.h>
#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}
using namespace std;
using namespace __gnu_pbds;
mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}
typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;
#define MAXN (int)(5e4)
#define INF (0x3f3f3f3f)
template< typename T = int >
inline T read(void);
int N, M;
int lg[MAXN << 3];
int pos[MAXN << 1];
int a[MAXN];
struct Matrix3{
int val[3][3];
Matrix3(int v00, int v01, int v02, int v10, int v11, int v12, int v20, int v21, int v22):
val{
{v00, v01, v02},
{v10, v11, v12},
{v20, v21, v22}
}{;}
Matrix3(int val[][3]){for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)this->val[i][j] = val[i][j];}
Matrix3(void) = default;
friend const Matrix3 operator * (const Matrix3 &x, const Matrix3 &y){
int val[3][3];
for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)val[i][j] = -INF;
for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)for(int p = 0; p <= 2; ++p)
val[i][j] = max({val[i][j], x.val[i][p] + y.val[p][j], -INF});
return Matrix3(val);
}
void Print(void){
for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)
printf("%d%c", val[i][j], j == 2 ? '\n' : ' ');
}
}mt[MAXN];
class CatTree{
private:
Matrix3 tr[30][MAXN << 1];
#define MID ((gl + gr) >> 1)
#define LS (p << 1)
#define RS (LS | 1)
public:
void Build(int p = 1, int dep = 1, int gl = 1, int gr = N){
// printf("Building l = %d, r = %d, p = %d, dep = %d\n", gl, gr, p, dep);
if(gl == gr)return pos[gl = gr] = p, void();
tr[dep][MID] = mt[MID];
tr[dep][MID + 1] = mt[MID + 1];
//Tips: Matrix operation does not have the commutative law.
for(int i = MID - 1; i >= gl; --i)tr[dep][i] = mt[i] * tr[dep][i + 1];
for(int i = MID + 1 + 1; i <= gr; ++i)tr[dep][i] = tr[dep][i - 1] * mt[i];
Build(LS, dep + 1, gl, MID);
Build(RS, dep + 1, MID + 1, gr);
}
Matrix3 Query(int l, int r){
if(l == r)return mt[l = r];
int dep = lg[pos[l]] - lg[pos[l] ^ pos[r]];
// tr[dep][l].Print(), tr[dep][r].Print();
return tr[dep][l] * tr[dep][r];
}
}ct;
int main(){
N = read();
int rN(1); while(rN < N)rN <<= 1;
for(int i = 1; i <= N; ++i)
a[i] = read(),
mt[i] = Matrix3(0, -INF, -INF, a[i], a[i], -INF, a[i], a[i], 0);
N = rN;
lg[0] = lg[1] = 1;
for(int i = 2; i <= (N << 2) + 10; ++i)lg[i] = lg[i >> 1] + 1;
ct.Build();
M = read();
for(int m = 1; m <= M; ++m){
int l = read(), r = read();
auto ans = Matrix3(-INF, 0, 0, -INF, -INF, -INF, -INF, -INF, -INF) * ct.Query(l, r);
printf("%d\n", ans.val[0][0]);
}
fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
return 0;
}
template < typename T >
inline T read(void){
T ret(0);
short flag(1);
char c = getchar();
while(c != '-' && !isdigit(c))c = getchar();
if(c == '-')flag = -1, c = getchar();
while(isdigit(c)){
ret *= 10;
ret += int(c - '0');
c = getchar();
}
ret *= flag;
return ret;
}
例题 #2
题面:给定长度为 $ N $ 的仅包含 0
,1
,?
的字符串 $ S $,给定 $ Q $ 组询问 $ (x_1, c_1), (x_2, c_2), \cdots, (x_q, c_q) $,每次将原字符串中 $ x_i $ 位置的字符改为 $ c_i $,然后输出 $ S $ 有多少种非空子串,?
需任意替换为 0
或 1
。
$ 1 \le N, Q \le 10^5, 1 \le x_i \le N $。
我们先不考虑修改,思考对于这样一个字符串能有多少种子串,显然这个东西是个 DP。
令 $ dp(i, 0 / 1) $ 表示考虑前 $ i $ 位以 $ 0 $ 或 $ 1 $ 结尾的方案数,显然有如下转移:
若 $ S_i = 1 $,有:
若 $ S_i = 0 $,有:
若 $ S_i = \texttt{?} $,有:
应该不难理解吧?如果状态里是当前这一位,那么可以接到上一个状态任意的结尾,接上这一位之后都会是一个符合要求的新串,或者丢弃以前的直接让这一位成为一个新串。反之就直接由上一次的转移而来,把这一位丢弃,而 $ \texttt{?} $ 可以认为是任意一个,所以可以接在任意一个的后面。
这里有一个细节可以解释一下,在这一位和状态相同的时候我们会 $ +1 $,但是如果之前就有一个孤立的(这时很显然会有的),比如 $ 0 $,那么我们这一位又一个新的 $ 0 $ 难道不会重吗?这显然是不会的,因为上一个状态的 $ 0 $ 在这一次的讨论中已经变成 $ 00 $ 了,只有在讨论到上界的时候才会存在真正的孤立的 $ 0 $ 或 $ 1 $,于是显然可以保证不重不漏。
这么一大坨分类讨论显然很难维护,那么我们尝试把这些缩到一起:
关于式子中的方括号括起来的表达式,这个一般用来表示如果里面的表达式为 $ \texttt{true} $,那么这个东西值为 $ 1 $,反之为 $ 0 $。
现在,我们就可以对这个简洁的式子搞事情了,因为我们后面要有很多次修改,这样的话显然可以考虑 DDP。
我们再次尝试设计出转移的矩阵:
不难算出:不难算出我们算不出来。
于是考虑再加一维
然后尝试推一下 $ T $,这东西个人认为实际上就是个蒙猜凑,总之最后弄一下就能出来这个(注意这里我们用的是普通矩乘规则,非广义):
(不要问我为什么这题用不到广义矩乘还要在前面说一大堆广义矩乘的内容,我记错了,技多不压身)
然后类比之前的过程,令:
那么一直拆下去,一定有:
依然尝试找到一个这样的等式:
所以可以最终化为:
终于推完了,不难发现我们每次的修改就是 $ A_i $ 的值,然后查询整个区间。
这玩意有修改显然用不了猫树了,老老实实写线段树吧。。
复杂度大概是 $ O(\xi^3(n + q \log n)) $,其中 $ \xi = 3 $。
#define _USE_MATH_DEFINES
#include <bits/extc++.h>
#define PI M_PI
#define E M_E
#define npt nullptr
#define SON i->to
#define OPNEW void* operator new(size_t)
#define ROPNEW(arr) void* Edge::operator new(size_t){static Edge* P = arr; return P++;}
using namespace std;
using namespace __gnu_pbds;
mt19937 rnd(random_device{}());
int rndd(int l, int r){return rnd() % (r - l + 1) + l;}
bool rnddd(int x){return rndd(1, 100) <= x;}
typedef unsigned int uint;
typedef unsigned long long unll;
typedef long long ll;
typedef long double ld;
#define MAXN (int)(1e5 + 100)
#define MOD (int)(998244353)
template< typename T = int >
inline T read(void);
int N, Q;
int S[MAXN];
struct Matrix3{
int val[3][3];
Matrix3(int v00, int v01, int v02, int v10, int v11, int v12, int v20, int v21, int v22):
val{
{v00, v01, v02},
{v10, v11, v12},
{v20, v21, v22}
}{;}
Matrix3(int S):
val{
{1, S != 0, 0},
{S != 1, 1, 0},
{S != 1, S != 0, 1}
}{;}
Matrix3(int val[][3]){for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)this->val[i][j] = val[i][j];}
Matrix3(void) = default;
friend const Matrix3 operator * (const Matrix3 &x, const Matrix3 &y){
int val[3][3]; memset(val, 0, sizeof val);
for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)for(int p = 0; p <= 2; ++p)
val[i][j] = ((ll)val[i][j] + (ll)x.val[i][p] * y.val[p][j] % MOD) % MOD;
return Matrix3(val);
}
void Print(void){
for(int i = 0; i <= 2; ++i)for(int j = 0; j <= 2; ++j)
printf("%d%c", val[i][j], j == 2 ? '\n' : ' ');
}
}mt[MAXN];
class SegTree{
private:
Matrix3 tr[MAXN << 2];
#define LS (p << 1)
#define RS (LS | 1)
#define MID ((gl + gr) >> 1)
public:
void Pushup(int p){tr[p] = tr[LS] * tr[RS];}
void Build(int p = 1, int gl = 1, int gr = N){
if(gl == gr)return tr[p] = mt[gl = gr], void();
Build(LS, gl, MID);
Build(RS, MID + 1, gr);
Pushup(p);
}
void Modify(int idx, Matrix3 v, int p = 1, int gl = 1, int gr = N){
if(gl == gr)return tr[p] = v, void();
if(idx <= MID)Modify(idx, v, LS, gl, MID);
else Modify(idx, v, RS, MID + 1, gr);
Pushup(p);
}
Matrix3 Query(void){return tr[1];}
}st;
int main(){
N = read(), Q = read();
string s; cin >> s;
for(int i = 1; i <= (int)s.size(); ++i)
S[i] = s.at(i - 1) == '?' ? -1 : int(s.at(i - 1) - '0'),
mt[i] = Matrix3(S[i]);
st.Build();
Matrix3 origin(0, 0, 1, 0, 0, 0, 0, 0, 0);
while(Q--){
int p = read();
char c = getchar(); while(c != '0' && c != '1' && c != '?')c = getchar();
int flag = c == '?' ? -1 : int(c - '0');
st.Modify(p, Matrix3(flag));
auto ans = origin * st.Query();
printf("%d\n", (int)((ll)(ans.val[0][0] + ans.val[0][1]) % MOD));
}
fprintf(stderr, "Time: %.6lf\n", (double)clock() / CLOCKS_PER_SEC);
return 0;
}
template < typename T >
inline T read(void){
T ret(0);
short flag(1);
char c = getchar();
while(c != '-' && !isdigit(c))c = getchar();
if(c == '-')flag = -1, c = getchar();
while(isdigit(c)){
ret *= 10;
ret += int(c - '0');
c = getchar();
}
ret *= flag;
return ret;
}
例题 #3
正在写~
UPD
update-2022_10_22 初稿