线段树 Segment Tree

Question

已知一个数列,你需要进行下面两种操作:

  1. 将某区间每一个数加上 k
  2. 求出某区间每一个数的和。

Input

第一行两个整数 n,m(1n,m105), 分别表示该数列数字的个数和操作的总个数;
第二行包含 n 个用空格分隔的整数,其中第 i 个数字表示数列第 i 项的初始值;
接下来 m 行每行包含 34 个整数,表示一个操作,具体如下:

  1. 1 x y k:将区间 [x,y] 内每个数加上 k
  2. 2 x y:输出区间 [x,y] 内每个数的和。

Output

输出包含若干行整数,即为所有操作 2 的结果。

Example

5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
11
8
20



这是一个被称为“区间和”的线段树模板题。 数据范围不是很大时,我们本能地考虑直接算法。
  1. 遍历区间 [x,y],进行元素的更新,时间复杂度 O(mn) .
  2. 遍历区间 [x,y],对元素进行加和,时间复杂度 O(mn) .
点击查看代码
#include <bits/stdc++.h>
using namespace std;
int n, m, x, y, k, f, s, a[500000];

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    while (m--)
    {
        cin >> f;
        if (f == 1)
        {
            cin >> x >> y >> k;
            for (int i = x; i <= y; i++)
                a[i] += k;
        }
        if (f == 2)
        {
            s = 0;
            cin >> x >> y;
            for (int i = x; i <= y; i++)
                s += a[i];
            cout << s << endl;
        }
    }
}

显然这种形式对于有大量查询操作的情况下,耗时会很严重。

如果能预处理存储区间信息,更新和查询时不再是对每个数一一处理,而是通过分治将时间复杂度降低为 O(mlogn),问题便得以解决。实现这种思路的工具被称作线段树。

线段树是一种二叉树,它的节点存储以下信息:

  1. 区间范围;
  2. 区间和;
  3. 延迟标记。

我们一一分析。



对于区间 [1,5],根据左闭右开原则划分为 [1,3][4,5]
对于区间 [1,3],可以划分为 [1,2][3,3]
对于区间 [1,2],可以划分为 [1,1][2,2]
对于区间 [4,5],可以划分为 [4,4][5,5]
这样我们得到了一个二叉树的结构,它的每个节点存储了区间范围的信息。

对于区间 [1,1],由于它的左端点与右端点相等,区间和即为左端点元素的值;
对于区间 [1,2],它的区间和等于 [1,1] 的区间和加 [2,2] 的区间和;
对于区间 [1,3],它的区间和等于 [1,2] 的区间和加 [3,3] 的区间和;
这样我们得到了每一个区间对应的区间和,线段树建立完成:

void build(int node, int left, int right)
{
    if (left == right)
    {
        sum[node] = a[left];
        return;
    }
    build(node * 2, left, mid);
    build(node * 2 + 1, mid + 1, right);
    sum[node] = sum[node * 2] + sum[node * 2 + 1];
}

延迟标记的作用体现在记录前一次更新的数据,我们不需要每次元素更新时都对线段树搜索到底。
比如要求更新 [3,5] 的值,当更新到 [4,5] 就不必再进行下去了,因为我们知道 [4,4][5,5] 一定也会被更新。这时将更新信息标记在 [4,5] 上,下次更新或查询时再加上标记即可。这个方法被称作延迟标记。
将延迟标记下传给子节点:

void pushdown(int node, int left, int right)
{
    lazy[node * 2] += lazy[node];
    lazy[node * 2 + 1] += lazy[node];
    sum[node * 2] += lazy[node] * (mid - left + 1);
    sum[node * 2 + 1] += lazy[node] * (right - mid);
    lazy[node] = 0;
}

更新区间:

void update(int node, int left, int right)
{
    if (left >= x && right <= y)
    {
        lazy[node] += k;
        sum[node] += k * (right - left + 1);
        return;
    }
    if (lazy[node])
        pushdown(node, left, right);
    int mid = (left + right) / 2;
    if (mid >= x)
        update(node * 2, left, mid);
    if (mid < y)
        update(node * 2 + 1, mid + 1, right);
    sum[node] = sum[node * 2] + sum[node * 2 + 1];
}

查询区间:

int query(int node, int left, int right)
{
    if (left >= x && right <= y)
        return sum[node];
    if (left > y || right < x)
        return 0;
    if (lazy[node])
        pushdown(node, left, right);
    int mid = (left + right) / 2;
    return query(node * 2, left, mid) + query(node * 2 + 1, mid + 1, right);
}



Answer

#include <bits/stdc++.h>
#define mid (left + right >> 1)
#define L(x) (x << 1)
#define R(x) (x << 1 | 1)
using namespace std;
typedef long long ll;
const int N = 500000;
ll n, m, x, y, k, f, a[N], sum[N], lazy[N];

inline int read()
{
    int s = 1, v = 0;
    char c = getchar();
    while (c < '0' || c > '9')
    {
        if (c == '-')
            s = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9')
    {
        v = (v << 3) + (v << 1) + (c ^ 48);
        c = getchar();
    }
    return s * v;
}

void build(ll node, ll left, ll right)
{
    if (left == right)
    {
        sum[node] = a[left];
        return;
    }
    build(L(node), left, mid);
    build(R(node), mid + 1, right);
    sum[node] = sum[L(node)] + sum[R(node)];
}

void pushdown(ll node, ll left, ll right)
{
    lazy[L(node)] += lazy[node];
    lazy[R(node)] += lazy[node];
    sum[L(node)] += lazy[node] * (mid - left + 1);
    sum[R(node)] += lazy[node] * (right - mid);
    lazy[node] = 0;
}

void update(ll node, ll left, ll right)
{
    if (left >= x && right <= y)
    {
        lazy[node] += k;
        sum[node] += k * (right - left + 1);
        return;
    }
    if (lazy[node])
        pushdown(node, left, right);
    if (mid >= x)
        update(L(node), left, mid);
    if (mid < y)
        update(R(node), mid + 1, right);
    sum[node] = sum[L(node)] + sum[R(node)];
}

ll query(ll node, ll left, ll right)
{
    if (left >= x && right <= y)
        return sum[node];
    if (left > y || right < x)
        return 0;
    if (lazy[node])
        pushdown(node, left, right);
    return query(L(node), left, mid) + query(R(node), mid + 1, right);
}

int main()
{
    n = read(), m = read();
    for (int i = 1; i <= n; i++)
        a[i] = read();
    build(1, 1, n);
    while (m--)
    {
        f = read();
        if (f == 1)
        {
            x = read(), y = read(), k = read();
            update(1, 1, n);
        }
        if (f == 2)
        {
            x = read(), y = read();
            cout << query(1, 1, n) << endl;
        }
    }
}
posted @   rzk_零月  阅读(61)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示