[知识点]线段树

// 此博文为迁移而来,写于2015年3月30日,不代表本人现在的观点与看法。原始地址:http://blog.sina.com.cn/s/blog_6022c4720102vw6j.html

 

UPDATE(20190324):大更新。描述、结构、代码各种调整与修复。

UPDATE(20190131):修复一些小问题。

UPDATE(20180801):修正各种问题。

UPDATE(20151103):线段树所有操作代码更新。

 

一、前言

  一道题目:给定一个序列 a[n] 和 q 个操作,操作可以为在 [l, r] 范围内增加 x,或是询问第 i 个结点当前的值为多少。这不水题嘛!直接模拟就可以了。、

  但是如果 n <= 100000 呢?询问次数超过 100000呢?O(n ^ 2) 并不能扛住。

  今天要引入的内容,线段树,就能解决这类问题。

 

二、概念

  线段树也是一种树结构,而且是二叉树。不一样的是,它的每个节点保存的是一个线段。先看看他的构造:

  我们发现,线段树虽然以树结构展现出来,但是他每个节点索要维护的东西和一般的树结构问题差别还是比较大的。

 

三、构建

struct tree {
    int v, mi, mx, f;
} t[MAXN << 1];

  

  每个节点表示有 v, mi, mx, f 四个变量,分别记录该节点表示线段值的和,最小值,最大值和 lazy 标识。变量设定根据需题目要求而变化。

  线段树的根节点所表示的线段是 [1, n],对于每一个节点均表示一个线段 [l, r] ,其左右儿子节点分别为该节点从中二分的两个子线段 [l, mid], [mid + 1, r],其中 mid = (l + r) / 2。

  如上图所示,当 n = 8 时,mid = (1 + 8) / 2 = 4 (向下取整),所以子节点所代表的线段分别是 [1, 4] 和 [5, 8]。从根节点一次往下建造,直到该节点的线段为一个点,即叶子节点。

  如何给每一个节点标号?根节点为1号,其两个儿子节点分别为2, 3号;2号节点的子节点为4, 5号,3号节点的子节点为6, 7号……易得对于o号节点,其子节点标号为o * 2, o * 2 + 1。由于调用较为频繁,可以直接在将它们分别define为ls, rs(左右儿子)。为了加快运行速度(装逼),代码中采用的是位运算(o >> 1, o >> 1 | 1)。

 
1 void build(int o, int l, int r) {
2     if (l == r) { 
3         t[o] = (tree) {w[l], w[l], w[l]}; 
4         return; 
5     }
6     int m = (l + r) >> 1;
7     build(ls, l, m), build(rs, m + 1, r);
8     t[o] = (Tree){t[ls].v + t[rs].v, max(t[ls].mx, t[rs].mx), min(t[ls].mi, t[rs].mi)};
9 }

 

从根节点到叶子节点递归。叶子节点表示的线段为单独的一个数,权值即为读入的数据。根据题目需要,维护每一条线段的权值和,最大值,最小值等。

 

四、操作类型

  线段树支持各类对单点或区间的修改和询问操作。如:

  1. 单点加减:给定 x, k,将 a[x] 加上 k

  2. 区间加减:给定 l, r, k,将 a[l~r] 全部加上 k

  3. 单点查询:给定 x,查询 a[x] 的值

  4. 区间查询:给定 l, r,查询 a[l~r] 的元素和

 

<1> 单点加减 / 单点修改

 1 void upds(int o, int l, int r) {
 2     if (l == r) {
 3         t[o].v += k, t[o].mi += k, t[o].mx += k;
 4         return;
 5     }
 6     int m = (l + r) >> 1;
 7     if (x <= m) upds(ls, l, m);
 8     else upds(rs, m + 1, r);
 9     t[o] = (Tree){t[ls].v + t[rs].v, max(t[ls].mx, t[rs].mx), min(t[ls].mi, t[rs].mi)};
10 }

 

从根节点向下递归到叶子节点,维护所有包括 x 的线段节点,和构建过程类似。

 

<2> 单点查询

1 int ques(int o, int l, int r) {
2     int m = (l + r) >> 1;
3     if (l == r) return t[o].v;
4     return q <= m ? ques(ls, l, m) : ques(rs, m + 1, r);
5 }

 

<3> 区间加减

 1 void updm(int o, int l, int r) {
 2     if (t[o].f) push(o, r - l + 1); // 这是什么?
 3     if (ql <= l && r <= qr) {
 4         t[o].f += x, t[o].mi += x, t[o].mx += x, t[o].v += x * (r - l + 1);
 5         return;
 6     }
 7     int m = (l + r) >> 1;
 8     if (ql <= m) updm(ls, l, m); 
 9     if (qr >= m + 1) updm(rs, m + 1, r);
10     t[o] = (Tree) {t[ls].v + t[rs].v, max(t[ls].v, t[rs].v), min(t[ls].v, t[rs].v)};
11 }

 

  与单点操作不同的是,所询问区间可能为线段树上多个线段的组合,需要分情况,在递归回溯时再综合处理,进行维护。

  如,现在需要对区间[3, 5]进行修改,从根节点开始递归。

  对于根节点,mid = (1 + 7) / 2 = 4,而3 <= 4, 5 <= 5,则两个子区间均包含所修改的区间,故需分别向下递归。

  对于区间[1, 4],包含所修改区间的[3, 4]。mid = (1 + 4) / 2 = 2, 3 >= 2, 3 <= 4, 则全部包含于右儿子区间[3, 4],只需向右递归。

  而对于区间[5, 7],同理易得仅有区间[5, 5]包含所修改区间。

  故最后需要修改的区间分别为[3, 4],[5, 5]。修改后,向上回溯维护其他区间。

  注意到代码中的push函数,这是线段树本身的一个优化,可以大幅提高区间操作的效率。我们发现,对于每一个节点设置了一个变量f,被称为lazy标识。什么是 lazy 标识?让我们来讲个故事——

  秋天到了,作为庙里的一个和尚,方丈给你一个任务:扫院子里的落叶。落叶在秋天是会不断落下的,所以每次扫完之后,过不一阵子就又会有新的落叶了。这时,你可以在方丈还没有来检查你的工作时偷一点懒——让叶子先堆积在那里,直到方丈要来检查时再一起扫完。

  lazy 标识的作用便是如此。

  

  在对线段 [l, r] 进行修改操作时,可以只修改该线段节点的 lazy 标识,而不向下递归更新,在之后的查询或修改操作覆盖到该线段后再进行更新。

  如,现在对 [1, 4] 增加 5,则 t[2].f = 5。查询 [3, 3] 时,将 [1, 4] 的 lazy 标识下放至 [1, 2] 和 [3, 4];[1, 2] 的 lazy 标识保持不动,[3, 4] 的 lazy 标识继续下放至 [3, 3] 和 [3, 4]。

 
1 void push(int o, int tot) { // tot为该区间内节点个数
2     t[ls].mi += t[o].f, t[rs].mi += t[o].f;
3     t[ls].mx += t[o].f, t[rs].mx += t[o].f;
4     t[ls].v += t[o].f * (tot - tot / 2), t[ls].v += t[o].f * tot / 2;
5     t[ls].f += t[o].f, t[rs].f += t[o].f;
6     t[o].f = 0; 
7 }
 

<4> 查询区间和 / 最大值 / 最小值

1 int quem(int o, int l, int r) {
2     if (t[o].f) push(o, r - l + 1);
3     if (ql <= l && r <= qr) return t[o].v;
4     int m = (l + r) >> 1;
5     return (ql <= m ? quem(ls, l, m) : 0) + (qr >= m + 1 ? quem(rs, m + 1, r) : 0);
6 }

 

        这里只给出了区间和的代码,最大值最小值同理。

 

五、完整代码

(涵盖单点加减 / 单点查询 / 区间加减 / 区间查询权值和) 
 
 1 #include <bits/stdc++.h>
 2 using namespace std;
 3 
 4 #define MAXN 100005
 5 #define ls o << 1
 6 #define rs o << 1 | 1
 7 
 8 int n, T, w[MAXN], q, x, ql, qr, p;
 9 
10 struct tree {
11     int v, mi, mx, f;
12 } t[MAXN << 1];
13 
14 void build(int o, int l, int r) {
15     if (l == r) { 
16         t[o] = (tree){w[l], w[l], w[l]}; 
17         return; 
18     }
19     int m = (l + r) >> 1;
20     build(ls, l, m), build(rs, m + 1, r);
21     t[o] = (tree){t[ls].v + t[rs].v, max(t[ls].mx, t[rs].mx), min(t[ls].mi, t[rs].mi)};
22 }
23 
24 void upds(int o, int l, int r) {
25     if (l == r) {
26         t[o].v += x, t[o].mi += x, t[o].mx += x;
27         return;
28     }
29     int m = (l + r) >> 1;
30     if (q <= m) upds(ls, l, m);
31     else upds(rs, m + 1, r);
32     t[o] = (tree){t[ls].v + t[rs].v, max(t[ls].mx, t[rs].mx), min(t[ls].mi, t[rs].mi)};
33 }
34 
35 void push(int o, int tot) {
36     t[ls].mi += t[o].f, t[rs].mi += t[o].f;
37     t[ls].mx += t[o].f, t[rs].mx += t[o].f;
38     t[ls].v += t[o].f * (tot - tot / 2), t[ls].v += t[o].f * tot / 2;
39     t[ls].f += t[o].f, t[rs].f += t[o].f;
40     t[o].f = 0; 
41 }
42  
43 void updm(int o, int l, int r) {
44     if (t[o].f) push(o, r - l + 1); 
45     if (ql <= l && r <= qr) {
46         t[o].f += x, t[o].mi += x, t[o].mx += x, t[o].v += x * (r - l + 1);
47         return;
48     }
49     int m = (l + r) >> 1;
50     if (ql <= m) updm(ls, l, m); 
51     if (qr >= m + 1) updm(rs, m + 1, r);
52     t[o] = (tree){t[ls].v + t[rs].v, max(t[ls].mx, t[rs].mx), min(t[ls].mi, t[rs].mi)};
53 }
54 
55 int ques(int o, int l, int r) {
56     int m = (l + r) >> 1;
57     if (l == r) return t[o].v;
58     return q <= m ? ques(ls, l, m) : ques(rs, m + 1, r);
59 }
60 
61 int quem(int o, int l, int r) {
62     if (t[o].f) push(o, r - l + 1);
63     if (ql <= l && r <= qr) return t[o].v;
64     int m = (l + r) >> 1;
65     return (ql <= m ? quem(ls, l, m) : 0) + (qr >= m + 1 ? quem(rs, m + 1, r) : 0);
66 }
67 
68 int main() {
69     scanf("%d %d", &n, &T);
70     for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
71     build(1, 1, n);
72     for (int i = 1; i <= T; i++) {
73         scanf("%d", &p);
74         if (p == 1) scanf("%d %d", &q, &x), upds(1, 1, n);
75         else if (p == 2) scanf("%d %d %d", &ql, &qr, &x), updm(1, 1, n);
76         else if (p == 3) scanf("%d", &q), printf("%d\n", ques(1, 1, n));
77         else scanf("%d %d", &ql, &qr), printf("%d\n", quem(1, 1, n));
78     }
79     return 0;
80 }

 

posted @ 2015-07-26 11:48  jinkun113  阅读(289)  评论(6编辑  收藏  举报