SegmentTreeBeats 简单学习笔记
SegmentTreeBeats 简单学习笔记
有一天补 \(\text{CF}\) 做到一个题,转化一波题意以后变成要求维护一个序列 \(a\) 。
-
对于 \(i \in [l,r], a_i =a_i+x\) 。
-
对于 \(i \in [l,r], a_i =\min(a_i, x)\) 。
-
求 \(\sum_{i=l}^r a_i\) 。
其实就是 \(\text{Segment Tree Beats}\) 的模板题,也就是那年吉老师营员交流课件的例题, 用线段树维护区间最大值 \(mx\) ,区间次大值 \(se\) ,区间和 \(sum\) ,区间最大值出现次数 \(cnt\) ,加法标记 \(tag\) 。
对于第二种操作,如果一个区间 \(mx \leq x\) 那么无事发生,可以跳过其所有子区间。如果 \(se < x < mx\) ,那么 \(sum = sum - (mx-x)\times cnt, mx = x\) ,注意这里子区间的 \(mx, sum\) 并没有更改,相当于 \(mx\) 同时作为一个修改标记,当前区间比子区间的 \(mx\) 小时,要进行 \(sum = sum - (mx-mx[fa]) \times cnt\) 的 \(\text{pushdown}\) 操作,打完标记之后就可以跳过了。对于其它情况,暴力对其子区间求解。
到这一步位置算法流程不难理解,但算法的复杂度证明比较难懂,目前 \(\mathcal O(n\log n)\) 的证明我还不会,只能理解 \(\mathcal O(n \log^2 n)\) 的证明,在这里写一下简要证明:
定义势能函数 \(\Phi\) 为线段树中 \(mx\) 不等于其父亲节点 \(mx\) 的节点数量,考虑一次第二操作过程的任一终止节点 \(v\) 。如果 \(v\) 对 \(\Phi\) 有贡献,假设这一类节点的数量为 \(A\) ,到达这些节点的复杂度为 \(\mathcal O (A\log n)\) ,结束后这些节点都对势能没贡献了,也就是说用了 \(\mathcal O(A\log n)\) 的时间让势能减小了 \(A\) 。
如果 \(v\) 对 \(\Phi\) 没贡献,记 \(u\) 为 \(v\) 的父亲,\(u\) 的另外一儿子为 \(c\) ,那么 \(mx[u] = mx[v], se[u] \neq se[v]\) ,也就是说 \(se[u] = mx[c]\) 。那么 \(c\) 的子树一定会被访问, 并在访问结束后 \(c\) 对 \(\Phi\) 没有贡献,假设这一类节点数量为 \(A\) ,同样也用 \(\mathcal O(A\log n)\) 的时间让势能减小了 \(A\) 。也就是说对于修改操作,实际上是每减小一个势能用了 \(\mathcal O(\log n)\) 的代价。
考虑修改操作,每次只会修改 \(\mathcal O(\log n)\) 节点,最多使势能增加 \(\mathcal O(\log n)\) 所以总复杂度是 \(\mathcal O(n\log^2 n)\)。
code: Codeforces 1290 E
/*program by mangoyang*/ #pragma GCC optimize("Ofast", "inline") #include<bits/stdc++.h> #define inf (0x7f7f7f7f) #define Max(a, b) ((a) > (b) ? (a) : (b)) #define Min(a, b) ((a) < (b) ? (a) : (b)) typedef long long ll; using namespace std; template <class T> inline void read(T &x){ int ch = 0, f = 0; x = 0; for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1; for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48; if(f) x = -x; } #define int ll const int N = 150005; int a[N], b[N], ans[N], n; namespace Seg{ #define lson (u << 1) #define rson (u << 1 | 1) #define mid ((l + r) >> 1) int mx[N<<2], se[N<<2], sz[N<<2], cnt[N<<2], sum[N<<2], tag[N<<2]; inline void clear(){ memset(mx, 0, sizeof(mx)); memset(se, 0, sizeof(se)); memset(sz, 0, sizeof(sz)); memset(cnt, 0, sizeof(cnt)); memset(sum, 0, sizeof(sum)); memset(tag, 0, sizeof(tag)); } inline void update(int u){ if(mx[lson] > mx[rson]) mx[u] = mx[lson], cnt[u] = cnt[lson]; else mx[u] = mx[rson], cnt[u] = cnt[rson]; if(mx[lson] == mx[rson]) cnt[u] += cnt[lson]; se[u] = max(se[lson], se[rson]); if(mx[lson] != mx[rson]){ int x = min(mx[lson], mx[rson]); se[u] = max(se[u], x); } sum[u] = sum[lson] + sum[rson]; sz[u] = sz[lson] + sz[rson]; } inline void pushdown(int u){ if(tag[u]){ if(mx[lson]) mx[lson] += tag[u]; if(se[lson]) se[lson] += tag[u]; if(mx[rson]) mx[rson] += tag[u]; if(se[rson]) se[rson] += tag[u]; sum[lson] += tag[u] * sz[lson]; sum[rson] += tag[u] * sz[rson]; tag[lson] += tag[u]; tag[rson] += tag[u]; tag[u] = 0; } if(mx[lson] > mx[u]){ sum[lson] -= (mx[lson] - mx[u]) * cnt[lson]; mx[lson] = mx[u]; } if(mx[rson] > mx[u]){ sum[rson] -= (mx[rson] - mx[u]) * cnt[rson]; mx[rson] = mx[u]; } } inline void ins(int u, int l, int r, int pos, int x){ if(l == r){ mx[u] = sum[u] = x; sz[u] = cnt[u] = 1; return; } pushdown(u); if(pos <= mid) ins(lson, l, mid, pos, x); else ins(rson, mid + 1, r, pos, x); update(u); } inline void gao(int u, int l, int r, int L, int R, int x){ if(l >= L && r <= R){ if(mx[u] <= x) return; if(se[u] < x){ sum[u] -= (mx[u] - x) * cnt[u]; mx[u] = x; return; } pushdown(u); gao(lson, l, mid, L, R, x); gao(rson, mid + 1, r, L, R, x); update(u); return; } pushdown(u); if(L <= mid) gao(lson, l, mid, L, R, x); if(mid < R) gao(rson, mid + 1, r, L, R, x); update(u); } inline void add(int u, int l, int r, int L, int R){ if(l >= L && r <= R){ if(mx[u]) mx[u]++; if(se[u]) se[u]++; sum[u] += sz[u], tag[u]++; return; } pushdown(u); if(L <= mid) add(lson, l, mid, L, R); if(mid < R) add(rson, mid + 1, r, L, R); update(u); } inline int query(int u, int l, int r, int L, int R){ if(l >= L && r <= R) return sz[u]; int res = 0; pushdown(u); if(L <= mid) res += query(lson, l, mid, L, R); if(mid < R) res += query(rson, mid + 1, r, L, R); return res; } } signed main(){ read(n); for(int i = 1; i <= n; i++) read(a[i]); for(int i = 1; i <= n; i++) b[a[i]] = i; for(int i = 1; i <= n; i++){ Seg::add(1, 1, n, b[i] + 1, n); int sz = Seg::query(1, 1, n, 1, b[i]); if(sz) Seg::gao(1, 1, n, 1, b[i], sz); Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1); ans[i] = Seg::sum[1] + Seg::sz[1]; } reverse(a + 1, a + n + 1); for(int i = 1; i <= n; i++) b[a[i]] = i; Seg::clear(); for(int i = 1; i <= n; i++){ Seg::add(1, 1, n, b[i] + 1, n); int sz = Seg::query(1, 1, n, 1, b[i]); if(sz) Seg::gao(1, 1, n, 1, b[i], sz); Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1); ans[i] -= Seg::sz[1] * (Seg::sz[1] + 1) - Seg::sum[1]; } for(int i = 1; i <= n; i++) printf("%lld\n", ans[i]); return 0; }