线段树详解与实现

此篇文章用于记录《玩转数据结构》课程的学习笔记

什么是线段树

线段树也被称为区间树,英文名为Segment Tree或者Interval tree,是一种高级的数据结构。这种数据结构更多出现在竞赛中,在常见的本科数据结构教材里没有介绍这种数据结构。但是,在面试中却有可能碰到和线段树相关的问题。那么为什么会产生线段树这种数据结构,线段树到底是为了解决什么样的一种问题呢?

其实这里的线段可以理解为区间,线段树就是为了解决区间问题的。

有一个很经典的线段树问题是:区间染色。

假设有一面墙,长度为 n,每次选择一段墙进行染色。


在区间染色的过程中,每次选择一段区间进行染色,这时新的颜色可能会覆盖之前的颜色。

最后的问题是:

  • 在经过 m 次染色操作后,我们可以在整个区间看见多少种颜色?

更加普遍的说法是:

  • 在经过 m 次染色操作后,我们可以在区间 [i, j]内看见多少种颜色?

由于第一个问题是第二个问题的一个特例,我们采用第二种问题来思考解决方法。

从上面可以看出,我们对于区间,有 2 种操作,分别是染色操作查询区间的颜色,使用更加一般的说法,染色操作就是更新区间查询区间的颜色就是查询区间

这类问题里面,更加常见的的是区间查询:一个数组存放的不再是颜色,而是具体的数字,查询某个区间[i, j]统计值。这里的统计值是指:区间内最大值、最小值、或者这个区间的数字和。


比如:

  • 查询 2018 年注册的用户中消费最高的用户
  • 查询 2018 年注册的用户中消费最低的用户

注意上面两种情况都是动态查询,我们查询的消费数据不只是 2018 的消费数据

如果我们想查询 2018 年中消费最高的用户,那么 2018 年的数据已经固定了,我们直接在这一年的数据中进行统计分析就行了。

但是一个 2018 年注册的用户,在 2019 年、2020 年都可能会有消费。我们实际上查询的是:2018 年注册的用户中,到现在为止,消费最高的用户。

这种情况下,数据是在动态变化的, 也就是说:2017 年注册的用户中,每个用户的消费额是会更新的,这就对应到更新区间的操作。

此时线段树就是一种好的选择。

按照通常的思路,使用数组存储上述的元素是比较好的,思考上面两个操作的时间复杂度:

  • 更新区间:每次根据需要更新的区间的首尾索引,逐个遍历区间中的元素进行更新,时间复杂度为O(n)
  • 查询区间:每次根据需要更新的区间的首尾索引,逐个遍历区间种的元素进行查询,时间复杂度为O(n)

两个操作的时间复杂度均为O(n),对于需要多次动态使用的场景来说,性能可能是不够好的。

在这类问题中,我们关注的是一个个区间内的元素的情况,线段树就有用武之地了,线段树的优点就是把两个操作的时间复杂度降到了O(logn)

操作 使用数组 使用线段树
更新 O(n) O(logn)
查询 O(n) O(logn)

这里提一点,如果你看到一个算法的时间复杂度是O(logn),那这个算法大多与二叉树分治算法有关。

这里也不例外,线段树就是使用二叉树来实现的。

那么一个区间是如何被构建成为一个二叉树的?

对于一个数组 A,如下所示:


对应的线段树就是:


二叉树中每个非叶子节点表示的是区间内元素的统计值叶子节点存储的就是元素本身。上面说了统计值是指:区间内最大值、最小值、或者这个区间的数字和。比如你要求区间的最大值,每个每个节点存储的就是这个区间内元素的最大值。像下面这样:


假设你要查询[4,7]区间内的最大值,那么不用查到叶子节点,而是查到A[4, 7]这个节点就行了。


当然,并不是所有的区间都恰好落在一个节点,比如你要求[2, 5]区间内的最大值。那么就要分别找到A[2, 3]A[4,5]的最大值,再进行比较。


可以看出,线段树的查询区间操作不需要遍历区间中的每一个元素,只要找到对应的树节点就可以返回,时间复杂度为O(logn)

总结

从更加抽象的角度来讲,线段树的使用场景就是,对于给定区间,进行更新区间和查询区间操作:

  • 更新区间:更新区间中的一个元素或者一个区间的值。
  • 查询区间:查询一个区间[i, j]的最大值、最小值、或者区间的数字和。

注意,在大多数情况下,我们是不考虑区间里添加元素和删除元素的,我们假设区间的大小是固定的。

线段树的表示

树的一般表示方法是链式存储,每个节点有两个指针,一个指向左孩子,一个指向右孩子。但是满二叉树,和完全二叉树,除了使用链表法来存储,还可以使用数组来表示。

在满二叉树和完全二叉树的数组表示中,假设一个节点的所以是i,那么左孩子的索引就是$2 \times i +1$,右孩子的索引就是$2 \times i +2$。

那么线段树是不是满二叉树或者完全二叉树呢? 能不能使用数组来表示呢?

在上面的例子中,我们的二叉树恰好是一棵满二叉树,这是因为我们的数组大小恰好是 8,也就是$2^3$,只有数组大小恰好是 2 的 n 次幂,所对应的线段树才会是一个满二叉树。在大部分情况下,线段树并不是一个满二叉树。如果一个数组的大小是 10 ,对应的线段树如下图所示。


所以线段树不是满二叉树,也不是完全二叉树。但实际上:线段树是平衡二叉树,是可以保证O(logn)的时间复杂度的,这里就不证明了。

其实平衡二叉树可以看作特殊的满二叉树,进而使用数组来表示。


下一步就是确定:对于大小为 n 的数组,需要多大的空间来存储线段树。

首先说一个结论,对于高为 h 层的满二叉树,一共有$2{h}-1$个节点,而最后一层有$2$个节点。那么:$除了最后一层的前面所有节点数 = 总的节点数 - 最后一层的节点数 = (2^{h}-1) - 2^{h-1} \approx 2^{h-1}-1$。也就是:前面所有层的节点数之和等于最后一层的节点数减 1。


那么线段树所需要的节点数量,分两种情况来讨论:

  • 如果 n 恰好是 2 的 k 次幂,由于线段树最后一层的叶子节点存储的是数组元素本身,最后一层的节点数就是 n,而根据上面的结论,前面所有层的节点数之和是$n-1$,那么总节点数就是$2 \times n-1$。为了方便起见,分配$ 2 \times n$的空间。

  • 如果 n 不是 2 的 k 次幂,最坏的情况就是$n=2^k+1$,那么有一个元素需要开辟新的一层来存储,需要$4 \times n-5$的大小。为了方便起见,我们分配$4 \times n$的空间,已经足够了。


综上,首先需要判断数组的大小是否为 $2^k$,是则使用 $2 \times n$ 的空间,否则使用的 $4 \times n$ 空间。下面线段树的实现是基于数组来实现的,不过为了简便起见,下面的实现统一使用 $4 \times n$ 空间来存储线段树。

使用数组来存储线段树,会有一定的空间浪费,但是换来的时间复杂度的降低是可以接受的。同时,我也在最后会介绍链式存储的实现方式。

线段树的实现

首先需要两个数组,其中data存放原来的数据,tree就是存放线段树。

基本的 API 有getSize():返回数组元素个数;get(int index):根据索引获取数据。

其中每个元素使用泛型E表示,这是为了考虑可扩展性:如果你的数组元素不是数字,而是自定义的类,那么使用泛型就是比较好的选择。

public class SegmentTree<E> {
    private E[] tree; //线段树
    private E[] data; //数据

    public SegmentTree(E[] arr) {
        data = (E[]) new Object[arr.length];
        tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
    }

    // 返回数组元素个数
    public int getSize() {
        return data.length;
    }

    // 根据索引获取数据
    public E get(int index) {
        if (index < 0 || index > data.length)
            throw new IllegalArgumentException("Index is illegal");
        return data[index];
    }
}

由于把线段树看作一棵完全二叉树,应该定义两个 API,根据一个节点获取到它的左孩子和右孩子。

// 根据一个节点的索引 index,返回这个节点的左孩子的索引
private int leftChild(int index) {
	return 2 * index + 1;
}

// 根据一个节点的索引 index,返回这个节点的右孩子的索引
private int rightChild(int index) {
	return 2 * index + 2;
}

线段树的构建

下面考虑的就是构造线段树的每个节点。这里以求区间的最大值为例。


根节点存储的是整个[0,7]区间的最大值,左孩子存储的是[0,3]区间内的最大值,右孩子存储的是[4,7]区间内的最大值。

我们要求根节点的值,首先求得左右两个孩子的值,再从左右两个孩子中取出较大的值作为根节点的值。要求父节点区间的值,需要先求孩子节点区间的值,这是递归的性质。所以可以通过递归来创建线段树。

那么这个递归的终止条件,也就是 base case,是什么呢?

  • base case:如果一个节点的区间长度为1,不能再划分,也就是递归到底了,就返回这个元素本身。

明确了思路,那么我们的递归函数需要几个参数?

首先,既然是创建节点,那么需要节点在tree数组中的索引;其次,这个节点对应的区间的左边界和右边界。

总共需要 3 个参数,写出的代码如下:

// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
    // base case:递归到叶子节点了
    if (l == r) {
        tree[treeIndex] = data[l];
        return;
    }

    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    //划分区间
    int mid = l + (r - l) / 2;
    // 求(左孩子)左边区间的最大值
    buildSegmentTree(leftTreeIndex, l, mid);
    // 求(右孩子)右区间的最大值
    buildSegmentTree(rightTreeIndex, mid + 1, r);
    //合并左右区间,求左区间和右区间点的最大值
    tree[treeIndex] = Math.max(tree[leftTreeIndex], tree[rightTreeIndex]);
}

当然这里最后一句是会报错的。因为tree的元素类型是泛型,不支持Math.max()函数。

还有一个问题是:如果你在这里把区间合并的逻辑写成了只取最大值,那么这个线段树就只能求某个区间的最大值,不能用于求取区间的最小值、或者区间的和,限制了线段树的应用场景。

一个更好的方法是:用户可以根据自己的业务场景,自由选择合并区间的逻辑。

要达成这个目的,我们需要创建一个接口,用户需要实现这个接口来实现自己的区间合并逻辑。

//融合器,表示如何合并两个区间的统计值
public interface Merger<E> {
    // a 表示左区间的统计值,b 表示有区间的统计值
    //返回整个[左区间+右区间] 的统计值
    E merge(E a, E b);
}

在线段树的构造函数中,添加一个Merger参数,并且调用buildSegmentTree()构建线段树。

public class SegmentTree<E> {
    private E[] tree; //线段树
    private E[] data; //数据
    private Merger<E> merger;//融合器

    public SegmentTree(E[] arr, Merger<E> merger) {
        this.merger = merger;
        data = (E[]) new Object[arr.length];
        tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        //构建线段树
        buildSegmentTree(0, 0, data.length - 1);
    }
    .
    .
    .
}

然后,修改buildSegmentTree()方法的最后一行。

// 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
private void buildSegmentTree(int treeIndex, int l, int r) {
    // base case:递归到叶子节点了
    if (l == r) {
        tree[treeIndex] = data[l];
        return;
    }

    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    //划分区间
    int mid = l + (r - l) / 2;
    // 求(左孩子)左区间的统计值
    buildSegmentTree(leftTreeIndex, l, mid);
    // 求(右孩子)右区间的统计值
    buildSegmentTree(rightTreeIndex, mid + 1, r);
    //求当前节点 [左区间+右区间] 的统计值
    tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

线段树的查询

我们用下面的数组构建一棵线段树,查询区间[2,5]为例。


  1. 我们首先从根节点开始查询:在区间[0,7]中查询[2,5]

将根节点的区间分为左孩子[0,3]和右孩子[4,7],而我们的查询的区间[2,5]不能其中一个孩子的区间完全包括

于是上述问题转化为两个子问题:

  • 在区间[0,3]中查询[2,3]
  • 在区间[4,7]中查询[4,5]
  • 最后将[2,3][4,5]结果合并,得到区间[2,5]的结果。

[0,3]划分为左孩子[0,1]和右孩子[2,3],此时孩子刚好和查询的区间重合,返回结果。

[4,7]划分为左孩子[4,5]和右孩子[6,7],此时孩子刚好和查询的区间重合,返回结果。


从上面可以看出,在线段树的区间查找过程中,不需要遍历区间中的每个元素,只需要在线段树中找到对应的所有子区间,再将这些子区间的结果合并即可。而查询所经过的节点数最多是树的高度,时间复杂度为O(logn)

而且上面的过程也是递归的过程。

  • 递归终止条件就是:查询区间的边界和节点的边界完全重合,就返回该节点的统计值。

如果不重合,那怎么办?

这时应该分 3 种情况:

  • 如果查询区间的左边界大于中间节点,那么就查询右区间

  • 如果查询区间的右边界小于等于中间节点,那么就查询左区间


  • 如果不属于上述两种情况,那么查询的区间就要根据中间节点拆分


递归函数的参数应该有 4 个,当前节点所在的区间的左边界和右边界用户要查询的的区间的左边界和右边界

代码如下:

//在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
private E query(int treeIndex, int l, int r, int queryL, int queryR) {
    if (l == queryL && r == queryR) {
        return tree[treeIndex];
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    // 如果左边界大于中间节点,则查询右区间
    if (queryL > mid)
        return query(rightTreeIndex, mid + 1, r, queryL, queryR);
    // 如果右边界小于等于中间节点,则查询左区间
    if (queryR <= mid)
        return query(leftTreeIndex, l, mid, queryL, queryR);
    // 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
    E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
    E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
    //合并左右区间的查询结果
    return merger.merge(leftResult, rightResult);
}

线段树的更新

如下图所示,更新A[i]=100,那么将需要更新的索引 i和区间的终点mid,分为两种情况。

  • 如果i>mid,那么索引i落在右区间,更新右区间;
  • 如果i<=mid,那么索引i落在左区间,更新左区间;

那么递归的终止条件是什么呢?

  • 当递归到叶子节点的时候,就值更新这个节点:叶子节点就是区间长度为 1 的节点。

当更新完叶子节点后,还需要回溯,更新父节点区间的统计值。


代码如下:

//将 index 位置的值,更新为 e
public void update(int index, E e) {
    if (index < 0 || index >= data.length)
        throw new IllegalArgumentException("Index is illegal");
    data[index] = e;
    //更新线段树相应的节点
    updateTree(0, 0, data.length - 1, index, e);
}

// 在以 treeIndex 为根的线段树中,更新 index 的值为 e
private void updateTree(int treeIndex, int l, int r, int index, E e) {
    //递归终止条件
    if (l == r) {
        tree[treeIndex] = e;
        return;
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (index > mid)
        updateTree(rightTreeIndex, mid + 1, r, index, e);
    else //index <= mid
        updateTree(leftTreeIndex, l, mid, index, e);
    //更新当前节点
    tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

完整代码

public class SegmentTree<E> {
    private E[] tree; //线段树
    private E[] data; //数据
    private Merger<E> merger;//融合器

    public SegmentTree(E[] arr, Merger<E> merger) {
        this.merger = merger;
        data = (E[]) new Object[arr.length];
        tree = (E[]) new Object[arr.length * 4]; //大小为 4 * n
        for (int i = 0; i < arr.length; i++) {
            data[i] = arr[i];
        }
        //构建线段树
        buildSegmentTree(0, 0, data.length - 1);
    }

    // 返回数组元素个数
    public int getSize() {
        return data.length;
    }

    // 根据索引获取数据
    public E get(int index) {
        if (index < 0 || index > data.length)
            throw new IllegalArgumentException("Index is illegal");
        return data[index];
    }

    //根据一个节点的索引 index,返回这个节点的左孩子的索引
    private int leftChild(int index) {
        return 2 * index + 1;
    }

    //根据一个节点的索引 index,返回这个节点的右孩子的索引
    private int rightChild(int index) {
        return 2 * index + 2;
    }

    // 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
    private void buildSegmentTree(int treeIndex, int l, int r) {
        // base case:递归到叶子节点了
        if (l == r) {
            tree[treeIndex] = data[l];
            return;
        }

        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        //划分区间
        int mid = l + (r - l) / 2;
        // 求(左孩子)左区间的统计值
        buildSegmentTree(leftTreeIndex, l, mid);
        // 求(右孩子)右区间的统计值
        buildSegmentTree(rightTreeIndex, mid + 1, r);
        //求当前节点 [左区间+右区间] 的统计值
        tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
    }

    //查询区间,返回区间 [queryL, queryR] 的统计值
    public E query(int queryL, int queryR) {
        //首先进行边界检查
        if (queryL < 0 || queryL > data.length || queryR < 0 || queryR > data.length || queryL > queryR) {
            throw new IllegalArgumentException("Index is illegal");
        }
        return query(0, 0, data.length - 1, queryL, queryR);
    }

    //在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
    private E query(int treeIndex, int l, int r, int queryL, int queryR) {
        if (l == queryL && r == queryR) {
            return tree[treeIndex];
        }
        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        // 如果左边界大于中间节点,则查询右区间
        if (queryL > mid)
            return query(rightTreeIndex, mid + 1, r, queryL, queryR);
        // 如果右边界小于等于中间节点,则查询左区间
        if (queryR <= mid)
            return query(leftTreeIndex, l, mid, queryL, queryR);
        // 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
        E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
        E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
        //合并左右区间的查询结果
        return merger.merge(leftResult, rightResult);
    }
    //将 index 位置的值,更新为 e
    public void update(int index, E e) {
        if (index < 0 || index >= data.length)
            throw new IllegalArgumentException("Index is illegal");
        data[index] = e;
        //更新线段树相应的节点
        updateTree(0, 0, data.length - 1, index, e);
    }

    // 在以 treeIndex 为根的线段树中,更新 index 的值为 e
    private void updateTree(int treeIndex, int l, int r, int index, E e) {
        //递归终止条件
        if (l == r) {
            tree[treeIndex] = e;
            return;
        }
        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        if (index > mid)
            updateTree(rightTreeIndex, mid + 1, r, index, e);
        else //index <= mid
            updateTree(leftTreeIndex, l, mid, index, e);
        //更新当前节点
        tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
    }
    
    public String toString() {
        StringBuffer res = new StringBuffer();
        res.append('[');
        for (int i = 0; i < tree.length; i++) {
            if (tree[i] != null)
                res.append(tree[i]);
            else res.append("null");
            if (i != tree.length - 1)
                res.append(", ");
        }
        res.append(']');
        return res.toString();
    }
}

使用例子

定义一个求区间的最大值的线段树,代码如下:

public class Main {
    public static void main(String[] args) {
        Integer[] nums = new Integer[]{34, 65, 8, 10, 21, 86, 79, 30};
        SegmentTree<Integer> segTree = new SegmentTree<>(nums, new Merger<Integer>() {
            @Override
            public Integer merge(Integer a, Integer b) {
                //返回 a 和 b 的最大值
                return Math.max(a, b);
            }
        });
        // 查询区间 [2,5] 的最大值
        System.out.println(segTree.query(4, 7));
    }
}

当然,你也可以定义一个求区间内元素的和的线段树,只需要修改merge()方法的实现即可:

public class Main {
    public static void main(String[] args) {
        Integer[] nums = new Integer[]{34, 65, 8, 10, 21, 86, 79, 30};
        SegmentTree<Integer> segTree = new SegmentTree<>(nums, new Merger<Integer>() {
            @Override
            public Integer merge(Integer a, Integer b) {
                //返回 a 和 b 的和
                return a + b;
            }
        });
        // 查询区间 [2,5] 的和
        System.out.println(segTree.query(4, 7));
    }
}

LeetCode 上相关的题目

303. 区域检索和-数组不可变

题目链接:303. 区域和检索 - 数组不可变


线段树求解

这道题是求取区间和,可以使用线段树来实现。时间复杂度为O(logn),空间复杂度为O(n)

class NumArray {
    private int[] tree;
    private int[] data;

    public NumArray(int[] nums) {
        data = nums;
        tree = new int[nums.length * 4];
        //当数组长度大于 0 时,才创建线段树
        if (nums.length > 0)
            //创建线段树
            buildSegmentTree(0, 0, nums.length - 1);
    }

    //根据一个节点的索引 index,返回这个节点的左孩子的索引
    private int leftChild(int index) {
        return 2 * index + 1;
    }

    //根据一个节点的索引 index,返回这个节点的右孩子的索引
    private int rightChild(int index) {
        return 2 * index + 2;
    }

    // 在 treeIndex 的位置创建表示区间 [l,r] 的线段树
    private void buildSegmentTree(int treeIndex, int l, int r) {
        //递归终止条件:区间长度为 1
        if (l == r) {
            tree[treeIndex] = data[l];
            return;
        }
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        int mid = l + (r - l) / 2;
        //创建左区间(左孩子)的和
        buildSegmentTree(leftTreeIndex, l, mid);
        //创建右区间(右孩子)的和
        buildSegmentTree(rightTreeIndex, mid + 1, r);
        //合并做有区间的和
        tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
    }

    public int sumRange(int i, int j) {
    	//tree.length == 1 表示数组没有元素,直接返回 0
        if (tree.length == 1)
            return 0;
        return queryRange(0, 0, data.length - 1, i, j);
    }

    //在以 treeIndex 为根的线段树中 [l,r] 的范围里,搜索区间 [queryL, queryR]
    private int queryRange(int treeIndex, int l, int r, int queryL, int queryR) {
        if (queryL == l && queryR == r)
            return tree[treeIndex];

        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        // 如果左边界大于中间节点,则查询右区间
        if (queryL > mid)
            return queryRange(rightTreeIndex, mid + 1, r, queryL, queryR);
        // 如果右边界小于等于中间节点,则查询左区间
        if (queryR <= mid)
            return queryRange(leftTreeIndex, l, mid, queryL, queryR);
        // 如果上述两种情况都不是,则根据中间节点,拆分为两个查询区间
        int leftResult = queryRange(leftTreeIndex, l, mid, queryL, mid);
        int rightResult = queryRange(rightTreeIndex, mid + 1, r, mid + 1, queryR);
        //合并左右区间的查询结果
        return leftResult + rightResult;
    }
}

前缀和求解

其实这道题有更加高效的解法,那就是前缀和

前缀和的定义是:定义一个前缀和数组sum,每个元素sum[i]表示的是nums[0...i]区间中的元素的和。


那么我们要求[i,j]区间的和,就可以使用sum[j]-sum[i-1]得到。


注意当i=0时,i-1=-1会溢出。因此sums数组应该整体向后移动一位。

sum[0]=0表示前面没有元素,和应该是 0。

此时[i,j]区间的和应该是sum[j+1]-sum[i]


代码如下:

class NumArray {
    //前缀和数组
    private int[] sums;

    public NumArray(int[] nums) {
        //边界条件判断
        if (nums == null || nums.length == 0) {
            sums = new int[]{};
        }
        int n = nums.length;
        //由于整体后移了一位,长度应该为 n+1
        sums = new int[n + 1];
        //构建前缀和
        for (int i = 0; i < n; i++) {
            sums[i + 1] = sums[i] + nums[i];
        }
    }

    public int sumRange(int i, int j) {
        if (sums.length == 0)
            return 0;
        //直接返回前缀和相减的结果
        return sums[j + 1] - sums[i];
    }
}

使用前缀和数组的空间复杂度依然是O(n),但时间复杂度是O(1),优于线段树。

那既然这样,区间问题为什么还要用线段树呢?

因为这道题目加了一个限制:数组不可变,也就是说数组里的元素是固定的。

如果数组的内容是可变的,那么每次更新索引[i]的数据,相应的[i...n]区间的前缀和都需要更新。前缀和数组更新的时间时间复杂度是O(n),而线段树的更新复杂度是O(logn)

因此,在数组内容可变的情况下,线段树依然是更优的选择。

303. 区域检索和-数组可修改

题目链接:307. 区域和检索 - 数组可修改


前缀和求解

根据上面前缀和的做法,我们只需要添加更新数据和对应的前缀和的逻辑即可。

class NumArray {
    int[] sums;
    int[] data;

    public NumArray(int[] nums) {
        //边界条件判断
        if (nums == null || nums.length == 0) {
            sums = new int[]{};
        }
        data = nums;
        int n = nums.length;
        //由于整体后移了一位,长度应该为 n+1
        sums = new int[n + 1];
        //构建前缀和
        for (int i = 0; i < n; i++) {
            sums[i + 1] = sums[i] + nums[i];
        }
    }

    public void update(int i, int val) {
        // 更新数组
        data[i] = val;
        //更新从 i 到 n 的前缀和
        for (int j = i; j < data.length; j++) {
            sums[j + 1] = sums[j] + data[j];
        }
    }

    public int sumRange(int i, int j) {
        if (sums.length == 0)
            return 0;
        //直接返回前缀和相减的结果
        return sums[j + 1] - sums[i];
    }
}

update()方法的时间复杂度是O(n)sumRange()方法的时间复杂度是O(1)

线段树求解

同理,这里只需要在上面 303. 区域和检索 - 数组不可变 的线段树解法上,添加更新的线段树的逻辑即可。

class NumArray {
    int[] tree;
    int[] data;

    public NumArray(int[] nums) {


        data = nums;
        int n = nums.length;
        if (nums == null || nums.length == 0) {
            tree = new int[]{};
            return;
        }
        tree = new int[n * 4];
        buildSegmentTree(0, 0, data.length - 1);
    }

    private void buildSegmentTree(int treeIndex, int l, int r) {
        // base case:递归到叶子节点了
        if (l == r) {
            tree[treeIndex] = data[l];
            return;
        }
        //划分区间
        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        // 求(左孩子)左区间的统计值
        buildSegmentTree(leftTreeIndex, l, mid);
        // 求(右孩子)右区间的统计值
        buildSegmentTree(rightTreeIndex, mid + 1, r);
        //求当前节点 [左区间+右区间] 的统计值
        tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];
    }

    private int leftChild(int treeIndex) {
        return 2 * treeIndex + 1;
    }

    private int rightChild(int treeIndex) {
        return 2 * treeIndex + 2;
    }

    //将 index 位置的值,更新为 e
    public void update(int i, int val) {
        // 更新数组
        data[i] = val;
        //更新线段树
        updateTree(0, i, 0, data.length - 1);

    }

    // 在以 treeIndex 为根的线段树中,更新 index 的值为 e
    private void updateTree(int treeIndex, int index, int l, int r) {
        //递归终止条件
        if (l == r) {
            tree[treeIndex] = data[l];
            return;
        }
        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        if (index <= mid)
            updateTree(leftTreeIndex, index, l, mid);
        else //index <= mid
            updateTree(rightTreeIndex, index, mid + 1, r);
        //更新当前节点
        tree[treeIndex] = tree[leftTreeIndex] + tree[rightTreeIndex];

    }

    public int sumRange(int i, int j) {
        if (data.length == 0)
            return 0;
        return queryRange(0, 0, data.length - 1, i, j);
    }

    private int queryRange(int treeIndex, int l, int r, int queryL, int queryR) {
        if (l == queryL && r == queryR) {
            return tree[treeIndex];
        }
        int mid = l + (r - l) / 2;
        int leftTreeIndex = leftChild(treeIndex);
        int rightTreeIndex = rightChild(treeIndex);
        if (queryL > mid)
            return queryRange(rightTreeIndex, mid + 1, r, queryL, queryR);
        if (queryR <= mid)
            return queryRange(leftTreeIndex, l, mid, queryL, queryR);
        int leftResult = queryRange(leftTreeIndex, l, mid, queryL, mid);
        int rightResult = queryRange(rightTreeIndex, mid + 1, r, mid + 1, queryR);
        return leftResult + rightResult;
    }
}

update()方法和sumRange()方法的时间复杂度都是O(n)

总结与扩展

线段树,虽然不是满二叉树。但是我们却可以把它看作一棵满二叉树,进而使用数组来存储这棵树。

其次,在线段树中,我们定义了每个节点,存储的数据是这个节点对应区间的统计值(最大值、最小值、区间和)。通过这一点,你可以体会到,对于树中的节点,你可以赋予它独特的定义,进而可以高效地解决各种各样的问题。因此,树的使用范围是非常广泛的。

对于,线段树的构建、更新区间、和查询区间,这 3 个操作,都是先递归访问叶子节点,然后再回溯访问父节点,合并左右孩子区间的结果,这本质上是一种后序遍历的思想。

对一个区间进行更新

在本文的例子中,每次更新都是对一个元素进行更新。现在考虑另一种更新:对某个区间内的所有元素进行更新。

比如:将[2,5]区间中的所有元素都加 3,就需要遍历这个区间中的每个元素,区间更新的时间复杂度就变为了O(n)。为了降低区间更新的复杂度,有一种专门的方法:懒惰更新

懒惰更新的思想是:在每次更新区间时,我们实际上先不更新实际的数据,而是使用另一个lazy数组,来标记这些未更新的内容。

那么什么时候才会更新这些节点呢?

当我们下一次更新或者查询到这些数据时,先查一下lazy数组中是否有数据未更新,然后将未更新的内容进行更新,再访问对应的数据。

例如:

  • 第一次:将[2,5]区间中的所有元素都加 3,实际上lazy数组就会标记[2,5]区间的数据未更新;
  • 第二次:将[4,7]区间中的所有元素都减 5。这时,查询lazy数组中,发现[4,5]区间的数据未更新,那么先更新[4,5]区间的内容,[2,3]区间的标记不变。然后标记[4,7]区间中的内容未更新。

通过懒惰更新,时间复杂度降为了O(logn)

二维线段树

在这篇文章中,我们处理的都是一维的线段树,实际中还可以产生二位线段树。

一维线段树,就是数据都是一维数组,每个节点记录的区间只有左右两个边界。


在二维线段树中,数据是二维数组,也就是一个矩阵,每个节点记录的区间有上下左右两个边界。那么每个节点就有 4 个孩子。


进一步扩展,你也可以设计出 3 维的线段树,甚至更高维的线段树。

线段树是一种设计思想,利用树这种数据结构,如何把一个大的数据单元,递归地拆分为小的数据单元,同时,利用树这种数据结构,可以高效地进行查询、更新等操作。

动态线段树

在这篇文章中,我使用了数组这种数据结构来存储树,造成了空间的浪费。实际上,我们可以也使用链式存储,可以更好地利用空间。

实际上,动态线段树有一个更加重要的应用,。例如,我们要存储 10000000 大小的线段树,但是不会对这么大一个区间中的每一个部分都进行访问,可能只会对一种某个小部分进行访问。

那么我们可以不用一开始就创建这么巨大的线段树,而是只创建一个根节点,等到实际访问某个区间时,在根据区间的边界,动态地创建线段树。


树状数组

这篇文章讲的线段树,实际上是对区间操作的一种数据结构。与区间操作相关的数据结构,还有另外一个:树状数组,英文名为Binary Index Tree。感兴趣可以自行查阅。



如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。


我的文章会首发在公众号上,欢迎扫码关注我的公众号张贤同学


posted @ 2020-04-12 14:48  张贤同学  阅读(917)  评论(0编辑  收藏  举报