从矩阵角度理解吉司机线段树
前几天打算学写吉司机线段树,写到区间历史最值的时候炸了,这些标记的复杂性让我有点望而却步,但是当我看到 warzone 大佬的矩阵角度理解吉司机线段树时,我知道这就是我想看的东西。
作为我学完之后的总结,我决定写一篇学习笔记。
warzone 的题解更为简洁,而我写的这篇会稍微更加详细一点,毕竟还是要写给自己看的。
默认读者会区间最值操作,不会的左转 OI-WIKI。话说只有会了的才会来看这个罢。
我不太会写引入,引入部分建议直接看 warzone 的题解。
好罢,在多次阅读 warzone 的题解后我发现他的题解实际上已经写的很详细了,我只是看过题解后有不懂的地方,还是我菜,现在我觉得我能写的就只有把自己的疑问写下来,然后以问答的形式来完成这篇学习笔记。
\(\text Q1:\)
为什么对于 \(\overrightarrow{mx}\) 和 \(\overrightarrow{se}\) 都各只需要一个标记来进行维护,区间加操作以及区间最值操作标记不会冲突吗?
\(\text A1:\)
题解里其实准确说过,我们可以通过将区间最值操作转为区间加操作进行,所以区间最值操作的标记也就是可以转化为加法标记(对应到矩阵上是乘法标记)。
\(\text Q2:\)
这个回答是不完整的,如果线段树当前所辖区间下面有些节点是无法被最值操作所修改的,那么推到乘法标记上一起下传不就让一些本不会被修改的值被修改了吗?
\(\text A2:\)
实际上我们下传标记时是有对这种情况进行判断的,设要被传的节点 \(y\) 的父亲是 \(x\),如果 \(x\) 的最大值减去当前的标记会等于 \(y\) 的最大值,那么就把 \(\overrightarrow{mx}\) 的乘法标记下传,否则下传 \(\overrightarrow{se}\) 的标记。因为次大值是不会含有区间最值操作标记的,所以就相当于只传了区间加操作的标记。至于为什么可以这样判断,是因为 \(x\) 是已经加过之前的加法操作了,而标记也就含有之前的加法操作,相减就相当于恢复了原来的最大值,也就可以判断了。
\(\text Q3:\)
集成求和操作?
\(\text A3:\)
最大值更改加上个数乘权值,次大值更改加上区间长减最大值个数再乘上权值。
\(\text Q4:\)
为什么求 \(B\) 的时候还要结合 \(\overrightarrow{se}\) 的然后再上传?
\(\text A4:\)
因为我们上传信息的时候只能保证最大值划入 \(\overrightarrow{mx}\) 不能保证历史最大值划入。但如果强行划入也是可以的,但这样子会加一些小常数,因为上传信息是很多次的,不如只在询问的时候结合。反正我改成强行划入后跑出了我的最劣解。
注意这里的矩阵乘法是不满足交换律的,看错了会使脑子混乱。
原题解写的是 zkw 线段树,对于我这种没学过的人属实有点不友好,我写的函数式线段树,供大家参考。
warzone 大佬的代码使用 zkw 本来应该要更快的,但始终没有跑进 8s 档,我个人认为应该是 define l k<<1
这种含计算的替换的大量使用导致增加了常数。
采用 VSCode 格式化码风,不喜轻喷。
#include <bits/stdc++.h>
#include <bits/extc++.h>
#define ll long long
#define ull unsigned long long
#define m_p make_pair
#define m_t make_tuple
#define N 500010
#define inf 0x7f7f7f7f
#define Minf ~0x7f7f7f7f
using namespace std;
using namespace __gnu_pbds;
struct Mat
{
int a, b;
Mat operator+(Mat B)
{
Mat Ans;
Ans.a = max(a, B.a);
Ans.b = max(b, B.b);
return Ans;
}
Mat operator*(Mat B)
{
Mat Ans;
if (a == Minf || B.a == Minf)
Ans.a = Minf;
else
Ans.a = a + B.a;
if (b == Minf || B.a == Minf)
Ans.b = B.b;
else
Ans.b = max(b + B.a, B.b);
return Ans;
}
} I = {0, Minf}, INF = {Minf, Minf};
struct Seg
{
ll Sum;
int Cnt;
Mat mx, se, tgm, tgs;
} tr[N << 2];
int A[N], X, Y, W;
Seg merinf(Seg x, Seg y)
{
if (x.mx.a < y.mx.a)
{
x.Cnt = y.Cnt;
x.se = x.se + y.se + x.mx;
x.mx = y.mx;
}
else if (x.mx.a == y.mx.a)
{
x.Cnt += y.Cnt;
x.mx = x.mx + y.mx;
x.se = x.se + y.se;
}
else
x.se = x.se + y.se + y.mx;
x.tgm = x.tgs = I;
x.Sum += y.Sum;
return x;
}
void build(int k, int l, int r)
{
if (l == r)
{
tr[k] = {A[l], 1, {A[l], A[l]}, INF, I, I};
return;
}
int mid = l + r >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
tr[k] = merinf(tr[k << 1], tr[k << 1 | 1]);
return;
}
void base_cmx(int k, Mat w)
{
tr[k].mx = w * tr[k].mx;
tr[k].tgm = w * tr[k].tgm;
tr[k].Sum += 1ll * tr[k].Cnt * w.a;
return;
}
void base_cse(int k, Mat w, int siz)
{
tr[k].se = w * tr[k].se;
tr[k].tgs = w * tr[k].tgs;
tr[k].Sum += 1ll * (siz - tr[k].Cnt) * w.a;
return;
}
void tr_pd(int k, int l, int r, int mid)
{
base_cmx(k << 1, tr[k << 1].mx.a == tr[k].mx.a - tr[k].tgm.a ? tr[k].tgm : tr[k].tgs);
base_cmx(k << 1 | 1, tr[k << 1 | 1].mx.a == tr[k].mx.a - tr[k].tgm.a ? tr[k].tgm : tr[k].tgs);
base_cse(k << 1, tr[k].tgs, mid - l + 1);
base_cse(k << 1 | 1, tr[k].tgs, r - mid );
tr[k].tgm = tr[k].tgs = I;
return;
}
void tr_min(int k, int l, int r)
{
if (tr[k].mx.a <= W || l > Y || r < X)
return;
if (l >= X && r <= Y && tr[k].se.a < W)
{
base_cmx(k, {W - tr[k].mx.a, W - tr[k].mx.a});
return;
}
int mid = l + r >> 1;
tr_pd(k, l, r, mid);
if (X <= mid)
tr_min(k << 1, l, mid);
if (Y > mid)
tr_min(k << 1 | 1, mid + 1, r);
tr[k] = merinf(tr[k << 1], tr[k << 1 | 1]);
return;
}
void tr_p(int k, int l, int r)
{
if (l > Y || r < X)
return;
if (l >= X && r <= Y)
{
base_cmx(k, {W, W});
base_cse(k, {W, W}, r - l + 1);
return;
}
int mid = l + r >> 1;
tr_pd(k, l, r, mid);
if (X <= mid)
tr_p(k << 1, l, mid);
if (Y > mid)
tr_p(k << 1 | 1, mid + 1, r);
tr[k] = merinf(tr[k << 1], tr[k << 1 | 1]);
return;
}
ll tr_s(int k, int l, int r)
{
if (l > Y || r < X)
return 0;
if (l >= X && r <= Y)
return tr[k].Sum;
int mid = l + r >> 1;
ll ans = 0;
tr_pd(k, l, r, mid);
if (X <= mid)
ans += tr_s(k << 1, l, mid);
if (Y > mid)
ans += tr_s(k << 1 | 1, mid + 1, r);
return ans;
}
int tr_maxa(int k, int l, int r)
{
if (l > Y || r < X)
return Minf;
if (l >= X && r <= Y)
return tr[k].mx.a;
int mid = l + r >> 1, ans = Minf;
tr_pd(k, l, r, mid);
if (X <= mid)
ans = max(ans, tr_maxa(k << 1, l, mid));
if (Y > mid)
ans = max(ans, tr_maxa(k << 1 | 1, mid + 1, r));
return ans;
}
int tr_maxb(int k, int l, int r)
{
if (l > Y || r < X)
return Minf;
if (l >= X && r <= Y)
return max(tr[k].mx.b, tr[k].se.b);
int mid = l + r >> 1, ans = Minf;
tr_pd(k, l, r, mid);
if (X <= mid)
ans = max(ans, tr_maxb(k << 1, l, mid));
if (Y > mid)
ans = max(ans, tr_maxb(k << 1 | 1, mid + 1, r));
return ans;
}
signed main()
{
// freopen("dc.in", "r", stdin);
// freopen("test.out", "w", stdout);
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, m, opt;
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> A[i];
build(1, 1, n);
while (m--)
{
cin >> opt >> X >> Y;
if (opt == 1)
{
cin >> W;
tr_p(1, 1, n);
}
else if (opt == 2)
{
cin >> W;
tr_min(1, 1, n);
}
else if (opt == 3)
cout << tr_s(1, 1, n) << "\n";
else if (opt == 4)
cout << tr_maxa(1, 1, n) << "\n";
else
cout << tr_maxb(1, 1, n) << "\n";
}
return 0;
}
终究只是对题解的拙劣模仿罢了。