【数据结构】线段树-区间修改

新模板,使用了模板和默认值、幺元等,分离node中的数据维护部分。当维护的数据与l,r有关时,节点内需要存l,r,可能导致空间开销比较大。当query的merge的成本太高时,不能用下面这个模板,必须用旧模板。旧模板中query的流程和每个节点中的数据相关,也就是说query是可以理解sum[u]和sum[ls], sum[rs]的,方便进行短路运算以及与l,r有关的操作(如CF-1902D)以及线段树二分等需要控制短路计算左右子树的问题。

新模板适用于常规的加法和、最小值、最大值、区间反转、字符串哈希等,特征为:对query的合并是O1/Ologn的。使用新模板把数据的操作全部交给node自行计算并合并,然后即可在query中取出。线段树本身不理解node中存储的数据

template<typename>
struct default_value;

template<>
struct default_value<int> {
    static constexpr int empty_value = 0;
    static constexpr int min_value = -INF;
    static constexpr int max_value = INF;
};

template<>
struct default_value<ll> {
    static constexpr ll empty_value = 0LL;
    static constexpr ll min_value = -LINF;
    static constexpr ll max_value = LINF;
};


static constexpr int calc_segment_tree_size (int n) {
    int res = 1;
    while (res < (n << 1)) {
        res <<= 1;
    }
    return res;
}


struct segment_tree {
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

    static const int MAXN = 2e5 + 10;
    static constexpr int MAXN_SEGMENT_TREE_SIZE = calc_segment_tree_size (MAXN);

    using value_type = ll;

    static const value_type empty_value = default_value<value_type>::empty_value;
    static const value_type min_value = default_value<value_type>::min_value;
    static const value_type max_value = default_value<value_type>::max_value;

    int n;

    struct node_type {
        int id;
        int len;
        value_type tag;
        value_type sum;
        value_type mi;
        value_type mx;

        node_type (int u = 0, int l = 0, int r = 0) {
            id = u;
            len = (r - l + 1);
            tag = empty_value;
            sum = empty_value;
            mi = max_value;
            mx = min_value;
        }

        node_type (int u, int l, int r, value_type value) {
            id = u;
            len = (r - l + 1);
            tag = empty_value;
            sum = value;
            mi = value;
            mx = value;
        }

        void merge (const node_type& left_node, const node_type& right_node) {
            len = left_node.len + right_node.len;
            sum = left_node.sum + right_node.sum;
            mi = min (left_node.mi, right_node.mi);
            mx = max (left_node.mx, right_node.mx);
        }

        // 加法
        void add (value_type value) {
            tag += value;
            sum += 1LL * len * value;
            mi += value;
            mx += value;
        }

        // 设值
//        void set (value_type value) {
//            tag = value;
//            sum = 1LL * len * value;
//            mi = value;
//            mx = value;
//        }

    } node[MAXN_SEGMENT_TREE_SIZE];

    void push_up (int u, int l, int r) {
        node[u].merge (node[ls], node[rs]);
    }

    void push_down (int u, int l, int r) {
        value_type t = node[u].tag;
        if (t != empty_value) {
            // 区间加法
            node[ls].add (t);
            node[rs].add (t);
            // 区间设值
//            node[ls].set (t);
//            node[rs].set (t);
            // 区间设值的懒标记要用特殊标记
            node[u].tag = empty_value;
        }
    }

    void build (int u, int l, int r, value_type* a) {
        if (l == r) {
            if (a == nullptr) {
                node[u] = node_type (u, l, r);
            } else {
                node[u] = node_type (u, l, r, a[l]);
            }
            return;
        }
        build (ls, l, mid, a);
        build (rs, mid + 1, r, a);
        push_up (u, l, r);
    }

    // 区间加法
    void add (int u, int l, int r, int L, int R, value_type value) {
        if (L > R || L > r || R < l) {
            return;
        }
        if (L <= l && r <= R) {
            node[u].add (value);
            return;
        }
        push_down (u, l, r);
        add (ls, l, mid, L, R, value);
        add (rs, mid + 1, r, L, R, value);
        push_up (u, l, r);
    }

    // 区间设值
//    void add (int u, int l, int r, int L, int R, value_type value) {
//        if (L > R || L > r || R < l) {
//            return;
//        }
//        if (L <= l && r <= R) {
//            node[u].set (value);
//            return;
//        }
//        push_down (u, l, r);
//        set (ls, l, mid, L, R, value);
//        set (rs, mid + 1, r, L, R, value);
//        push_up (u, l, r);
//    }

    node_type query (int u, int l, int r, int L, int R) {
        if (L > R || L > r || R < l) {
            return node_type();
        }
        if (L <= l && r <= R)
            return node[u];
        push_down (u, l, r);
        node_type l_res = query (ls, l, mid, L, R);
        node_type r_res = query (rs, mid + 1, r, L, R);
        node_type res;
        res.merge (l_res, r_res);
        return res;
    }

    void Build (int n, value_type* a = nullptr) {
        this->n = n;
        build (1, 1, n, a);
    }

    // 区间加法
    void Add (int L, int R, value_type value) {
        add (1, 1, n, L, R, value);
    }

    // 区间设值
//    void Set (int L, int R, value_type value) {
//        set (1, 1, n, L, R, value);
//    }

    node_type Query (int L, int R) {
        return query (1, 1, n, L, R);
    }

    value_type QueryMin (int L, int R) {
        return Query (L, R).mi;
    }

    value_type QueryMax (int L, int R) {
        return Query (L, R).mx;
    }

    value_type QuerySum (int L, int R) {
        return Query (L, R).sum;
    }

#undef ls
#undef rs
#undef mid

} st;

旧模板,比较熟悉比较好改动。

区间加法

struct SegmentTree {

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

    static const int MAXN = 2e5 + 10;
    static const ll LINF = 1e18 + 10;

    ll tag[MAXN << 2];
    ll mi[MAXN << 2];
    ll mx[MAXN << 2];
    ll sum[MAXN << 2];

    void PushUp(int u) {
        mi[u] = min(mi[ls], mi[rs]);
        mx[u] = max(mx[ls], mx[rs]);
        sum[u] = sum[ls] + sum[rs];
    }

    void PushDown(int u, int l, int r) {
        ll t = tag[u];
        if(t != 0) {
            tag[ls] += t, mi[ls] += t, mx[ls] += t;
            sum[ls] += 1LL * (mid - l + 1) * t;
            tag[rs] += t, mi[rs] += t, mx[rs] += t;
            sum[rs] += 1LL * (r - mid) * t;
            tag[u] = 0;
        }
    }

    void Build(int u, int l, int r) {
        tag[u] = 0;
        if(l == r) {
            mi[u] = mx[u] = sum[u] = 0;
            return;
        }
        Build(ls, l, mid);
        Build(rs, mid + 1, r);
        PushUp(u);
    }

    void UpdateAdd(int u, int l, int r, int L, int R, ll v) {
        if(L <= l && r <= R) {
            tag[u] += v, mi[u] += v, mx[u] += v;
            sum[u] += 1LL * (r - l + 1) * v;
            return;
        }
        PushDown(u, l, r);
        if(L <= mid) UpdateAdd(ls, l, mid, L, R, v);
        if(R >= mid + 1) UpdateAdd(rs, mid + 1, r, L, R, v);
        PushUp(u);
    }

    ll QueryMin(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return mi[u];
        PushDown(u, l, r);
        ll res = LINF;
        if(L <= mid) res = min(res, QueryMin(ls, l, mid, L, R));
        if(R >= mid + 1) res = min(res, QueryMin(rs, mid + 1, r, L, R));
        return res;
    }

    ll QueryMax(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return mx[u];
        PushDown(u, l, r);
        ll res = -LINF;
        if(L <= mid) res = max(res, QueryMax(ls, l, mid, L, R));
        if(R >= mid + 1) res = max(res, QueryMax(rs, mid + 1, r, L, R));
        return res;
    }

    ll QuerySum(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return sum[u];
        PushDown(u, l, r);
        ll res = 0;
        if(L <= mid) res += QuerySum(ls, l, mid, L, R);
        if(R >= mid + 1) res += QuerySum(rs, mid + 1, r, L, R);
        return res;
    }

#undef ls
#undef rs
#undef mid

} st;

区间设值

struct SegmentTree {

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

    static const int MAXN = 2e5 + 10;
    static const ll LINF = 1e18 + 10;

    ll tag[MAXN << 2];
    ll mi[MAXN << 2];
    ll mx[MAXN << 2];
    ll sum[MAXN << 2];

    void PushUp(int u) {
        mi[u] = min(mi[ls], mi[rs]);
        mx[u] = max(mx[ls], mx[rs]);
        sum[u] = sum[ls] + sum[rs];
    }

    void PushDown(int u, int l, int r) {
        ll t = tag[u];
        if(t != LINF) {
            tag[ls] = mi[ls] = mx[ls] = t;
            sum[ls] = 1LL * (mid - l + 1) * t;
            tag[rs] = mi[rs] = mx[rs] = t;
            sum[rs] = 1LL * (r - mid) * t;
            tag[u] = LINF;
        }
    }

    void Build(int u, int l, int r) {
        tag[u] = LINF;
        if(l == r) {
            mi[u] = mx[u] = sum[u] = 0;
            return;
        }
        Build(ls, l, mid);
        Build(rs, mid + 1, r);
        PushUp(u);
    }

    void UpdateSet(int u, int l, int r, int L, int R, ll v) {
        if(L <= l && r <= R) {
            tag[u] = mi[u] = mx[u] = v;
            sum[u] = 1LL * (r - l + 1) * v;
            return;
        }
        PushDown(u, l, r);
        if(L <= mid) UpdateSet(ls, l, mid, L, R, v);
        if(R >= mid + 1) UpdateSet(rs, mid + 1, r, L, R, v);
        PushUp(u);
    }

    ll QueryMin(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return mi[u];
        PushDown(u, l, r);
        ll res = LINF;
        if(L <= mid) res = min(res, QueryMin(ls, l, mid, L, R));
        if(R >= mid + 1) res = min(res, QueryMin(rs, mid + 1, r, L, R));
        return res;
    }

    ll QueryMax(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return mx[u];
        PushDown(u, l, r);
        ll res = -LINF;
        if(L <= mid) res = max(res, QueryMax(ls, l, mid, L, R));
        if(R >= mid + 1) res = max(res, QueryMax(rs, mid + 1, r, L, R));
        return res;
    }

    ll QuerySum(int u, int l, int r, int L, int R) {
        if(L <= l && r <= R)
            return sum[u];
        PushDown(u, l, r);
        ll res = 0;
        if(L <= mid) res += QuerySum(ls, l, mid, L, R);
        if(R >= mid + 1) res += QuerySum(rs, mid + 1, r, L, R);
        return res;
    }

#undef ls
#undef rs
#undef mid

} st;

混合标记:
加法和乘法的混合标记

超省心版:

struct SegmentTree {

private:

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

    static const int MAXN = 3e5 + 10;
    static const ll LINF = 1e18 + 10;

    int n;
    struct Node {
        ll sum, mi, ma;
        ll len, tag;
        Node() {
            sum = 0, mi = 0, ma = 0;
            len = 1, tag = 0;
            return;
        }
        Node& operator=(ll val) {
            sum = 1LL * len * val, mi = val, ma = val;
            tag = 0;
            return *this;
        }
        Node& operator+=(ll val) {
            sum += 1LL * len * val, mi += val, ma += val;
            tag += val;
            return *this;
        }
    } node[MAXN << 2];

    Node merge(const Node &x, const Node &y) {
        Node res;
        res.sum = x.sum + y.sum;
        res.mi = min(x.mi, y.mi);
        res.ma = min(x.ma, y.ma);
        res.len = x.len + y.len;
        res.tag = 0;
        return move(res);
    }

    void pull(int u) {
        node[u] = merge(node[ls], node[rs]);
        return;
    }

    void push(int u) {
        if (node[u].tag != 0) {
            node[ls] += node[u].tag;
            node[rs] += node[u].tag;
            node[u].tag = 0;
        }
        return;
    }

    void iBuild(int u, int l, int r) {
        if (l > r)
            return;
        if (l == r) {
            node[u] = Node();
            return;
        }
        iBuild(ls, l, mid);
        iBuild(rs, mid + 1, r);
        pull(u);
        return;
    }

    void iUpdate(int u, int l, int r, int lpos, int rpos, ll val) {
        if (l > r || lpos > r || rpos < l)
            return;
        if (lpos <= l && r <= rpos) {
            node[u] += val;
            return;
        }
        push(u);
        iUpdate(ls, l, mid, lpos, rpos, val);
        iUpdate(rs, mid + 1, r, lpos, rpos, val);
        pull(u);
        return;
    }

    Node iQuery(int u, int l, int r, int lpos, int rpos) {
        if (l > r || lpos > r || rpos < l)
            return Node();
        if (lpos <= l && r <= rpos)
            return node[u];
        push(u);
        Node resL = iQuery(ls, l, mid, lpos, rpos);
        Node resR = iQuery(rs, mid + 1, r, lpos, rpos);
        return merge(resL, resR);
    }

#undef ls
#undef rs
#undef mid

public:

    void build(int n) {
        this->n = n;
        iBuild(1, 1, n);
        return;
    }

    void update(int lpos, int rpos, ll val) {
        iUpdate(1, 1, n, lpos, rpos, val);
        return;
    }

    Node query(int lpos, int rpos) {
        return iQuery(1, 1, n, lpos, rpos);
    }

    ll querySum(int lpos, int rpos) {
        return query(lpos, rpos).sum;
    }

    ll queryMin(int lpos, int rpos) {
        return query(lpos, rpos).mi;
    }

    ll queryMax(int lpos, int rpos) {
        return query(lpos, rpos).ma;
    }

} st;
posted @ 2020-11-26 22:22  purinliang  阅读(119)  评论(0编辑  收藏  举报