线段树模板

线段树是一种通用的数据结构,能够处理满足结合律的信息。

前置知识

线段树

基础版

struct node {
    int l, r;
    // TODO: information and tag
    int lazy, val;
    // int sum;
} tr[N * 4];

void modify(int p, int l, int r, int v) {
    // tr[p].lazy += d, tr[p].val += (l - r + 1) * v;
}

void pushup(int u) {
    // tr[u].v = tr[u >> 1].val + tr[u >> 1 | 1].val;
}

void pushdown(int u, int l, int r) {
    if (tr[u].lazy) {
        int mid = l + r >> 1;
        modify(u >> 1, l, mid, tr[u].lazy);
        modify(u >> 1 | 1, mid + 1, r, tr[u].lazy);
        tr[u].lazy = 0;
    }
}

void build(int u, int l, int r) {
    if (l == r) tr[u] = {l, r};
    else {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void update(int u, int l, int r, int v) {
    if (tr[u].l >= l && tr[u].r <= r) {
        modify(u, l, r, v);
        return;
    }
    
    pushdown(u, l, r);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid) update(u << 1, l, r, v);
    if (r > mid) update(u << 1 | 1, l, r, v);
    pushup(u);
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) {
        return ;  // TODO return value
        // return tree[u].sum;
    } else {
        pushdown(u, l, r);
        int mid = tr[u].l + tr[u].r >> 1;
        int res = 0;
        if (l <= mid) res = query(u << 1, l, r);
        if (r > mid) res += query(u << 1 | 1, l, r);
        return res;
    }
}

zkw 版

代码转自 有趣的 zkw 线段树(超全详解) + 轻微压行

const int M = 1e6;
int n, m, q, sum[M << 2 + 7], mn[M << 2 + 7], mx[M << 2 + 7], add[M << 2 + 7];

inline void build() {
    for (m = 1; m <= n;) m <<= 1;
    for (int i = m + 1; i <= m + n; ++i) cin >> mx[i], sum[i] = mn[i] = mx[i];
    for (int i = m - 1; i; --i) {
        sum[i] = sum[i << 1] + sum[i << 1 | 1];
        mn[i] = min(mn[i << 1], mn[i << 1 | 1]), mn[i << 1] -= mn[i], mn[i << 1 | 1] -= mn[i];
        mx[i] = max(mx[i << 1], mx[i << 1 | 1]), mx[i << 1] -= mx[i], mx[i << 1 | 1] -= mx[i];
    }
}

inline void update_node(int x, int v, int A = 0) {
    x += m, mx[x] += v, mn[x] += v, sum[x] += v;
    for (; x > 1; x >>= 1) {
        sum[x] += v;
        A = min(mn[x], mn[x ^ 1]);
        mn[x] -= A, mn[x ^ 1] -= A, mn[x >> 1] += A;
        A = max(mx[x], mx[x ^ 1]), mx[x] -= A, mx[x ^ 1] -= A, mx[x >> 1] += A;
    }
}

inline void update_part(int s, int t, int v) {
    int A = 0, lc = 0, rc = 0, len = 1;
    for (s += m - 1, t += m + 1; s ^ t ^ 1; s >>= 1, t >>= 1, len <<= 1) {
        if ((s & 1) ^ 1) add[s ^ 1] += v, lc += len, mn[s ^ 1] += v, mx[s ^ 1] += v;
        if (t & 1) add[t ^ 1] += v, rc += len, mn[t ^ 1] += v, mx[t ^ 1] += v;
        sum[s >> 1] += v * lc, sum[t >> 1] += v * rc;
        A = min(mn[s], mn[s ^ 1]), mn[s] -= A, mn[s ^ 1] -= A, mn[s >> 1] += A, A = min(mn[t], mn[t ^ 1]),
        mn[t] -= A, mn[t ^ 1] -= A, mn[t >> 1] += A;
        A = max(mx[s], mx[s ^ 1]), mx[s] -= A, mx[s ^ 1] -= A, mx[s >> 1] += A, A = max(mx[t], mx[t ^ 1]),
        mx[t] -= A, mx[t ^ 1] -= A, mx[t >> 1] += A;
    }
    for (lc += rc; s; s >>= 1) {
        sum[s >> 1] += v * lc;
        A = min(mn[s], mn[s ^ 1]), mn[s] -= A, mn[s ^ 1] -= A, mn[s >> 1] += A, A = max(mx[s], mx[s ^ 1]),
        mx[s] -= A, mx[s ^ 1] -= A, mx[s >> 1] += A;
    }
}

inline int query_node(int x, int ans = 0) {
    for (x += m; x; x >>= 1) ans += mn[x];
    return ans;
}

inline int query_sum(int s, int t) {
    int lc = 0, rc = 0, len = 1, ans = 0;
    for (s += m - 1, t += m + 1; s ^ t ^ 1; s >>= 1, t >>= 1, len <<= 1) {
        if ((s & 1) ^ 1) ans += sum[s ^ 1] + len * add[s ^ 1], lc += len;
        if (t & 1) ans += sum[t ^ 1] + len * add[t ^ 1], rc += len;
        if (add[s >> 1]) ans += add[s >> 1] * lc;
        if (add[t >> 1]) ans += add[t >> 1] * rc;
    }
    for (lc += rc, s >>= 1; s; s >>= 1)
        if (add[s]) ans += add[s] * lc;
    return ans;
}

inline int query_min(int s, int t, int L = 0, int R = 0, int ans = 0) {
    if (s == t) return query_node(s);
    for (s += m, t += m; s ^ t ^ 1; s >>= 1, t >>= 1) {
        L += mn[s], R += mn[t];
        if ((s & 1) ^ 1) L = min(L, mn[s ^ 1]);
        if (t & 1) R = min(R, mn[t ^ 1]);
    }
    for (ans = min(L, R), s >>= 1; s; s >>= 1) ans += mn[s];
    return ans;
}

inline int query_max(int s, int t, int L = 0, int R = 0, int ans = 0) {
    if (s == t) return query_node(s);
    for (s += m, t += m; s ^ t ^ 1; s >>= 1, t >>= 1) {
        L += mx[s], R += mx[t];
        if ((s & 1) ^ 1) L = max(L, mx[s ^ 1]);
        if (t & 1) R = max(R, mx[t ^ 1]);
    }
    for (ans = max(L, R), s >>= 1; s; s >>= 1) ans += mx[s];
    return ans;
}

STL 版

参考链接 1

参考链接 2

AC-Library

AC-Library Examples

int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}

template <class S, S (*op)(S, S), S (*e)()>
struct segtree {
public:
    segtree() : segtree(0) {}
    explicit segtree(int n) : segtree(std::vector<S>(n, e())) {}
    explicit segtree(const std::vector<S>& v) : _n(int(v.size())) {
        log = ceil_pow2(_n);
        size = 1 << log;
        d = std::vector<S>(2 * size, e());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) const {
        assert(0 <= p && p < _n);
        return d[p + size];
    }

    S prod(int l, int r) const {
        assert(0 <= l && l <= r && r <= _n);
        S sml = e(), smr = e();
        l += size;
        r += size;

        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return op(sml, smr);
    }

    S all_prod() const { return d[1]; }

    template <bool (*f)(S)>
    int max_right(int l) const {
        return max_right(l, [](S x) { return f(x); });
    }
    template <class F>
    int max_right(int l, F f) const {
        assert(0 <= l && l <= _n);
        assert(f(e()));
        if (l == _n) return _n;
        l += size;
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(op(sm, d[l]))) {
                while (l < size) {
                    l = (2 * l);
                    if (f(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*f)(S)>
    int min_left(int r) const {
        return min_left(r, [](S x) { return f(x); });
    }
    template <class F>
    int min_left(int r, F f) const {
        assert(0 <= r && r <= _n);
        assert(f(e()));
        if (r == 0) return 0;
        r += size;
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(op(d[r], sm))) {
                while (r < size) {
                    r = (2 * r + 1);
                    if (f(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

private:
    int _n, size, log;
    std::vector<S> d;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};

template <class S, S (*op)(S, S), S (*e)(), class F, S (*mapping)(F, S), F (*composition)(F, F), F (*id)()>
struct lazy_segtree {
public:
    lazy_segtree() : lazy_segtree(0) {}
    explicit lazy_segtree(int n) : lazy_segtree(std::vector<S>(n, e())) {}
    explicit lazy_segtree(const std::vector<S>& v) : _n(int(v.size())) {
        log = ceil_pow2(_n);
        size = 1 << log;
        d = std::vector<S>(2 * size, e());
        lz = std::vector<F>(size, id());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }

    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }

    S get(int p) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        return d[p];
    }

    S prod(int l, int r) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return e();

        l += size;
        r += size;

        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push((r - 1) >> i);
        }

        S sml = e(), smr = e();
        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }

        return op(sml, smr);
    }

    S all_prod() { return d[1]; }

    void apply(int p, F f) {
        assert(0 <= p && p < _n);
        p += size;
        for (int i = log; i >= 1; i--) push(p >> i);
        d[p] = mapping(f, d[p]);
        for (int i = 1; i <= log; i++) update(p >> i);
    }
    void apply(int l, int r, F f) {
        assert(0 <= l && l <= r && r <= _n);
        if (l == r) return;

        l += size;
        r += size;

        for (int i = log; i >= 1; i--) {
            if (((l >> i) << i) != l) push(l >> i);
            if (((r >> i) << i) != r) push((r - 1) >> i);
        }

        {
            int l2 = l, r2 = r;
            while (l < r) {
                if (l & 1) all_apply(l++, f);
                if (r & 1) all_apply(--r, f);
                l >>= 1;
                r >>= 1;
            }
            l = l2;
            r = r2;
        }

        for (int i = 1; i <= log; i++) {
            if (((l >> i) << i) != l) update(l >> i);
            if (((r >> i) << i) != r) update((r - 1) >> i);
        }
    }

    template <bool (*g)(S)>
    int max_right(int l) {
        return max_right(l, [](S x) { return g(x); });
    }
    template <class G>
    int max_right(int l, G g) {
        assert(0 <= l && l <= _n);
        assert(g(e()));
        if (l == _n) return _n;
        l += size;
        for (int i = log; i >= 1; i--) push(l >> i);
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!g(op(sm, d[l]))) {
                while (l < size) {
                    push(l);
                    l = (2 * l);
                    if (g(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }

    template <bool (*g)(S)>
    int min_left(int r) {
        return min_left(r, [](S x) { return g(x); });
    }
    template <class G>
    int min_left(int r, G g) {
        assert(0 <= r && r <= _n);
        assert(g(e()));
        if (r == 0) return 0;
        r += size;
        for (int i = log; i >= 1; i--) push((r - 1) >> i);
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!g(op(d[r], sm))) {
                while (r < size) {
                    push(r);
                    r = (2 * r + 1);
                    if (g(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }

private:
    int _n, size, log;
    std::vector<S> d;
    std::vector<F> lz;

    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
    void all_apply(int k, F f) {
        d[k] = mapping(f, d[k]);
        if (k < size) lz[k] = composition(f, lz[k]);
    }
    void push(int k) {
        all_apply(2 * k, lz[k]);
        all_apply(2 * k + 1, lz[k]);
        lz[k] = id();
    }
};

传入参数解释

S 为线段树节点类型 推荐用 struct。\(e.g.\) struct S { int sum, min, max; };

操作 \(op\) 即为基础版中的 \(pushup\) 函数

幺元 \(e\) 为线段树的初始值

F 相当于线段树的懒标记

S mapping(F f, S x) 即为基础版中的 \(modify\) 函数

F composition(F l, F r) 即为基础版中的 \(pushup\) 函数(不确定)

F id() 为恒等映射

void seg.set(int p, S x) \(\Leftrightarrow\) a[p] = x

时间复杂度
\(O(\log{n})\)

S seg.get(int p) \(\Leftrightarrow\) return a[p]

时间复杂度
\(O(1)\)

S seg.prod(int l, int r) \(\Leftrightarrow\) return op(a[l]...a[r - 1])

时间复杂度
\(O(\log{n})\)

S seg.all_prod() \(\Leftrightarrow\) return op(a[0]...a[n - 1]

时间复杂度
\(O(\log{n})\)

注意一切下标从 0 开始!

单点修改&区间和

typedef long long ll;

struct S {
	ll sum, len;
};

S op(S l, S r) { return S{ l.sum + r.sum, l.len + r.len }; }

S e() {	return S{ 0,0 }; }

typedef long long F;

S mapping(F f, S x) { return { x.sum + f * x.len, x.len }; }

F composition(F l, F r) { return l + r; }

F id() { return 0; }

加上下面的代码可以 AC 【模板】线段树 1 - 洛谷

int n, m, p, x, y, k;
vector<S> a;
int main() {
	cin >> n >> m;
	for (int i = 0; i < n; i++) {
		int x; cin >> x;
		a.push_back({ x,1 });
	}
	lazy_segtree<S, op, e, F, mapping, composition, id> seg(a);
	while (m--) {
		cin >> p;
		if (p == 1) {
			cin >> x >> y >> k;
			seg.apply(x - 1, y, k);
		} else {
			cin >> x >> y;
			cout << seg.prod(x - 1, y).sum << endl;
		}
	}

	return 0;
}

区间修改&区间和

typedef long long ll;
const int mod = 998244353;
struct S {
    ll sum, size;
};
struct F {
    ll mul, add;
};

S op(S l, S r) { return S{ (l.sum + r.sum) % mod, l.size + r.size }; }

S e() { return S{ 0, 0 }; }

S mapping(F f, S x) { return S{ ((x.sum * f.mul) % mod + x.size * f.add % mod) % mod, x.size }; }

F composition(F f, F g) { return F{ (f.mul * g.mul) % mod, ((f.mul * g.add) % mod + f.add) % mod }; }

F id() { return F{ 1, 0 }; }

加上上文中的代码,可以 AC 【模板】线段树 2 - 洛谷

区间赋值

struct S {
	ll sum, len;
};

typedef long long  F;

S op(S l, S r) { return S{ l.sum + r.sum, l.len + r.len }; }

S e() { return S{ 0, 0 }; }

S mapping(F f, S x) { 
	if (f == -1) return x;
	return S{ f * x.len, x.len };
}

F composition(F f, F g) {
	if (f == -1) return g;
	return f;
}

F id() { return -1; }

区间求平方和

struct S {
	ll sum, sq, len;
};

struct F {
	ll mul, add;
};

S op(S l, S r) { return S{ l.sum + r.sum , l.sq + r.sq, l.len + r.len }; }

S e() { return S{ 0, 0, 0 }; }

S mapping(F f, S x) {
	return S{ x.sum * f.mul + x.len * f.add, f.mul * f.mul * x.sq + 2 * f.mul * x.sum * f.add + f.mul * f.add * f.add * x.len, x.len };
}

F composition(F f, F g) { return F{ f.mul * g.mul , f.mul * g.add + f.add }; }

F id() { return F{ 1, 0 }; }

二分线段树

typedef long long ll;
const int N = 5e5;
int n, m, k, a[N + 7], cnt[N + 7], limit = 0;
typedef int S;
S op(S l, S r) { return max(l, r); }
S e() {	return 0; }
bool f(S x) { return x < limit; }
void slove() {
	vector<S> v;
	cin >> n >> m >> k;
	for (int i = 0; i < n; i++) v.push_back(m), cnt[i] = k;
	segtree<int, op, e> seg(v);
	for (int i = 0; i < n; i++) {
		cin >> a[i];
		if (seg.all_prod() < a[i]) cout << -1 << '\n';
		else {
			limit = a[i];
			int id = seg.max_right(0, f);
			cout << id + 1 << '\n';
			--cnt[id];
			seg.set(id, seg.get(id) - a[i]);
			if (cnt[id] == 0) seg.set(id, 0);
		}
	}
}

posted @ 2022-10-15 22:54  CKB2008  阅读(38)  评论(0编辑  收藏  举报