SegmentTreeBeats 简单学习笔记

SegmentTreeBeats 简单学习笔记

​ 有一天补 \(\text{CF}\) 做到一个题,转化一波题意以后变成要求维护一个序列 \(a\)

  1. 对于 \(i \in [l,r], a_i =a_i+x\)

  2. 对于 \(i \in [l,r], a_i =\min(a_i, x)\)

  3. \(\sum_{i=l}^r a_i\)

    ​ 其实就是 \(\text{Segment Tree Beats}\) 的模板题,也就是那年吉老师营员交流课件的例题, 用线段树维护区间最大值 \(mx\) ,区间次大值 \(se\) ,区间和 \(sum\) ,区间最大值出现次数 \(cnt\) ,加法标记 \(tag\)

    ​ 对于第二种操作,如果一个区间 \(mx \leq x\) 那么无事发生,可以跳过其所有子区间。如果 \(se < x < mx\) ,那么 \(sum = sum - (mx-x)\times cnt, mx = x\) ,注意这里子区间的 \(mx, sum\) 并没有更改,相当于 \(mx\) 同时作为一个修改标记,当前区间比子区间的 \(mx\) 小时,要进行 \(sum = sum - (mx-mx[fa]) \times cnt\)\(\text{pushdown}\) 操作,打完标记之后就可以跳过了。对于其它情况,暴力对其子区间求解。

    ​ 到这一步位置算法流程不难理解,但算法的复杂度证明比较难懂,目前 \(\mathcal O(n\log n)\) 的证明我还不会,只能理解 \(\mathcal O(n \log^2 n)\) 的证明,在这里写一下简要证明:

    ​ 定义势能函数 \(\Phi\) 为线段树中 \(mx\) 不等于其父亲节点 \(mx\) 的节点数量,考虑一次第二操作过程的任一终止节点 \(v\) 。如果 \(v\)\(\Phi\) 有贡献,假设这一类节点的数量为 \(A\) ,到达这些节点的复杂度为 \(\mathcal O (A\log n)\) ,结束后这些节点都对势能没贡献了,也就是说用了 \(\mathcal O(A\log n)\) 的时间让势能减小了 \(A\)

    ​ 如果 \(v\)\(\Phi\) 没贡献,记 \(u\)\(v\) 的父亲,\(u\) 的另外一儿子为 \(c\) ,那么 \(mx[u] = mx[v], se[u] \neq se[v]\) ,也就是说 \(se[u] = mx[c]\) 。那么 \(c\) 的子树一定会被访问, 并在访问结束后 \(c\)\(\Phi\) 没有贡献,假设这一类节点数量为 \(A\) ,同样也用 \(\mathcal O(A\log n)\) 的时间让势能减小了 \(A\) 。也就是说对于修改操作,实际上是每减小一个势能用了 \(\mathcal O(\log n)\) 的代价。

    ​ 考虑修改操作,每次只会修改 \(\mathcal O(\log n)\) 节点,最多使势能增加 \(\mathcal O(\log n)\) 所以总复杂度是 \(\mathcal O(n\log^2 n)\)

    code: Codeforces 1290 E

    /*program by mangoyang*/
    #pragma GCC optimize("Ofast", "inline")
    #include<bits/stdc++.h>
    #define inf (0x7f7f7f7f)
    #define Max(a, b) ((a) > (b) ? (a) : (b))
    #define Min(a, b) ((a) < (b) ? (a) : (b))
    typedef long long ll;
    using namespace std;
    template <class T>
    inline void read(T &x){
        int ch = 0, f = 0; x = 0;
        for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
        for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
        if(f) x = -x;
    }
    #define int ll
    const int N = 150005;
    int a[N], b[N], ans[N], n;
    namespace Seg{
        #define lson (u << 1)
        #define rson (u << 1 | 1)
        #define mid ((l + r) >> 1)
        int mx[N<<2], se[N<<2], sz[N<<2], cnt[N<<2], sum[N<<2], tag[N<<2];
        inline void clear(){
            memset(mx, 0, sizeof(mx));
            memset(se, 0, sizeof(se));
            memset(sz, 0, sizeof(sz));
            memset(cnt, 0, sizeof(cnt));
            memset(sum, 0, sizeof(sum));
            memset(tag, 0, sizeof(tag));
        }
        inline void update(int u){
            if(mx[lson] > mx[rson])
                mx[u] = mx[lson], cnt[u] = cnt[lson];
            else
                mx[u] = mx[rson], cnt[u] = cnt[rson];
            if(mx[lson] == mx[rson]) cnt[u] += cnt[lson];
            se[u] = max(se[lson], se[rson]);
            if(mx[lson] != mx[rson]){
                int x = min(mx[lson], mx[rson]);
                se[u] = max(se[u], x);
            }
            sum[u] = sum[lson] + sum[rson];
            sz[u] = sz[lson] + sz[rson];
        }
        inline void pushdown(int u){
            if(tag[u]){
                if(mx[lson]) mx[lson] += tag[u];
                if(se[lson]) se[lson] += tag[u];
                if(mx[rson]) mx[rson] += tag[u];
                if(se[rson]) se[rson] += tag[u];
                sum[lson] += tag[u] * sz[lson];
                sum[rson] += tag[u] * sz[rson];
                tag[lson] += tag[u];
                tag[rson] += tag[u];
                tag[u] = 0;
            }
            if(mx[lson] > mx[u]){
                sum[lson] -= (mx[lson] - mx[u]) * cnt[lson];
                mx[lson] = mx[u];
            }
            if(mx[rson] > mx[u]){
                sum[rson] -= (mx[rson] - mx[u]) * cnt[rson];
                mx[rson] = mx[u];
            }
        }
        inline void ins(int u, int l, int r, int pos, int x){
            if(l == r){
                mx[u] = sum[u] = x;
                sz[u] = cnt[u] = 1;
                return;
            }
            pushdown(u);
            if(pos <= mid) ins(lson, l, mid, pos, x);
            else ins(rson, mid + 1, r, pos, x);
            update(u);
        }
        inline void gao(int u, int l, int r, int L, int R, int x){
            if(l >= L && r <= R){
                if(mx[u] <= x) return;
                if(se[u] < x){
                    sum[u] -= (mx[u] - x) * cnt[u];
                    mx[u] = x;
                    return;
                }
                pushdown(u);
                gao(lson, l, mid, L, R, x);
                gao(rson, mid + 1, r, L, R, x);
                update(u);
                return;
            }
            pushdown(u);
            if(L <= mid) gao(lson, l, mid, L, R, x);
            if(mid < R) gao(rson, mid + 1, r, L, R, x);
            update(u);
        }
        inline void add(int u, int l, int r, int L, int R){
            if(l >= L && r <= R){
                if(mx[u]) mx[u]++;
                if(se[u]) se[u]++;
                sum[u] += sz[u], tag[u]++;
                return;
            }
            pushdown(u);
            if(L <= mid) add(lson, l, mid, L, R);
            if(mid < R) add(rson, mid + 1, r, L, R);
            update(u);
        }
        inline int query(int u, int l, int r, int L, int R){
            if(l >= L && r <= R) return sz[u];
            int res = 0; pushdown(u);
            if(L <= mid) res += query(lson, l, mid, L, R);
            if(mid < R) res += query(rson, mid + 1, r, L, R);
            return res;
        }
    }
    signed main(){
        read(n);
        for(int i = 1; i <= n; i++) read(a[i]);
        for(int i = 1; i <= n; i++) b[a[i]] = i;
        for(int i = 1; i <= n; i++){
            Seg::add(1, 1, n, b[i] + 1, n);
            int sz = Seg::query(1, 1, n, 1, b[i]);
            if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
            Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
            ans[i] = Seg::sum[1] + Seg::sz[1];
        }
        reverse(a + 1, a + n + 1);
        for(int i = 1; i <= n; i++) b[a[i]] = i;
        Seg::clear();
            for(int i = 1; i <= n; i++){
            Seg::add(1, 1, n, b[i] + 1, n);
            int sz = Seg::query(1, 1, n, 1, b[i]);
            if(sz) Seg::gao(1, 1, n, 1, b[i], sz);
            Seg::ins(1, 1, n, b[i], Seg::sz[1] + 1);
            ans[i] -= Seg::sz[1] * (Seg::sz[1] + 1) - Seg::sum[1];
        }
        for(int i = 1; i <= n; i++)
            printf("%lld\n", ans[i]);
        return 0;
    }
    
posted @ 2020-03-25 17:51  Joyemang33  阅读(1136)  评论(0编辑  收藏  举报