线段树详解与实现
此篇文章用于记录《玩转数据结构》课程的学习笔记
什么是线段树
线段树也被称为区间树,英文名为Segment Tree
或者Interval tree
,是一种高级的数据结构。这种数据结构更多出现在竞赛中,在常见的本科数据结构教材里没有介绍这种数据结构。但是,在面试中却有可能碰到和线段树相关的问题。那么为什么会产生线段树这种数据结构,线段树到底是为了解决什么样的一种问题呢?
其实这里的线段
可以理解为区间
,线段树就是为了解决区间问题的。
有一个很经典的线段树问题是:区间染色。
假设有一面墙,长度为 n,每次选择一段墙进行染色。
在区间染色的过程中,每次选择一段区间进行染色,这时新的颜色可能会覆盖之前的颜色。
最后的问题是:
- 在经过 m 次染色操作后,我们可以在整个区间看见多少种颜色?
更加普遍的说法是:
- 在经过 m 次染色操作后,我们可以在区间
[i, j]
内看见多少种颜色?
由于第一个问题是第二个问题的一个特例,我们采用第二种问题来思考解决方法。
从上面可以看出,我们对于区间,有 2 种操作,分别是染色操作和查询区间的颜色,使用更加一般的说法,染色操作就是更新区间,查询区间的颜色就是查询区间。
这类问题里面,更加常见的的是区间查询:一个数组存放的不再是颜色,而是具体的数字,查询某个区间[i, j]
的统计值。这里的统计值是指:区间内最大值、最小值、或者这个区间的数字和。
比如:
- 查询 2018 年注册的用户中消费最高的用户
- 查询 2018 年注册的用户中消费最低的用户
注意上面两种情况都是动态查询,我们查询的消费数据不只是 2018 的消费数据。
如果我们想查询 2018 年中消费最高的用户,那么 2018 年的数据已经固定了,我们直接在这一年的数据中进行统计分析就行了。
但是一个 2018 年注册的用户,在 2019 年、2020 年都可能会有消费。我们实际上查询的是:2018 年注册的用户中,到现在为止,消费最高的用户。
这种情况下,数据是在动态变化的, 也就是说:2017 年注册的用户中,每个用户的消费额是会更新的,这就对应到更新区间的操作。
此时线段树就是一种好的选择。
按照通常的思路,使用数组存储上述的元素是比较好的,思考上面两个操作的时间复杂度:
- 更新区间:每次根据需要更新的区间的首尾索引,逐个遍历区间中的元素进行更新,时间复杂度为
O(n)
。 - 查询区间:每次根据需要更新的区间的首尾索引,逐个遍历区间种的元素进行查询,时间复杂度为
O(n)
。
两个操作的时间复杂度均为O(n)
,对于需要多次动态使用的场景来说,性能可能是不够好的。
在这类问题中,我们关注的是一个个区间内的元素的情况,线段树就有用武之地了,线段树的优点就是把两个操作的时间复杂度降到了O(logn)
。
操作 | 使用数组 | 使用线段树 |
---|---|---|
更新 | O(n) | O(logn) |
查询 | O(n) | O(logn) |
这里提一点,如果你看到一个算法的时间复杂度是O(logn)
,那这个算法大多与二叉树、分治算法有关。
这里也不例外,线段树就是使用二叉树来实现的。
那么一个区间是如何被构建成为一个二叉树的?
对于一个数组 A,如下所示:
对应的线段树就是:
二叉树中每个非叶子节点表示的是区间内元素的统计值,叶子节点存储的就是元素本身。上面说了统计值是指:区间内最大值、最小值、或者这个区间的数字和。比如你要求区间的最大值,每个每个节点存储的就是这个区间内元素的最大值。像下面这样:
假设你要查询[4,7]
区间内的最大值,那么不用查到叶子节点,而是查到A[4, 7]
这个节点就行了。
当然,并不是所有的区间都恰好落在一个节点,比如你要求[2, 5]
区间内的最大值。那么就要分别找到A[2, 3]
和A[4,5]
的最大值,再进行比较。
可以看出,线段树的查询区间操作不需要遍历区间中的每一个元素,只要找到对应的树节点就可以返回,时间复杂度为O(logn)
。
总结
从更加抽象的角度来讲,线段树的使用场景就是,对于给定区间,进行更新区间和查询区间操作:
- 更新区间:更新区间中的一个元素或者一个区间的值。
- 查询区间:查询一个区间
[i, j]
的最大值、最小值、或者区间的数字和。
注意,在大多数情况下,我们是不考虑区间里添加元素和删除元素的,我们假设区间的大小是固定的。
线段树的表示
树的一般表示方法是链式存储,每个节点有两个指针,一个指向左孩子,一个指向右孩子。但是满二叉树,和完全二叉树,除了使用链表法来存储,还可以使用数组来表示。
在满二叉树和完全二叉树的数组表示中,假设一个节点的所以是i
,那么左孩子的索引就是$2 \times i +1$,右孩子的索引就是$2 \times i +2$。
那么线段树是不是满二叉树或者完全二叉树呢? 能不能使用数组来表示呢?
在上面的例子中,我们的二叉树恰好是一棵满二叉树,这是因为我们的数组大小恰好是 8,也就是$2^3$,只有数组大小恰好是 2 的 n 次幂,所对应的线段树才会是一个满二叉树。在大部分情况下,线段树并不是一个满二叉树。如果一个数组的大小是 10 ,对应的线段树如下图所示。
所以线段树不是满二叉树,也不是完全二叉树。但实际上:线段树是平衡二叉树,是可以保证O(logn)
的时间复杂度的,这里就不证明了。
其实平衡二叉树可以看作特殊的满二叉树,进而使用数组来表示。
下一步就是确定:对于大小为 n 的数组,需要多大的空间来存储线段树。
首先说一个结论,对于高为 h 层的满二叉树,一共有$2{h}-1$个节点,而最后一层有$2$个节点。那么:$除了最后一层的前面所有节点数 = 总的节点数 - 最后一层的节点数 = (2^{h}-1) - 2^{h-1} \approx 2^{h-1}-1$。也就是:前面所有层的节点数之和等于最后一层的节点数减 1。
那么线段树所需要的节点数量,分两种情况来讨论:
-
如果 n 恰好是 2 的 k 次幂,由于线段树最后一层的叶子节点存储的是数组元素本身,最后一层的节点数就是 n,而根据上面的结论,前面所有层的节点数之和是$n-1$,那么总节点数就是$2 \times n-1$。为了方便起见,分配$ 2 \times n$的空间。
-
如果 n 不是 2 的 k 次幂,最坏的情况就是$n=2^k+1$,那么有一个元素需要开辟新的一层来存储,需要$4 \times n-5$的大小。为了方便起见,我们分配$4 \times n$的空间,已经足够了。
综上,首先需要判断数组的大小是否为 $2^k$,是则使用 $2 \times n$ 的空间,否则使用的 $4 \times n$ 空间。下面线段树的实现是基于数组来实现的,不过为了简便起见,下面的实现统一使用 $4 \times n$ 空间来存储线段树。
使用数组来存储线段树,会有一定的空间浪费,但是换来的时间复杂度的降低是可以接受的。同时,我也在最后会介绍链式存储的实现方式。
线段树的实现
首先需要两个数组,其中data
存放原来的数据,tree
就是存放线段树。
基本的 API 有getSize()
:返回数组元素个数;get(int index)
:根据索引获取数据。
其中每个元素使用泛型E
表示,这是为了考虑可扩展性:如果你的数组元素不是数字,而是自定义的类,那么使用泛型就是比较好的选择。
public class SegmentTree<E> {
private E[] tree; //线段树
private E[] data; //数据
public SegmentTree(E[] arr) {
data = (E[]) new Object[arr.length];
tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
}
// 返回数组元素个数
public int getSize() {
return data.length;
}
// 根据索引获取数据
public E get(int index) {
if (index < 0 || index > data.length)
throw new IllegalArgumentException("Index is illegal");
return data[index];
}
}
由于把线段树看作一棵完全二叉树,应该定义两个 API,根据一个节点获取到它的左孩子和右孩子。
// 根据一个节点的索引 index,返回这个节点的左孩子的索引
private int leftChild(int index) {
return 2 * index + 1;
}
// 根据一个节点的索引 index,返回这个节点的右孩子的索引
private int rightChild(int index) {
return 2 * index + 2;
}
线段树的构建
下面考虑的就是构造线段树的每个节点。这里以求区间的最大值为例。
根节点存储的是整个[0,7]
区间的最大值,左孩子存储的是[0,3]
区间内的最大值,右孩子存储的是[4,7]
区间内的最大值。
我们要求根节点的值,首先求得左右两个孩子的值,再从左右两个孩子中取出较大的值作为根节点的值。要求父节点区间的值,需要先求孩子节点区间的值,这是递归的性质。所以可以通过递归来创建线段树。
那么这个递归的终止条件,也就是 base case,是什么呢?
- base case:如果一个节点的区间长度为1,不能再划分,也就是递归到底了,就返回这个元素本身。
明确了思路,那么我们的递归函数需要几个参数?
首先,既然是创建节点,那么需要节点在tree
数组中的索引;其次,这个节点对应的区间的左边界和右边界。
总共需要 3 个参数,写出的代码如下:
// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
// base case:递归到叶子节点了
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
//划分区间
int mid = l + (r - l) / 2;
// 求(左孩子)左边区间的最大值
buildSegmentTree(leftTreeIndex, l, mid);
// 求(右孩子)右区间的最大值
buildSegmentTree(rightTreeIndex, mid + 1, r);
//合并左右区间,求左区间和右区间点的最大值
tree[treeIndex] = Math.max(tree[leftTreeIndex], tree[rightTreeIndex]);
}
当然这里最后一句是会报错的。因为tree
的元素类型是泛型,不支持Math.max()
函数。
还有一个问题是:如果你在这里把区间合并的逻辑写成了只取最大值,那么这个线段树就只能求某个区间的最大值,不能用于求取区间的最小值、或者区间的和,限制了线段树的应用场景。
一个更好的方法是:用户可以根据自己的业务场景,自由选择合并区间的逻辑。
要达成这个目的,我们需要创建一个接口,用户需要实现这个接口来实现自己的区间合并逻辑。
//融合器,表示如何合并两个区间的统计值
public interface Merger<E> {
// a 表示左区间的统计值,b 表示有区间的统计值
//返回整个[左区间+右区间] 的统计值
E merge(E a, E b);
}
在线段树的构造函数中,添加一个Merger
参数,并且调用buildSegmentTree()
构建线段树。
public class SegmentTree<E> {
private E[] tree; //线段树
private E[] data; //数据
private Merger<E> merger;//融合器
public SegmentTree(E[] arr, Merger<E> merger) {
this.merger = merger;
data = (E[]) new Object[arr.length];
tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
//构建线段树
buildSegmentTree(0, 0, data.length - 1);
}
.
.
.
}
然后,修改buildSegmentTree()
方法的最后一行。
// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
// base case:递归到叶子节点了
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
//划分区间
int mid = l + (r - l) / 2;
// 求(左孩子)左区间的统计值
buildSegmentTree(leftTreeIndex, l, mid);
// 求(右孩子)右区间的统计值
buildSegmentTree(rightTreeIndex, mid + 1, r);
//求当前节点 [左区间+右区间] 的统计值
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
线段树的查询
我们用下面的数组构建一棵线段树,查询区间[2,5]
为例。
- 我们首先从根节点开始查询:在区间
[0,7]
中查询[2,5]
。
将根节点的区间分为左孩子[0,3]
和右孩子[4,7]
,而我们的查询的区间[2,5]
,不能其中一个孩子的区间完全包括。
于是上述问题转化为两个子问题:
- 在区间
[0,3]
中查询[2,3]
; - 在区间
[4,7]
中查询[4,5]
; - 最后将
[2,3]
和[4,5]
结果合并,得到区间[2,5]
的结果。
将[0,3]
划分为左孩子[0,1]
和右孩子[2,3]
,此时右孩子刚好和查询的区间重合,返回结果。
将[4,7]
划分为左孩子[4,5]
和右孩子[6,7]
,此时左孩子刚好和查询的区间重合,返回结果。
从上面可以看出,在线段树的区间查找过程中,不需要遍历区间中的每个元素,只需要在线段树中找到对应的所有子区间,再将这些子区间的结果合并即可。而查询所经过的节点数最多是树的高度,时间复杂度为O(logn)
。
而且上面的过程也是递归的过程。
- 递归终止条件就是:查询区间的边界和节点的边界完全重合,就返回该节点的统计值。
如果不重合,那怎么办?
这时应该分 3 种情况:
- 如果查询区间的左边界大于中间节点,那么就查询右区间
-
如果查询区间的右边界小于等于中间节点,那么就查询左区间
-
如果不属于上述两种情况,那么查询的区间就要根据中间节点拆分
递归函数的参数应该有 4 个,当前节点所在的区间的左边界和右边界,用户要查询的的区间的左边界和右边界。
代码如下:
//在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
private E query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 如果左边界大于中间节点,则查询右区间
if (queryL > mid)
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
// 如果右边界小于等于中间节点,则查询左区间
if (queryR <= mid)
return query(leftTreeIndex, l, mid, queryL, queryR);
// 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
//合并左右区间的查询结果
return merger.merge(leftResult, rightResult);
}
线段树的更新
如下图所示,更新A[i]=100
,那么将需要更新的索引 i
和区间的终点mid
,分为两种情况。
- 如果
i>mid
,那么索引i
落在右区间,更新右区间; - 如果
i<=mid
,那么索引i
落在左区间,更新左区间;
那么递归的终止条件是什么呢?
- 当递归到叶子节点的时候,就值更新这个节点:叶子节点就是区间长度为 1 的节点。
当更新完叶子节点后,还需要回溯,更新父节点区间的统计值。
代码如下:
//将 index 位置的值,更新为 e
public void update(int index, E e) {
if (index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal");
data[index] = e;
//更新线段树相应的节点
updateTree(0, 0, data.length - 1, index, e);
}
// 在以 treeIndex 为根的线段树中,更新 index 的值为 e
private void updateTree(int treeIndex, int l, int r, int index, E e) {
//递归终止条件
if (l == r) {
tree[treeIndex] = e;
return;
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index > mid)
updateTree(rightTreeIndex, mid + 1, r, index, e);
else //index <= mid
updateTree(leftTreeIndex, l, mid, index, e);
//更新当前节点
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
完整代码
public class SegmentTree<E> {
private E[] tree; //线段树
private E[] data; //数据
private Merger<E> merger;//融合器
public SegmentTree(E[] arr, Merger<E> merger) {
this.merger = merger;
data = (E[]) new Object[arr.length];
tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
for (int i = 0; i < arr.length; i++) {
data[i] = arr[i];
}
//构建线段树
buildSegmentTree(0, 0, data.length - 1);
}
// 返回数组元素个数
public int getSize() {
return data.length;
}
// 根据索引获取数据
public E get(int index) {
if (index < 0 || index > data.length)
throw new IllegalArgumentException("Index is illegal");
return data[index];
}
//根据一个节点的索引 index,返回这个节点的左孩子的索引
private int leftChild(int index) {
return 2 * index + 1;
}
//根据一个节点的索引 index,返回这个节点的右孩子的索引
private int rightChild(int index) {
return 2 * index + 2;
}
// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
// base case:递归到叶子节点了
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
//划分区间
int mid = l + (r - l) / 2;
// 求(左孩子)左区间的统计值
buildSegmentTree(leftTreeIndex, l, mid);
// 求(右孩子)右区间的统计值
buildSegmentTree(rightTreeIndex, mid + 1, r);
//求当前节点 [左区间+右区间] 的统计值
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
//查询区间,返回区间 [queryL, queryR] 的统计值
public E query(int queryL, int queryR) {
//首先进行边界检查
if (queryL < 0 || queryL > data.length || queryR < 0 || queryR > data.length || queryL > queryR) {
throw new IllegalArgumentException("Index is illegal");
}
return query(0, 0, data.length - 1, queryL, queryR);
}
//在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
private E query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 如果左边界大于中间节点,则查询右区间
if (queryL > mid)
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
// 如果右边界小于等于中间节点,则查询左区间
if (queryR <= mid)
return query(leftTreeIndex, l, mid, queryL, queryR);
// 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
//合并左右区间的查询结果
return merger.merge(leftResult, rightResult);
}
//将 index 位置的值,更新为 e
public void update(int index, E e) {
if (index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal");
data[index] = e;
//更新线段树相应的节点
updateTree(0, 0, data.length - 1, index, e);
}
// 在以 treeIndex 为根的线段树中,更新 index 的值为 e
private void updateTree(int treeIndex, int l, int r, int index, E e) {
//递归终止条件
if (l == r) {
tree[treeIndex] = e;
return;
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index > mid)
updateTree(rightTreeIndex, mid + 1, r, index, e);
else //index <= mid
updateTree(leftTreeIndex, l, mid, index, e);
//更新当前节点
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
public String toString() {
StringBuffer res = new StringBuffer();
res.append('[');
for (int i = 0; i < tree.length; i++) {
if (tree[i] != null)
res.append(tree[i]);
else res.append("null");
if (i != tree.length - 1)
res.append(", ");
}
res.append(']');
return res.toString();
}
}
使用例子
定义一个求区间的最大值的线段树,代码如下:
public class Main {
public static void main(String[] args) {
Integer[] nums = new Integer[]{34, 65, 8, 10, 21, 86, 79, 30};
SegmentTree<Integer> segTree = new SegmentTree<>(nums, new Merger<Integer>() {
@Override
public Integer merge(Integer a, Integer b) {
//返回 a 和 b 的最大值
return Math.max(a, b);
}
});
// 查询区间 [2,5] 的最大值
System.out.println(segTree.query(4, 7));
}
}
当然,你也可以定义一个求区间内元素的和的线段树,只需要修改merge()
方法的实现即可:
public class Main {
public static void main(String[] args) {
Integer[] nums = new Integer[]{34, 65, 8, 10, 21, 86, 79, 30};
SegmentTree<Integer> segTree = new SegmentTree<>(nums, new Merger<Integer>() {
@Override
public Integer merge(Integer a, Integer b) {
//返回 a 和 b 的和
return a + b;
}
});
// 查询区间 [2,5] 的和
System.out.println(segTree.query(4, 7));
}
}
LeetCode 上相关的题目
303. 区域检索和-数组不可变
题目链接:303. 区域和检索 - 数组不可变
线段树求解
这道题是求取区间和,可以使用线段树来实现。时间复杂度为O(logn)
,空间复杂度为O(n)
。
class NumArray {
private int[] tree;
private int[] data;
public NumArray(int[] nums) {
data = nums;
tree = new int[nums.length * 4];
//当数组长度大于 0 时,才创建线段树
if (nums.length > 0)
//创建线段树
buildSegmentTree(0, 0, nums.length - 1);
}
//根据一个节点的索引 index,返回这个节点的左孩子的索引
private int leftChild(int index) {
return 2 * index + 1;
}
//根据一个节点的索引 index,返回这个节点的右孩子的索引
private int rightChild(int index) {
return 2 * index + 2;
}
// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
//递归终止条件:区间长度为 1
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
//创建左区间(左孩子)的和
buildSegmentTree(leftTreeIndex, l, mid);
//创建右区间(右孩子)的和
buildSegmentTree(rightTreeIndex, mid + 1, r);
//合并做有区间的和
tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
}
public int sumRange(int i, int j) {
//tree.length == 1 表示数组没有元素,直接返回 0
if (tree.length == 1)
return 0;
return queryRange(0, 0, data.length - 1, i, j);
}
//在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
private int queryRange(int treeIndex, int l, int r, int queryL, int queryR) {
if (queryL == l && queryR == r)
return tree[treeIndex];
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 如果左边界大于中间节点,则查询右区间
if (queryL > mid)
return queryRange(rightTreeIndex, mid + 1, r, queryL, queryR);
// 如果右边界小于等于中间节点,则查询左区间
if (queryR <= mid)
return queryRange(leftTreeIndex, l, mid, queryL, queryR);
// 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
int leftResult = queryRange(leftTreeIndex, l, mid, queryL, mid);
int rightResult = queryRange(rightTreeIndex, mid + 1, r, mid + 1, queryR);
//合并左右区间的查询结果
return leftResult + rightResult;
}
}
前缀和求解
其实这道题有更加高效的解法,那就是前缀和。
前缀和的定义是:定义一个前缀和数组sum
,每个元素sum[i]
表示的是nums[0...i]
区间中的元素的和。
那么我们要求[i,j]
区间的和,就可以使用sum[j]-sum[i-1]
得到。
注意当i=0
时,i-1=-1
会溢出。因此sums
数组应该整体向后移动一位。
sum[0]=0
表示前面没有元素,和应该是 0。
此时[i,j]
区间的和应该是sum[j+1]-sum[i]
。
代码如下:
class NumArray {
//前缀和数组
private int[] sums;
public NumArray(int[] nums) {
//边界条件判断
if (nums == null || nums.length == 0) {
sums = new int[]{};
}
int n = nums.length;
//由于整体后移了一位,长度应该为 n+1
sums = new int[n + 1];
//构建前缀和
for (int i = 0; i < n; i++) {
sums[i + 1] = sums[i] + nums[i];
}
}
public int sumRange(int i, int j) {
if (sums.length == 0)
return 0;
//直接返回前缀和相减的结果
return sums[j + 1] - sums[i];
}
}
使用前缀和数组的空间复杂度依然是O(n)
,但时间复杂度是O(1)
,优于线段树。
那既然这样,区间问题为什么还要用线段树呢?
因为这道题目加了一个限制:数组不可变,也就是说数组里的元素是固定的。
如果数组的内容是可变的,那么每次更新索引[i]
的数据,相应的[i...n]
区间的前缀和都需要更新。前缀和数组更新的时间时间复杂度是O(n)
,而线段树的更新复杂度是O(logn)
。
因此,在数组内容可变的情况下,线段树依然是更优的选择。
303. 区域检索和-数组可修改
题目链接:307. 区域和检索 - 数组可修改
前缀和求解
根据上面前缀和的做法,我们只需要添加更新数据和对应的前缀和的逻辑即可。
class NumArray {
int[] sums;
int[] data;
public NumArray(int[] nums) {
//边界条件判断
if (nums == null || nums.length == 0) {
sums = new int[]{};
}
data = nums;
int n = nums.length;
//由于整体后移了一位,长度应该为 n+1
sums = new int[n + 1];
//构建前缀和
for (int i = 0; i < n; i++) {
sums[i + 1] = sums[i] + nums[i];
}
}
public void update(int i, int val) {
// 更新数组
data[i] = val;
//更新从 i 到 n 的前缀和
for (int j = i; j < data.length; j++) {
sums[j + 1] = sums[j] + data[j];
}
}
public int sumRange(int i, int j) {
if (sums.length == 0)
return 0;
//直接返回前缀和相减的结果
return sums[j + 1] - sums[i];
}
}
update()
方法的时间复杂度是O(n)
,sumRange()
方法的时间复杂度是O(1)
。
线段树求解
同理,这里只需要在上面 303. 区域和检索 - 数组不可变 的线段树解法上,添加更新的线段树的逻辑即可。
class NumArray {
int[] tree;
int[] data;
public NumArray(int[] nums) {
data = nums;
int n = nums.length;
if (nums == null || nums.length == 0) {
tree = new int[]{};
return;
}
tree = new int[n * 4];
buildSegmentTree(0, 0, data.length - 1);
}
private void buildSegmentTree(int treeIndex, int l, int r) {
// base case:递归到叶子节点了
if (l == r) {
tree[treeIndex] = data[l];
return;
}
//划分区间
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// 求(左孩子)左区间的统计值
buildSegmentTree(leftTreeIndex, l, mid);
// 求(右孩子)右区间的统计值
buildSegmentTree(rightTreeIndex, mid + 1, r);
//求当前节点 [左区间+右区间] 的统计值
tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
}
private int leftChild(int treeIndex) {
return 2 * treeIndex + 1;
}
private int rightChild(int treeIndex) {
return 2 * treeIndex + 2;
}
//将 index 位置的值,更新为 e
public void update(int i, int val) {
// 更新数组
data[i] = val;
//更新线段树
updateTree(0, i, 0, data.length - 1);
}
// 在以 treeIndex 为根的线段树中,更新 index 的值为 e
private void updateTree(int treeIndex, int index, int l, int r) {
//递归终止条件
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index <= mid)
updateTree(leftTreeIndex, index, l, mid);
else //index <= mid
updateTree(rightTreeIndex, index, mid + 1, r);
//更新当前节点
tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
}
public int sumRange(int i, int j) {
if (data.length == 0)
return 0;
return queryRange(0, 0, data.length - 1, i, j);
}
private int queryRange(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (queryL > mid)
return queryRange(rightTreeIndex, mid + 1, r, queryL, queryR);
if (queryR <= mid)
return queryRange(leftTreeIndex, l, mid, queryL, queryR);
int leftResult = queryRange(leftTreeIndex, l, mid, queryL, mid);
int rightResult = queryRange(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return leftResult + rightResult;
}
}
update()
方法和sumRange()
方法的时间复杂度都是O(n)
。
总结与扩展
线段树,虽然不是满二叉树。但是我们却可以把它看作一棵满二叉树,进而使用数组来存储这棵树。
其次,在线段树中,我们定义了每个节点,存储的数据是这个节点对应区间的统计值(最大值、最小值、区间和)。通过这一点,你可以体会到,对于树中的节点,你可以赋予它独特的定义,进而可以高效地解决各种各样的问题。因此,树的使用范围是非常广泛的。
对于,线段树的构建、更新区间、和查询区间,这 3 个操作,都是先递归访问叶子节点,然后再回溯访问父节点,合并左右孩子区间的结果,这本质上是一种后序遍历的思想。
对一个区间进行更新
在本文的例子中,每次更新都是对一个元素进行更新。现在考虑另一种更新:对某个区间内的所有元素进行更新。
比如:将[2,5]
区间中的所有元素都加 3,就需要遍历这个区间中的每个元素,区间更新的时间复杂度就变为了O(n)
。为了降低区间更新的复杂度,有一种专门的方法:懒惰更新。
懒惰更新的思想是:在每次更新区间时,我们实际上先不更新实际的数据,而是使用另一个lazy
数组,来标记这些未更新的内容。
那么什么时候才会更新这些节点呢?
当我们下一次更新或者查询到这些数据时,先查一下lazy
数组中是否有数据未更新,然后将未更新的内容进行更新,再访问对应的数据。
例如:
- 第一次:将
[2,5]
区间中的所有元素都加 3,实际上lazy
数组就会标记[2,5]
区间的数据未更新; - 第二次:将
[4,7]
区间中的所有元素都减 5。这时,查询lazy
数组中,发现[4,5]
区间的数据未更新,那么先更新[4,5]
区间的内容,[2,3]
区间的标记不变。然后标记[4,7]
区间中的内容未更新。
通过懒惰更新,时间复杂度降为了O(logn)
。
二维线段树
在这篇文章中,我们处理的都是一维的线段树,实际中还可以产生二位线段树。
一维线段树,就是数据都是一维数组,每个节点记录的区间只有左右两个边界。
在二维线段树中,数据是二维数组,也就是一个矩阵,每个节点记录的区间有上下左右两个边界。那么每个节点就有 4 个孩子。
进一步扩展,你也可以设计出 3 维的线段树,甚至更高维的线段树。
线段树是一种设计思想,利用树这种数据结构,如何把一个大的数据单元,递归地拆分为小的数据单元,同时,利用树这种数据结构,可以高效地进行查询、更新等操作。
动态线段树
在这篇文章中,我使用了数组这种数据结构来存储树,造成了空间的浪费。实际上,我们可以也使用链式存储,可以更好地利用空间。
实际上,动态线段树有一个更加重要的应用,。例如,我们要存储 10000000 大小的线段树,但是不会对这么大一个区间中的每一个部分都进行访问,可能只会对一种某个小部分进行访问。
那么我们可以不用一开始就创建这么巨大的线段树,而是只创建一个根节点,等到实际访问某个区间时,在根据区间的边界,动态地创建线段树。
树状数组
这篇文章讲的线段树,实际上是对区间操作的一种数据结构。与区间操作相关的数据结构,还有另外一个:树状数组,英文名为Binary Index Tree
。感兴趣可以自行查阅。
如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。
我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学。