Segment Tree Beats 学习笔记
A simple introduction to "Segment tree beats"
区间最值
以「 区间取 \(\min\),查询区间和」为例,线段树节点需储存 \(mx,smx,cnt,sum\) 四个信息,即最大值,严格次大值,最大值个数,区间和。更新信息:
void update(int x, int l, int r, int L, int R, int val){
if(t[x].mx <= val) return;
if(l >= L && r <= R && t[x].smx < val){ addtag(x, val); return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, val);
if(mid < R) update(x*2+1, mid+1, r, L, R, val);
pushup(x);
}
在只有区间 \(\min,\max\) 操作时,时间复杂度为 \(O(n\log n)\),当有其他区间修改操作时,时间复杂度为 \(O(n\log^2n)\),但实际表现和 \(1\) 个 \(\log\) 差不多。
这样的处理方式本质上就是对最大值或最小值专门进行维护,于是可以将信息分成两类,最值和非最值,两种分开维护,而区间 \(\min,\max\) 操作可以转化为对最值的区间加减操作。
历史最值
此处考虑的完整问题是:区间取 \(\min,\max\)、区间加、区间历史最大值、区间历史最大值之和。记 \(A_i\) 为原数组,\(B_i\) 为历史最大值数组。
区间加,区间最大历史最大值:
在加法懒标记 \(Add\) 之外再维护一个历史最大加减标记 \(Pre\),表示从上一次标记下传至今 \(Add\) 达到过的最大值,合并:\(Pre_{son}=\max(Pre_{son},Add_{son}+Pre_x),\;Add_{son}=Add_{son}+Add_x\),\(O(n\log n)\) 。
只有区间查询历史最小值/最大值:
将最值和非最值分开维护,则操作全部转化为区间加减操作,可以沿用 \(Pre\) 懒标记进行维护,\(O(n\log^2n)\) 。
无区间 \(\min,\max\) 操作:
- 区间历史最大值 & 历史最大值之和:记 \(C_i=A_i-B_i\),区间加转化为 \(C_i\rightarrow \min(C_i+x,0)\),\(O(n\log^2n)\)
- 区间历史版本之和:令 \(t\) 为当前已结束的操作数,记 \(C_i=B_i+t\cdot A_i\),区间加转化为 \(C_i\rightarrow C_i-x\cdot t\),\(O(n\log n)\)。
有区间 \(\min,\max\) 操作:
将最值和非最值分开维护 \(C_i\),则转化为上面的「无区间 \(\min,\max\) 操作」问题,在分开维护部分会多出 \(1\) 个 \(\log\),从而「区间历史最大值」和「历史最大值之和」 \(O(n\log^3n)\),「区间历史版本之和」 \(O(n\log^2n)\) 。
一些简单的例题&实现
洛谷P6242 【模板】线段树 3
即前面区间最值中的「无区间 \(\min,\max\) 操作」,线段树节点维护 \(A_i\) 最大值、\(B_i\) 最大值、\(A_i\) 严格次小值、\(A_i\) 最大值个数、\(A_i\) 最大值和非最大值的加法标记,\(B_i\) 最大值和非最大值的历史加法标记(即 \(Pre\)),然后就可以了。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 505000
#define ll long long
#define Inf 0x7fffffff
using namespace std;
int n, m, a[N];
struct SegmentTreeBeats{
struct node{
int mx_a, smx, cnt, mx_b;
ll sum, add_a1, add_a2, add_b1, add_b2;
// add_a1, add_a2: lazy add tag add_b1, add_b2: historical tag
} t[N<<2];
void pushup(int x){
t[x].mx_a = max(t[x*2].mx_a, t[x*2+1].mx_a);
t[x].mx_b = max(t[x*2].mx_b, t[x*2+1].mx_b);
t[x].sum = t[x*2].sum + t[x*2+1].sum;
if(t[x*2].mx_a == t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx_a > t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].mx_a);
else t[x].smx = max(t[x*2].mx_a, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx_a >= t[x*2+1].mx_a) * t[x*2].cnt + (t[x*2+1].mx_a >= t[x*2].mx_a) * t[x*2+1].cnt;
}
// a: to max, b: to historical max, c: to non max, d: to historical non max
void addtag(int x, int l, int r, ll a, ll b, ll c, ll d){
t[x].sum += a*t[x].cnt + c*(r-l+1-t[x].cnt);
t[x].mx_b = max((ll)t[x].mx_b, t[x].mx_a + b);
t[x].add_b1 = max(t[x].add_b1, t[x].add_a1 + b);
t[x].add_b2 = max(t[x].add_b2, t[x].add_a2 + d);
t[x].mx_a += a, t[x].add_a1 += a, t[x].add_a2 += c;
if(t[x].smx > -Inf) t[x].smx += c;
}
void pushdown(int x, int l, int r){
int mid = (l+r)>>1, mx = max(t[x*2].mx_a, t[x*2+1].mx_a);
ll add_a1 = t[x].add_a1, add_a2 = t[x].add_a2, add_b1 = t[x].add_b1, add_b2 = t[x].add_b2;
if(t[x*2].mx_a == mx) addtag(x*2, l, mid, add_a1, add_b1, add_a2, add_b2);
else addtag(x*2, l, mid, add_a2, add_b2, add_a2, add_b2);
if(t[x*2+1].mx_a == mx) addtag(x*2+1, mid+1, r, add_a1, add_b1, add_a2, add_b2);
else addtag(x*2+1, mid+1, r, add_a2, add_b2, add_a2, add_b2);
t[x].add_a1 = t[x].add_b1 = t[x].add_a2 = t[x].add_b2 = 0;
}
void build(int x, int l, int r){
if(l == r){
t[x].sum = t[x].mx_a = t[x].mx_b = a[l];
t[x].smx = -Inf, t[x].cnt = 1;
return;
}
int mid = (l+r)>>1;
build(x*2, l, mid), build(x*2+1, mid+1, r);
pushup(x);
}
void update_add(int x, int l, int r, int L, int R, int k){
if(l >= L && r <= R){ addtag(x, l, r, k, k, k, k); return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update_add(x*2, l, mid, L, R, k);
if(mid < R) update_add(x*2+1, mid+1, r, L, R, k);
pushup(x);
}
void update_min(int x, int l, int r, int L, int R, int val){
if(t[x].mx_a <= val) return;
if(l >= L && r <= R && t[x].smx < val){ addtag(x, l, r, val-t[x].mx_a, val-t[x].mx_a, 0, 0); return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update_min(x*2, l, mid, L, R, val);
if(mid < R) update_min(x*2+1, mid+1, r, L, R, val);
pushup(x);
}
ll query(int x, int l, int r, int L, int R, int id){
if(l >= L && r <= R) return id == 1 ? t[x].sum : (id == 2 ? t[x].mx_a : t[x].mx_b);
int mid = (l+r)>>1;
pushdown(x, l, r);
ll a = (mid >= L) ? query(x*2, l, mid, L, R, id) : (id == 1 ? 0 : -Inf);
ll b = (mid < R) ? query(x*2+1, mid+1, r, L, R, id) : (id == 1 ? 0 : -Inf);
return (id == 1 ? a+b : max(a, b));
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n>>m;
rep(i,1,n) cin>>a[i];
T.build(1, 1, n);
int type, l, r, k;
while(m--){
cin>>type>>l>>r;
switch(type){
case 1 : cin>>k, T.update_add(1, 1, n, l, r, k); break;
case 2 : cin>>k, T.update_min(1, 1, n, l, r, k); break;
default : cout<< T.query(1, 1, n, l, r, type-2) <<endl;
}
}
return 0;
}
Codeforces 1290E Cartesian Tree
给定一个长为 \(n\) 的排列的 \(a_i\),对于每个 \(k\in[1,n]\),以前 \(k\) 小的值(保持排列中的顺序)建笛卡尔树,求出所有子树大小之和。\(1\leq n\leq 150000\)
记 \(l_i,r_i\) 分别为 \(a_i\) 从左右得到的最后一个比其小的数的位置(保留前 \(k\) 小的值组成序列的位置),则节点 \(i\) 的子树大小即 \(r_i-l_i+1\),于是答案即为 \(\sum r_i-\sum l_i+k\)。分开维护 \(l_i\) 和 \(r_i\),以 \(r_i\) 为例,容易发现当 \(k\) 增加 \(1\),\(k+1\) 插入到序列中时,\(pos_{k+1}\) 左侧的 \(r_i\rightarrow \min(r_i,pos_{k+1})\),而右侧的 \(r_i\) 因为左边添加了元素,\(r_i\rightarrow r_i+1\),\(l_i\) 也是类似的处理。所以我们使用 Segment Tree Beats 维护区间 \(\min\)、区间加、全局和即可,\(O(n\log^2n)\) 。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 160000
#define Inf 0x3f3f3f3f
#define ll long long
#define lowbit(x) (x&-x)
using namespace std;
int n;
int a[N], pos[N];
struct Segment_Tree_Beats{
struct node{
int mx, smx, cnt, num, tag = -1, lazy;
ll sum;
} t[N<<2];
void pushup(int x){
t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
t[x].num = t[x*2].num + t[x*2+1].num;
t[x].sum = t[x*2].sum + t[x*2+1].sum;
if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
}
void pushadd(int x, int k){
t[x].sum += (ll)k * t[x].num;
if(t[x].num){
t[x].mx += k, t[x].lazy += k;
if(t[x].smx > 0) t[x].smx += k;
if(~t[x].tag) t[x].tag += k;
}
}
void pushmin(int x, int k){
if(t[x].mx <= k) return;
t[x].sum -= (ll)t[x].cnt * (t[x].mx - k);
t[x].mx = k, t[x].tag = k;
}
void pushdown(int x){
if(t[x].lazy) pushadd(x*2, t[x].lazy), pushadd(x*2+1, t[x].lazy);
if(~t[x].tag) pushmin(x*2, t[x].tag), pushmin(x*2+1, t[x].tag);
t[x].tag = -1, t[x].lazy = 0;
}
void insert(int x, int l, int r, int pos, int val){
if(l == r){ t[x].sum = t[x].mx = val, t[x].cnt = t[x].num = 1; return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= pos) insert(x*2, l, mid, pos, val);
else insert(x*2+1, mid+1, r, pos, val);
pushup(x);
}
void update(int x, int l, int r, int L, int R, int k, int id){
if(L > R) return;
if(id && t[x].mx <= k) return;
if(l >= L && r <= R){
if(id && t[x].smx < k){ pushmin(x, k); return; }
else if(!id){ pushadd(x, k); return; }
}
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, k, id);
if(mid < R) update(x*2+1, mid+1, r, L, R, k, id);
pushup(x);
}
} LB, RB;
struct Fenwick{
int t[N];
void update(int pos, int k){
while(pos <= n) t[pos] += k, pos += lowbit(pos);
}
int get(int pos){
int ret = 0;
while(pos) ret += t[pos], pos -= lowbit(pos);
return ret;
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n;
rep(i,1,n) cin>>a[i], pos[a[i]] = i;
rep(i,1,n){
RB.update(1, 1, n, pos[i]+1, n, 1, 0);
RB.update(1, 1, n, 1, pos[i]-1, T.get(pos[i]), 1);
LB.update(1, 1, n, 1, n-pos[i]+1, -1, 0);
LB.update(1, 1, n, 1, n-pos[i]+1, n-T.get(pos[i])-1, 1);
RB.insert(1, 1, n, pos[i], i), LB.insert(1, 1, n, n-pos[i]+1, n);
cout<< RB.t[1].sum - ((ll)i*(n+1) - LB.t[1].sum) + i <<endl;
T.update(pos[i], 1);
}
return 0;
}
Codeforces 1572F Stations
有 \(n\) 个城市,每个城市有两个属性 \(h_i,w_i\),第 \(i\) 个城市的广播可以覆盖到所有 \([i,w_i]\) 中满足 \(\max_{i<k\leq j}\{h_k\}<h_i\) 的城市 \(j\)。
初始时所有 \(h_i=0,w_i=i\),每次操作可以使得城市 \(c_i\) 的 \(h\) 成为全局严格最大值,并修改 \(w_{c_i}\),或者询问 \([l,r]\) 中,对于每个城市来说广播可以覆盖到它的城市个数之和。
\(1\leq n\leq 2\times 10^5\)
显然每个城市覆盖的区域是一个从 \(i\) 开始的区间,而维护区间右端点是一个「 单点修改,区间取 \(\min\)」问题,另外用一棵线段树维护区间覆盖,注意到 Segment Tree Beats 里值的修改是直接进行的,所以修改时可以顺带在另一棵线段树上进行区间修改,即打 \(tag\) 时维护即可。询问时直接在线段树查询区间和。时间复杂度 \(O(n\log^2n)\) 。
#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
using namespace std;
int n, q;
struct SegmentTree{
struct node{
ll sum, lazy;
} t[N<<2];
void pushdown(int x, int l, int r){
int mid = (l+r)>>1;
t[x*2].sum += (mid-l+1) * t[x].lazy, t[x*2].lazy += t[x].lazy;
t[x*2+1].sum += (r-mid) * t[x].lazy, t[x*2+1].lazy += t[x].lazy;
t[x].lazy = 0;
}
void update(int x, int l, int r, int L, int R, ll k){
if(L > R) return;
if(l >= L && r <= R){ t[x].sum += (r-l+1) * k, t[x].lazy += k; return; }
int mid = (l+r)>>1;
pushdown(x, l, r);
if(mid >= L) update(x*2, l, mid, L, R, k);
if(mid < R) update(x*2+1, mid+1, r, L, R, k);
t[x].sum = t[x*2].sum + t[x*2+1].sum;
}
ll get(int x, int l, int r, int L, int R){
if(l >= L && r <= R) return t[x].sum;
int mid = (l+r)>>1; ll ret = 0;
pushdown(x, l, r);
if(mid >= L) ret += get(x*2, l, mid, L, R);
if(mid < R) ret += get(x*2+1, mid+1, r, L, R);
return ret;
}
} BIT;
struct Segment_Tree_Beats{
struct node{
int mx, smx, cnt, tag = -1;
} t[N<<2];
void pushup(int x){
t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
}
void addtag(int x, int k, bool frt){
if(t[x].mx <= k) return;
if(frt) BIT.update(1, 1, n, k+1, t[x].mx, -t[x].cnt);
t[x].mx = t[x].tag = k;
}
void pushdown(int x){
if(~t[x].tag)
addtag(x*2, t[x].tag, 0), addtag(x*2+1, t[x].tag, 0);
t[x].tag = -1;
}
void insert(int x, int l, int r, int pos, int val){
if(l == r){
BIT.update(1, 1, n, l, t[x].mx, -t[x].cnt), BIT.update(1, 1, n, l, val, 1);
t[x].mx = val, t[x].cnt = 1; return;
}
int mid = (l+r)>>1;
pushdown(x);
if(mid >= pos) insert(x*2, l, mid, pos, val);
else insert(x*2+1, mid+1, r, pos, val);
pushup(x);
}
void update(int x, int l, int r, int L, int R, int k){
if(t[x].mx <= k || L > R) return;
if(l >= L && r <= R && t[x].smx < k){ addtag(x, k, 1); return; }
int mid = (l+r)>>1;
pushdown(x);
if(mid >= L) update(x*2, l, mid, L, R, k);
if(mid < R) update(x*2+1, mid+1, r, L, R, k);
pushup(x);
}
} T;
int main(){
ios::sync_with_stdio(false);
cin>>n>>q;
rep(i,1,n) T.insert(1, 1, n, i, i);
int type, c, g, l, r;
while(q--){
cin>>type;
if(type == 1){
cin>>c>>g;
T.insert(1, 1, n, c, g), T.update(1, 1, n, 1, c-1, c-1);
} else cin>>l>>r, cout<< BIT.get(1, 1, n, l, r) <<endl;
}
return 0;
}