[模板]线段树

线段树利用分治思想在区间上统计信息.

一颗线段树有着如下的结构:

1.线段树的每一个节点表示一段区间,保存着这段区间的左端点l,右端点r,以及该段区间的某些信息(如最值,和).

2.作为二叉树,对于每一个非叶节点,若其代表了区间[l,r],则其左儿子代表区间[l,mid],右儿子代表区间[mid+1,r](其中mid=⌊(l+r)/2⌋).

  显然,线段树的根节点代表整个统计范围[1,n].

3.对于每一个叶节点,其代表的区间长度为1.

直观感受一下:(图-<<指南>>)

 会发现,对于建立在区间[1,N]上的线段树,如果把最后一层补全会使该层有N~2N-1(不确定具体值,但一定是O(N))个节点,而对于这样的一颗二叉树,其高度为O(logN).

由此可知这个二叉树中会有O(N)个节点,实践中往往需要4*N的空间来存储才能保证足够.

基于二叉树结构,线段树可以方便地上下传递,整合信息,这里的信息必须是具有"结合律"的.

 

线段树分为两大类:

①单点修改+区间查询型

这种线段树支持的操作有:

1.建树O(N)

利用数组存储一颗二叉树,回忆一下手写堆是怎么做的:

对于节点p,其左儿子表示为p*2,右儿子表示为p*2+1.

struct ST{      // Segment Tree
    int l, r, big;
}t[4 * N + 10];
int a[N + 10];    // 需要统计的数据区间

void build(int p, int l, int r){
    t[p].l = l, t[p].r = r;
    if(l == r){
        t[p].big = a[l];
        return;
    }
    int mid = t[p].l + t[p].r >> 1;
    build(p * 2, l, mid), build(p * 2 + 1, mid + 1, r);
    // 下面维护所需的区间信息, 这里以区间最大值为例
    t[p].big = max(t[p * 2].big, t[p * 2 + 1].big);
}

// 调用一次build(1,1,n)即可建树

观察这个build,会发现它对时间是零浪费的(递归的下一个节点总是未遍历过的),因此时间复杂度为O(N).

此后,便可以用t[p]表示线段树的p号节点并访问其区间信息.

 

2.单点修改O(logN)

(以维护区间最大值为例)

假设现在有如下线段树(圆圈表示节点,其中的数字表示该节点代表的区间元素的最大值):

 

 现在需要把最左下角节点(因为是叶节点,它对应原始数据中的一个元素)的数据改为10,会发现需要依次更新其父节点"③⑤⑨"为"⑩".

 这个过程花费O(logN)时间.

不过需要先从根节点出发,找到需要修改的位置后再执行上述操作,这个过程也是O(logN)的.

void change(int p, int x, int v){
    if(t[p].l == t[p].r){
        t[p].big = v;
        return;
    }
    int mid = t[p].l + t[p].r >> 1;
    if(x <= mid) change(p * 2, x, v);
    else change(p * 2 + 1, x, v);
    t[p].big = max(t[p * 2].big, t[p * 2 + 1].big);
}
// 调用change(1, x, v)将位于原始数据区间中位置为x的线段树节点信息更改为v

 

3.区间查询O(logN)

(仍是区间最大值)

想要获取给定区间[l,r]的信息,大多数情况下不存在某个线段树节点刚好存储着[l,r]的信息,因此需要整合多个节点的信息.

由于线段树的二分性质,总是可以用若干个节点不重不漏地表示范围内的任意区间.只需要进行如下操作:

检查区间[a,b]:

  若被[l,r]包含,直接返回此区间(节点)的信息.

  若与[l,r]不沾边,舍弃.(对于求区间最大值来说,实现方法是返回一个极小的值)

  否则,把它从中间一刀两断为[a,mid],[mid+1,b]:

    若mid>=l,那么递归地检查区间[a,mid]并整合信息.

    若mid<r,那么递归地检查区间[mid+1,r]并整合信息.

返回整合后的信息.

这样的分治花费O(logN),递归终点总是检查区间被[l,r]包含的情况.

对于不同的区间信息,这里的整合有不同的方式,这里是查询区间最大值的实现,其中舍弃区间通过返回0实现:

int ask(int p, int l, int r){
    if(t[p].r <= r && t[p].l >= l) return t[p].big;
    int ret = 0, mid = t[p].l + t[p].r >> 1;
    if(mid >= l) ret = max(ret, ask(p * 2, l, r));
    if(mid < r) ret = max(ret, ask(p * 2 + 1, l, r));
    return ret;
}
// 调用ask(1, l, r)以查询区间[l, r]的最大值

 

现在就可以构造出一颗支持单点修改,查询区间最大值的线段树了,这里是模板题:

I Hate It

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;

struct ST{
    int l, r, big;
}t[800010];
int n, m, a[200010];

void build(int p, int l, int r){
    t[p].l = l, t[p].r = r;
    if(l == r){
        t[p].big = a[l];
        return;
    }
    int mid = t[p].l + t[p].r >> 1;
    build(p * 2, l, mid), build(p * 2 + 1, mid + 1, r);
    t[p].big = max(t[p * 2].big, t[p * 2 + 1].big);
}
void change(int p, int x, int v){
    if(t[p].l == t[p].r){
        t[p].big = v;
        return;
    }
    int mid = t[p].l + t[p].r >> 1;
    if(x <= mid) change(p * 2, x, v);
    else change(p * 2 + 1, x, v);
    t[p].big = max(t[p * 2].big, t[p * 2 + 1].big);
}
int ask(int p, int l, int r){
    if(t[p].r <= r && t[p].l >= l) return t[p].big;
    int ret = 0, mid = t[p].l + t[p].r >> 1;
    if(mid >= l) ret = max(ret, ask(p * 2, l, r));
    if(mid < r) ret = max(ret, ask(p * 2 + 1, l, r));
    return ret;
}

void solve(){
    for(int i = 1; i <= n; i++) scanf("%d", a + i);
    build(1, 1, n);
    while(m--){
        char ch;
        int x, y;
        cin >> ch;
        scanf("%d%d", &x, &y);
        if(ch == 'Q') printf("%d\n", ask(1, x, y));
        else change(1, x, y);
    }
}


int main(){
    while(scanf("%d%d", &n, &m) != EOF) solve();

    return 0;
}
单点修改,区间查询最大值

 

在这个模板的基础上,对线段树的每种操作都稍加修改可以得到支持单点修改,查询区间和的线段树,总共只改了不到十行.

敌兵布阵

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <string>
using namespace std;

struct ST {
    int l, r, sum;
} t[200010];
int n, a[50010];

void build(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l == r) {
        t[p].sum = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(p * 2, l, mid), build(p * 2 + 1, mid + 1, r);
    t[p].sum = t[p * 2].sum + t[p * 2 + 1].sum;
}
void change(int p, int x, int v) {
    if (t[p].l == t[p].r) {
        t[p].sum += v;
        return;
    }
    int mid = t[p].l + t[p].r >> 1;
    if (x <= mid)
        change(p * 2, x, v);
    else
        change(p * 2 + 1, x, v);
    t[p].sum = t[p * 2].sum + t[p * 2 + 1].sum;
}
int ask(int p, int l, int r) {
    if (l <= t[p].l && r >= t[p].r) return t[p].sum;
    int ret = 0;
    int mid = t[p].l + t[p].r >> 1;
    if (mid >= l) ret += ask(p * 2, l, r);
    if (mid < r) ret += ask(p * 2 + 1, l, r);
    return ret;
}

void solve() {
    string s;
    scanf("%d", &n);
    if (n == 0) {
        cin >> s;
        return;
    }
    for (int i = 1; i <= n; i++) scanf("%d", a + i);
    for (int i = 1; i <= n * 4; i++) t[i].l = t[i].r = t[i].sum = 0;
    build(1, 1, n);
    while (cin >> s && s[0] != 'E') {
        int x, y;
        scanf("%d%d", &x, &y);
        if (s[0] == 'Q')
            printf("%d\n", ask(1, x, y));
        else if (s[0] == 'A')
            change(1, x, y);
        else
            change(1, x, -y);
    }
    // for(int i = 1; i <= 100; i++) if(t[i].l == 5 && t[i].r == 5)
    // printf("!!!%d\n", t[i].sum);
}

int main() {
    // freopen("data.in", "r", stdin);
    // freopen("data.out", "w", stdout);
    int t;
    scanf("%d", &t);
    for (int i = 1; i <= t; i++) {
        printf("Case %d:\n", i);
        solve();
    }

    return 0;
}
单点修改,区间查询和

 

②区间修改+区间查询型(使用懒标记)

由于区间的大小可以限制为1,所以完全可以取代前一种线段树,但是实现起来稍微多了点东西.

下列操作复杂度同上,实现参考(照搬)了<<指南>>,将会完成一个同时维护了区间和,区间最大值,支持区间操作的线段树.

1.建树

struct ST{
    int l, r, big, sum;
    int tag;
    #define l(x) st[x].l
    #define r(x) st[x].r
    #define big(x) st[x].big
    #define sum(x) st[x].sum
    #define tag(x) st[x].tag
}st[400010];        // 4 * N
int n, q;

void build(int p, int l, int r){
    l(p) = l, r(p) = r;
    if(l == r) {sum(p) = 0; big(p) = 0; return;}
    int mid = l + r >> 1;
    build(p * 2, l, mid);
    build(p * 2 + 1, mid + 1, r);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    big(p) = max(big(p * 2), big(p * 2 + 1)); 
}

 

2.区间修改+传递懒标记

void spread(int p){
    if(!tag(p)) return;
    big(p * 2) += tag(p), big(p * 2 + 1) += tag(p);
    sum(p * 2) += tag(p) * (r(p * 2) - l(p * 2) + 1);
    sum(p * 2 + 1) += tag(p) * (r(p * 2 + 1) - l (p * 2 + 1) + 1);
    tag(p * 2) += tag(p), tag(p * 2 + 1) += tag(p);
    tag(p) = 0;
}
void change(int p, int l, int r, int x){
    if(l <= l(p) && r >= r(p)) {
        big(p) += x;
        sum(p) += x * (r(p) - l(p) + 1);
        tag(p) += x;
        return;
    }
    spread(p);
    int mid = l(p) + r(p) >> 1;
    if(l <= mid) change(p * 2, l, r, x);
    if(r > mid) change(p * 2 + 1, l, r, x);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    big(p) = max(big(p * 2), big(p * 2 + 1));
}

 

3.区间查询

int ask_big(int p, int l, int r){
    if(l <= l(p) && r >= r(p)) return big(p);
    spread(p);
    int mid = l(p) + r(p) >> 1;
    int ret = 0;
    if(l <= mid) ret = max(ret, ask_big(p * 2, l, r));
    if(r > mid) ret = max(ret, ask_big(p * 2 + 1, l, r));
    return ret;
}
int ask_sum(int p, int l, int r){
    if(l <= l(p) && r >= r(p)) return sum(p);
    spread(p);
    int mid = l(p) + r(p) >> 1;
    int ret = 0;
    if(l <= mid) ret += ask_sum(p * 2, l, r);
    if(r > mid) ret += ask_sum(p * 2 + 1, l, r);
    return ret;
}

 

关于懒标记:在上面的实现里,每当调用spread(p)传递懒标记后,节点p及其子节点p*2,p*2+1的值均成为最新的(即正确的)状态,并且p的懒标记被利用后合理地清除了,而其子节点p*2,p*2+1的懒标记则仍然存在并且根据p的懒标记进行了更新.

这意味着懒标记最终会被传递到叶子节点上,注意叶子节点具不具有懒标记并不能说明叶子节点是否最新(从而正确),而叶子节点的懒标记是不会传递下去的(因为相关的函数在递归到叶子时一定会在spread之前return).

并且,在上面的change和ask函数里,发现有直接使用标记修改值的操作而不是调用spread.这是因为由于递归,当函数对这个节点调用时,可以保证这个节点自身的状态(在change中,是更改前;在aks中,是现在)是最新的,只需要将其标记向下传递即可.


 

最后放一道使用了稍微复杂一点的懒标记的题目.

这里维护了一个可以同时进行区间增加和区间数乘操作的线段树.

P3373 【模板】线段树 2

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <string>
#include <set>
using namespace std;

struct ST{
    int l, r;
    long long sum, tagA, tagM = 1;
    #define l(x) st[x].l
    #define r(x) st[x].r
    #define sum(x) st[x].sum
    #define tagA(x) st[x].tagA
    #define tagM(x) st[x].tagM
}st[400010];        // 4 * N
int n, q, M;
long long a[100010];

void spread(int p){
    if(!tagA(p) && tagM(p) == 1) return;
    sum(p * 2) = (sum(p * 2) * tagM(p) % M + tagA(p) * (r(p * 2) - l(p * 2) + 1) % M) % M;
    sum(p * 2 + 1) = (sum(p * 2 + 1) * tagM(p) % M + tagA(p) * (r(p * 2 + 1) - l(p * 2 + 1) + 1) % M) % M;
    tagA(p * 2) = (tagA(p * 2) * tagM(p) % M + tagA(p)) % M, tagM(p * 2) = tagM(p * 2) * tagM(p) % M;
    tagA(p * 2 + 1) = (tagA(p * 2 + 1) * tagM(p) % M + tagA(p)) % M, tagM(p * 2 + 1) = tagM(p * 2 + 1) * tagM(p) % M;
    tagA(p) = 0, tagM(p) = 1;
}
void add(int p, int l, int r, int x){
    if(l <= l(p) && r >= r(p)) {
        sum(p) = (sum(p) + x * (r(p) - l(p) + 1)) % M;
        tagA(p) = (tagA(p) + x) % M;
        return;
    }
    spread(p);
    int mid = l(p) + r(p) >> 1;
    if(l <= mid) add(p * 2, l, r, x);
    if(r > mid) add(p * 2 + 1, l, r, x);
    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % M;
}
void mul(int p, int l, int r, int x){
    if(l <= l(p) && r >= r(p)){
        sum(p) = (sum(p) * x) % M;
        tagM(p) = (tagM(p) * x) % M;
        tagA(p) = (tagA(p) * x) % M;
        return;
    }
    spread(p);
    int mid = l(p) + r(p) >> 1;
    if(l <= mid) mul(p * 2, l, r, x);
    if(r > mid) mul(p * 2 + 1, l, r, x);
    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % M;
}
void build(int p, int l, int r){
    l(p) = l, r(p) = r;
    if(l == r) {sum(p) = a[l];  return;}
    int mid = l + r >> 1;
    build(p * 2, l, mid);
    build(p * 2 + 1, mid + 1, r);
    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % M;
}
int ask(int p, int l, int r){
    if(l <= l(p) && r >= r(p)) return sum(p);
    spread(p);
    int mid = l(p) + r(p) >> 1;
    int ret = 0;
    if(l <= mid) ret += ask(p * 2, l, r);
    if(r > mid) ret += ask(p * 2 + 1, l, r);
    return ret % M;
}

int main(){
    scanf("%d%d%d", &n, &q, &M);
    for(int i = 1; i <= n; i++) scanf("%lld", a + i);
    build(1, 1, n);

    while(q--){
        int opr, x, y, k;
        scanf("%d%d%d", &opr, &x, &y);
        if(opr == 1){
            scanf("%d", &k);
            mul(1, x, y, k);
        }else if(opr == 2){
            scanf("%d", &k);
            add(1, x, y, k);
        }else printf("%d\n", ask(1, x, y) % M);
    }

    return 0;
}
P3373

 

posted @ 2021-05-21 20:03  goverclock  阅读(59)  评论(0编辑  收藏  举报