Segment Tree Beats 学习笔记

2016集训队论文 吉如一《区间最值操作与历史最值问题》

A simple introduction to "Segment tree beats"

区间最值

以「 区间取 \(\min\),查询区间和」为例,线段树节点需储存 \(mx,smx,cnt,sum\) 四个信息,即最大值,严格次大值,最大值个数,区间和。更新信息:

void update(int x, int l, int r, int L, int R, int val){
    if(t[x].mx <= val) return;
    if(l >= L && r <= R && t[x].smx < val){ addtag(x, val); return; }
    int mid = (l+r)>>1;
    pushdown(x);
    if(mid >= L) update(x*2, l, mid, L, R, val);
    if(mid < R) update(x*2+1, mid+1, r, L, R, val);
    pushup(x);
}

在只有区间 \(\min,\max\) 操作时,时间复杂度为 \(O(n\log n)\),当有其他区间修改操作时,时间复杂度为 \(O(n\log^2n)\),但实际表现和 \(1\)\(\log\) 差不多。

这样的处理方式本质上就是对最大值或最小值专门进行维护,于是可以将信息分成两类,最值和非最值,两种分开维护,而区间 \(\min,\max\) 操作可以转化为对最值的区间加减操作。

历史最值

此处考虑的完整问题是:区间取 \(\min,\max\)、区间加、区间历史最大值、区间历史最大值之和。记 \(A_i\) 为原数组,\(B_i\) 为历史最大值数组。

区间加,区间最大历史最大值:

  在加法懒标记 \(Add\) 之外再维护一个历史最大加减标记 \(Pre\),表示从上一次标记下传至今 \(Add\) 达到过的最大值,合并:\(Pre_{son}=\max(Pre_{son},Add_{son}+Pre_x),\;Add_{son}=Add_{son}+Add_x\)\(O(n\log n)\)

只有区间查询历史最小值/最大值:

  将最值和非最值分开维护,则操作全部转化为区间加减操作,可以沿用 \(Pre\) 懒标记进行维护,\(O(n\log^2n)\)

无区间 \(\min,\max\) 操作:

  • 区间历史最大值 & 历史最大值之和:记 \(C_i=A_i-B_i\),区间加转化为 \(C_i\rightarrow \min(C_i+x,0)\)\(O(n\log^2n)\)
  • 区间历史版本之和:令 \(t\) 为当前已结束的操作数,记 \(C_i=B_i+t\cdot A_i\),区间加转化为 \(C_i\rightarrow C_i-x\cdot t\)\(O(n\log n)\)

有区间 \(\min,\max\) 操作:

  将最值和非最值分开维护 \(C_i\),则转化为上面的「无区间 \(\min,\max\) 操作」问题,在分开维护部分会多出 \(1\)\(\log\),从而「区间历史最大值」和「历史最大值之和」 \(O(n\log^3n)\),「区间历史版本之和」 \(O(n\log^2n)\)


 

一些简单的例题&实现

洛谷P6242 【模板】线段树 3

即前面区间最值中的「无区间 \(\min,\max\) 操作」,线段树节点维护 \(A_i\) 最大值、\(B_i\) 最大值、\(A_i\) 严格次小值、\(A_i\) 最大值个数、\(A_i\) 最大值和非最大值的加法标记,\(B_i\) 最大值和非最大值的历史加法标记(即 \(Pre\)),然后就可以了。

#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 505000
#define ll long long
#define Inf 0x7fffffff
using namespace std;

int n, m, a[N];

struct SegmentTreeBeats{
    struct node{
        int mx_a, smx, cnt, mx_b;
        ll sum, add_a1, add_a2, add_b1, add_b2;
        // add_a1, add_a2: lazy add tag   add_b1, add_b2: historical tag
    } t[N<<2];

    void pushup(int x){
        t[x].mx_a = max(t[x*2].mx_a, t[x*2+1].mx_a);
        t[x].mx_b = max(t[x*2].mx_b, t[x*2+1].mx_b);
        t[x].sum = t[x*2].sum + t[x*2+1].sum;
        if(t[x*2].mx_a == t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
        else if(t[x*2].mx_a > t[x*2+1].mx_a) t[x].smx = max(t[x*2].smx, t[x*2+1].mx_a);
        else t[x].smx = max(t[x*2].mx_a, t[x*2+1].smx);
        t[x].cnt = (t[x*2].mx_a >= t[x*2+1].mx_a) * t[x*2].cnt + (t[x*2+1].mx_a >= t[x*2].mx_a) * t[x*2+1].cnt;
    }

    // a: to max, b: to historical max, c: to non max, d: to historical non max
    void addtag(int x, int l, int r, ll a, ll b, ll c, ll d){
        t[x].sum += a*t[x].cnt + c*(r-l+1-t[x].cnt);
        t[x].mx_b = max((ll)t[x].mx_b, t[x].mx_a + b);
        t[x].add_b1 = max(t[x].add_b1, t[x].add_a1 + b);
        t[x].add_b2 = max(t[x].add_b2, t[x].add_a2 + d);
        t[x].mx_a += a, t[x].add_a1 += a, t[x].add_a2 += c;
        if(t[x].smx > -Inf) t[x].smx += c;
    }

    void pushdown(int x, int l, int r){
        int mid = (l+r)>>1, mx = max(t[x*2].mx_a, t[x*2+1].mx_a);
        ll add_a1 = t[x].add_a1, add_a2 = t[x].add_a2, add_b1 = t[x].add_b1, add_b2 = t[x].add_b2;
        if(t[x*2].mx_a == mx) addtag(x*2, l, mid, add_a1, add_b1, add_a2, add_b2);
        else addtag(x*2, l, mid, add_a2, add_b2, add_a2, add_b2);
        if(t[x*2+1].mx_a == mx) addtag(x*2+1, mid+1, r, add_a1, add_b1, add_a2, add_b2);
        else addtag(x*2+1, mid+1, r, add_a2, add_b2, add_a2, add_b2);
        t[x].add_a1 = t[x].add_b1 = t[x].add_a2 = t[x].add_b2 = 0;
    }

    void build(int x, int l, int r){
        if(l == r){
            t[x].sum = t[x].mx_a = t[x].mx_b = a[l];
            t[x].smx = -Inf, t[x].cnt = 1;
            return;
        }
        int mid = (l+r)>>1;
        build(x*2, l, mid), build(x*2+1, mid+1, r);
        pushup(x);
    }

    void update_add(int x, int l, int r, int L, int R, int k){
        if(l >= L && r <= R){ addtag(x, l, r, k, k, k, k); return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update_add(x*2, l, mid, L, R, k);
        if(mid < R) update_add(x*2+1, mid+1, r, L, R, k);
        pushup(x);
    }

    void update_min(int x, int l, int r, int L, int R, int val){
        if(t[x].mx_a <= val) return;
        if(l >= L && r <= R && t[x].smx < val){ addtag(x, l, r, val-t[x].mx_a, val-t[x].mx_a, 0, 0); return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update_min(x*2, l, mid, L, R, val);
        if(mid < R) update_min(x*2+1, mid+1, r, L, R, val);
        pushup(x);
    }

    ll query(int x, int l, int r, int L, int R, int id){
        if(l >= L && r <= R) return id == 1 ? t[x].sum : (id == 2 ? t[x].mx_a : t[x].mx_b); 
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        ll a = (mid >= L) ? query(x*2, l, mid, L, R, id) : (id == 1 ? 0 : -Inf);
        ll b = (mid < R) ? query(x*2+1, mid+1, r, L, R, id) : (id == 1 ? 0 : -Inf);
        return (id == 1 ? a+b : max(a, b));
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n>>m;
    rep(i,1,n) cin>>a[i];
    T.build(1, 1, n);
    int type, l, r, k;
    while(m--){
        cin>>type>>l>>r;
        switch(type){
            case 1 : cin>>k, T.update_add(1, 1, n, l, r, k); break;
            case 2 : cin>>k, T.update_min(1, 1, n, l, r, k); break;
            default : cout<< T.query(1, 1, n, l, r, type-2) <<endl;
        }
    }
    return 0;
}

 

Codeforces 1290E Cartesian Tree

给定一个长为 \(n\) 的排列的 \(a_i\),对于每个 \(k\in[1,n]\),以前 \(k\) 小的值(保持排列中的顺序)建笛卡尔树,求出所有子树大小之和。\(1\leq n\leq 150000\)

\(l_i,r_i\) 分别为 \(a_i\) 从左右得到的最后一个比其小的数的位置(保留前 \(k\) 小的值组成序列的位置),则节点 \(i\) 的子树大小即 \(r_i-l_i+1\),于是答案即为 \(\sum r_i-\sum l_i+k\)。分开维护 \(l_i\)\(r_i\),以 \(r_i\) 为例,容易发现当 \(k\) 增加 \(1\)\(k+1\) 插入到序列中时,\(pos_{k+1}\) 左侧的 \(r_i\rightarrow \min(r_i,pos_{k+1})\),而右侧的 \(r_i\) 因为左边添加了元素,\(r_i\rightarrow r_i+1\)\(l_i\) 也是类似的处理。所以我们使用 Segment Tree Beats 维护区间 \(\min\)、区间加、全局和即可,\(O(n\log^2n)\)

#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 160000
#define Inf 0x3f3f3f3f
#define ll long long
#define lowbit(x) (x&-x)
using namespace std; 

int n;
int a[N], pos[N];

struct Segment_Tree_Beats{
    struct node{
	int mx, smx, cnt, num, tag = -1, lazy;
	ll sum;
    } t[N<<2];

    void pushup(int x){
	t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
	t[x].num = t[x*2].num + t[x*2+1].num;
	t[x].sum = t[x*2].sum + t[x*2+1].sum;
	if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
	else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
	else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
	t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
    }

    void pushadd(int x, int k){
	t[x].sum += (ll)k * t[x].num;
	if(t[x].num){
	    t[x].mx += k, t[x].lazy += k;
	    if(t[x].smx > 0) t[x].smx += k;
	    if(~t[x].tag) t[x].tag += k;
	}
    }

    void pushmin(int x, int k){
	if(t[x].mx <= k) return;
	t[x].sum -= (ll)t[x].cnt * (t[x].mx - k);
	t[x].mx = k, t[x].tag = k;
    }

    void pushdown(int x){
	if(t[x].lazy) pushadd(x*2, t[x].lazy), pushadd(x*2+1, t[x].lazy);
	if(~t[x].tag) pushmin(x*2, t[x].tag), pushmin(x*2+1, t[x].tag);
	t[x].tag = -1, t[x].lazy = 0;
    }

    void insert(int x, int l, int r, int pos, int val){
	if(l == r){ t[x].sum = t[x].mx = val, t[x].cnt = t[x].num = 1; return; }
	int mid = (l+r)>>1;
	pushdown(x);
	if(mid >= pos) insert(x*2, l, mid, pos, val);
	else insert(x*2+1, mid+1, r, pos, val);
	pushup(x);
    }

    void update(int x, int l, int r, int L, int R, int k, int id){
	if(L > R) return;
	if(id && t[x].mx <= k) return;
	if(l >= L && r <= R){
	    if(id && t[x].smx < k){ pushmin(x, k); return; }
	    else if(!id){ pushadd(x, k); return; }
	}
	int mid = (l+r)>>1;
	pushdown(x);
	if(mid >= L) update(x*2, l, mid, L, R, k, id);
	if(mid < R) update(x*2+1, mid+1, r, L, R, k, id);
	pushup(x);
    }
} LB, RB;

struct Fenwick{
    int t[N];
    void update(int pos, int k){
	while(pos <= n) t[pos] += k, pos += lowbit(pos);
    }
    int get(int pos){
	int ret = 0;
	while(pos) ret += t[pos], pos -= lowbit(pos);
	return ret;
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n;
    rep(i,1,n) cin>>a[i], pos[a[i]] = i;

    rep(i,1,n){
	RB.update(1, 1, n, pos[i]+1, n, 1, 0);
	RB.update(1, 1, n, 1, pos[i]-1, T.get(pos[i]), 1);
	LB.update(1, 1, n, 1, n-pos[i]+1, -1, 0);
	LB.update(1, 1, n, 1, n-pos[i]+1, n-T.get(pos[i])-1, 1);
	RB.insert(1, 1, n, pos[i], i), LB.insert(1, 1, n, n-pos[i]+1, n);
	cout<< RB.t[1].sum - ((ll)i*(n+1) - LB.t[1].sum) + i <<endl;
	T.update(pos[i], 1);
    }
    return 0;
}

 

Codeforces 1572F Stations

\(n\) 个城市,每个城市有两个属性 \(h_i,w_i\),第 \(i\) 个城市的广播可以覆盖到所有 \([i,w_i]\) 中满足 \(\max_{i<k\leq j}\{h_k\}<h_i\) 的城市 \(j\)

初始时所有 \(h_i=0,w_i=i\),每次操作可以使得城市 \(c_i\)\(h\) 成为全局严格最大值,并修改 \(w_{c_i}\),或者询问 \([l,r]\) 中,对于每个城市来说广播可以覆盖到它的城市个数之和。

\(1\leq n\leq 2\times 10^5\)

显然每个城市覆盖的区域是一个从 \(i\) 开始的区间,而维护区间右端点是一个「 单点修改,区间取 \(\min\)」问题,另外用一棵线段树维护区间覆盖,注意到 Segment Tree Beats 里值的修改是直接进行的,所以修改时可以顺带在另一棵线段树上进行区间修改,即打 \(tag\) 时维护即可。询问时直接在线段树查询区间和。时间复杂度 \(O(n\log^2n)\)


#include<iostream>
#define rep(i,a,b) for(int i = (a); i <= (b); i++)
#define per(i,b,a) for(int i = (b); i >= (a); i--)
#define N 200021
#define ll long long
using namespace std;

int n, q;

struct SegmentTree{
    struct node{
        ll sum, lazy;
    } t[N<<2];

    void pushdown(int x, int l, int r){
        int mid = (l+r)>>1;
        t[x*2].sum += (mid-l+1) * t[x].lazy, t[x*2].lazy += t[x].lazy;
        t[x*2+1].sum += (r-mid) * t[x].lazy, t[x*2+1].lazy += t[x].lazy;
        t[x].lazy = 0;
    }

    void update(int x, int l, int r, int L, int R, ll k){
        if(L > R) return;
        if(l >= L && r <= R){ t[x].sum += (r-l+1) * k, t[x].lazy += k; return; }
        int mid = (l+r)>>1;
        pushdown(x, l, r);
        if(mid >= L) update(x*2, l, mid, L, R, k);
        if(mid < R) update(x*2+1, mid+1, r, L, R, k);
        t[x].sum = t[x*2].sum + t[x*2+1].sum;
    }

    ll get(int x, int l, int r, int L, int R){
        if(l >= L && r <= R) return t[x].sum;
        int mid = (l+r)>>1; ll ret = 0;
        pushdown(x, l, r);
        if(mid >= L) ret += get(x*2, l, mid, L, R);
        if(mid < R) ret += get(x*2+1, mid+1, r, L, R);
        return ret;
    }
} BIT;

struct Segment_Tree_Beats{
    struct node{
        int mx, smx, cnt, tag = -1;
    } t[N<<2];

    void pushup(int x){
        t[x].mx = max(t[x*2].mx, t[x*2+1].mx);
        if(t[x*2].mx == t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].smx);
        else if(t[x*2].mx > t[x*2+1].mx) t[x].smx = max(t[x*2].smx, t[x*2+1].mx);
        else t[x].smx = max(t[x*2].mx, t[x*2+1].smx);
        t[x].cnt = (t[x*2].mx >= t[x*2+1].mx) * t[x*2].cnt + (t[x*2+1].mx >= t[x*2].mx) * t[x*2+1].cnt;
    }

    void addtag(int x, int k, bool frt){
        if(t[x].mx <= k) return;
        if(frt) BIT.update(1, 1, n, k+1, t[x].mx, -t[x].cnt);
        t[x].mx = t[x].tag = k;
    }

    void pushdown(int x){
        if(~t[x].tag)
            addtag(x*2, t[x].tag, 0), addtag(x*2+1, t[x].tag, 0);
        t[x].tag = -1;
    }

    void insert(int x, int l, int r, int pos, int val){
        if(l == r){ 
            BIT.update(1, 1, n, l, t[x].mx, -t[x].cnt), BIT.update(1, 1, n, l, val, 1);
            t[x].mx = val, t[x].cnt = 1; return; 
        }
        int mid = (l+r)>>1;
        pushdown(x);
        if(mid >= pos) insert(x*2, l, mid, pos, val);
        else insert(x*2+1, mid+1, r, pos, val);
        pushup(x);
    }

    void update(int x, int l, int r, int L, int R, int k){
        if(t[x].mx <= k || L > R) return;
        if(l >= L && r <= R && t[x].smx < k){ addtag(x, k, 1); return; }
        int mid = (l+r)>>1;
        pushdown(x);
        if(mid >= L) update(x*2, l, mid, L, R, k);
        if(mid < R) update(x*2+1, mid+1, r, L, R, k);
        pushup(x);
    }
} T;

int main(){
    ios::sync_with_stdio(false);
    cin>>n>>q;
    rep(i,1,n) T.insert(1, 1, n, i, i);
    int type, c, g, l, r;
    while(q--){
        cin>>type;
        if(type == 1){
            cin>>c>>g;
            T.insert(1, 1, n, c, g), T.update(1, 1, n, 1, c-1, c-1);
        } else cin>>l>>r, cout<< BIT.get(1, 1, n, l, r) <<endl;
    }
    return 0;
}

posted @ 2021-12-15 23:21  Neal_lee  阅读(1058)  评论(0编辑  收藏  举报