线段树 | 第1讲 (给定区间求和)(转)

让我们通过考虑下面的问题来理解线段树。

给定一个数组arr[0 . . . n-1],我们要对数组执行这样的操作:

1 计算从下标l到r的元素之和,其中 0 <= l <= r <= n-1
​2 修改数组指定元素的值arr[i] = x,其中 0 <= i <= n-1

一个简单的方案是从lr执行循环,计算给定区间的元素之和。更新值的时候,简单地令arr[i] = x。第一个操作花费O(n)的时间,第二个操作花费O(1)的时间。

第二个方案是创建另外一个数组来存储从下标i开始的元素之和。这样一来,给定区间之和可以用O(1)的时间计算,但是更新需要花费O(n)的时间。这种方法适用于需要大量查询而更新操作较少的场景。

如果查询和更新的次数一样多呢?我们可以在O(log n)的时间内完成上述两种操作吗?

我们可以使用线段树来实现在O(log n)时间内完成上述两种操作。

线段树的表示:

1. 叶子节点存储输入的数组元素
2. 每一个内部节点表示某些叶子节点的合并(merge)。合并的方法可能会因问题而异。对于这个问题,合并指的是某个节点之下的所有叶子节点的和。

此处使用树的数组形式来表示线段树。对于下标i的节点,其左孩子为 2*i+1,右孩子为2*i+2,父节点为floor( (i - 1)/2 )

segment-tree

根据给定的数组构建线段树:

我们从线段数组arr[0 . . . n-1]开始。每一次将现有的线段拆分成两半(如果当前线段的长度还不为1),然后在两半线段分别执行同样的过程,并且对于每一个这样的线段,我们存储相应节点的和。

构建出的线段树除最后一层外每一层都会填满。线段树是满二叉树(此处的满二叉树是指树中任意节点的度为0或2的二叉树,与国内计算机教材关于满二叉树的定义不同),因为我们在每一层都将线段拆分为两半。由于构建出的树总是拥有n个叶子节点的满二叉树,因此内部节点会有 n-1 个。因而节点总数为 2*n - 1。

线段树的高度为ceil( log n )。由于树使用数组表示,并且需要维护父子索引之间的关系,为线段树分配的内存需要:2 * 2 ^ ceil( logn ) - 1

查询给定区间元素之和:

线段树构建完成之后,怎样获取给定区间的和呢?下面的伪代码展示了获取区间元素之和的过程。

int getSum(node, l, r) 
{
   if range of node is within l and r
        return value in node
   else if range of node is completely outside l and r
        return 0
   else
    return getSum(node's left child, l, r) + 
           getSum(node's right child, l, r)
}

元素值的更新:

与树的构建与查询操作类似,更新操作也可以递归地完成。给定需要更新的数组下标。令diff为需要添加的值。我们从线段树的根节点开始,对所有位于给定区间之内的节点添加diff。如果节点不在区间范围之内,则不做任何修改。

线段树的实现:

下面展示了线段树的实现。程序实现了从任意数组构建线段树,以及查询与更新操作。

// 演示线段树构建、查询、更新等操作的示例程序
#include <stdio.h>
#include <math.h>
 
// 获取起止下标中点的工具函数
int getMid(int s, int e) {  return s + (e - s)/2;  }
 
/*  获取数组给定区间之和的递归函数
    下面是函数的参数列表
 
    st    --> 线段树的指针
    index --> 线段树当前节点的下标。 初始传入根节点的下标为0
             根节点的下标值不会变更
    ss & se  --> 线段树当前节点表示的原数组起止下标
                 亦即,st[index]的起止下标
    qs & qe  --> 查询区间的起止下标 */
int getSumUtil(int *st, int ss, int se, int qs, int qe, int index)
{
    // 如果当前节点存储的线段是区间的一部分,
    // 返回当前线段的和
    if (qs <= ss && qe >= se)
        return st[index];
 
    // 如果节点存储的线段不在给定区间之内
    if (se < qs || ss > qe)
        return 0;
 
    // 如果节点的线段与区间的一部分有交集
    int mid = getMid(ss, se);
    return getSumUtil(st, ss, mid, qs, qe, 2*index+1) +
           getSumUtil(st, mid+1, se, qs, qe, 2*index+2);
}
 
/* 更新下标位于给定区间内节点值的递归函数,
    下面是参数列表
    st, index, ss and se 与getSumUtil() 一致
    i    --> 待更新元素的下标,指的是输入数组的下标。
   diff --> 区间需要增加的值 */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int index)
{
    // Base Case: 如果输入下标在线段树范围之外
    if (i < ss || i > se)
        return;
 
    // 如果输入下标在节点范围之内,
    // 则更新节点及其孩子的值
    st[index] = st[index] + diff;
    if (se != ss)
    {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i, diff, 2*index + 1);
        updateValueUtil(st, mid+1, se, i, diff, 2*index + 2);
    }
}
 
// 更新输入数组与线段树值的函数。
// 使用了函数 updateValueUtil() 来更新线段树的值
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
    // 检查错误的输入下标
    if (i < 0 || i > n-1)
    {
        printf("Invalid Input");
        return;
    }
 
    // 计算新值与老值之间的差值
    int diff = new_val - arr[i];
 
    // 更新数组的值
    arr[i] = new_val;
 
    // 更新线段树节点的值
    updateValueUtil(st, 0, n-1, i, diff, 0);
}
 
// 返回下标qs(查询起点)到qe(查询终点)的元素之和。
// 主要使用了函数getSumUtil()
int getSum(int *st, int n, int qs, int qe)
{
    // 检查错误的输入
    if (qs < 0 || qe > n-1 || qs > qe)
    {
        printf("Invalid Input");
        return -1;
    }
 
    return getSumUtil(st, 0, n-1, qs, qe, 0);
}
 
// 递归函数,为数组[ss..se]构建线段树
// si 是线段树st内当前节点的下标
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
    // 如果数组只包含一个元素
    // 将其存储与线段树的当前节点并返回
    if (ss == se)
    {
        st[si] = arr[ss];
        return arr[ss];
    }
 
    // 如果有不止一个元素,
    // 则递归计算左右子树,并将两者之和存储与节点内,并返回
    int mid = getMid(ss, se);
    st[si] =  constructSTUtil(arr, ss, mid, st, si*2+1) +
              constructSTUtil(arr, mid+1, se, st, si*2+2);
    return st[si];
}
 
/* 从给定数组构建线段树的函数。
   函数为线段树分配内存空间,并调用函数constructSTUtil()
   来填充分配的内存 */
int *constructST(int arr[], int n)
{
    // 为线段树分配内存空间
    int x = (int)(ceil(log2(n))); //线段树的高度
    int max_size = 2*(int)pow(2, x) - 1; //线段树的最大容量
    int *st = new int[max_size];
 
    // 填充线段树st
    constructSTUtil(arr, 0, n-1, st, 0);
 
    // 返回构建的线段树
    return st;
}
 
// 上述函数的测试程序
int main()
{
    int arr[] = {1, 3, 5, 7, 9, 11};
    int n = sizeof(arr)/sizeof(arr[0]);
 
    // 从给定数组构建线段树
    int *st = constructST(arr, n);
 
    // 输出下标1 到 3的元素之和
    printf("Sum of values in given range = %d\n", getSum(st, n, 1, 3));
 
    // 更新: 令 arr[1] = 10
    //  并更新相应的线段树节点
    updateValue(arr, st, n, 1, 10);
 
    // 输出更新后的和值
    printf("Updated sum of values in given range = %d\n",
                                                  getSum(st, n, 1, 3));
 
    return 0;
}

程序输出:

Sum of values in given range = 15
Updated sum of values in given range = 22

时间复杂度:

线段树构建的时间复杂度为O(n)。总计有2n-1个节点,每一个节点在树构建过程中只被运算一次。

查询的时间复杂度为O(log n)。要查询区间和,我们在每一层至多处理4个节点,并且层的总数为O(log n)。

更新的时间复杂度也是O(log n)。要更新一个叶子节点,我们每一层处理一个节点,并且层的总数为O(log n)。

原文链接:http://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/

本文链接:http://bookshadow.com/weblog/2015/08/13/segment-tree-set-1-sum-of-given-range/

 

较为清晰的代码:

/**
 * Definition of Interval:
 * public classs Interval {
 *     int start, end;
 *     Interval(int start, int end) {
 *         this.start = start;
 *         this.end = end;
 *     }
 */


public class Solution {
    /*
     * @param A: An integer list
     * @param queries: An query list
     * @return: The result list
     */
    class SegmentTreeNode {
        public int start;
        public int end;
        public long sum;
        SegmentTreeNode left;
        SegmentTreeNode right;
        public SegmentTreeNode(int start, int end, long sum) {
            this.start = start;
            this.end = end;
            this.sum = sum;
            this.left = null;
            this.right = null;
        }
    }

    public SegmentTreeNode build(int start, int end, int[] A) {
        if (start > end) {
            return null;
        }
    
        if (start == end) {
            return new SegmentTreeNode(start, end, A[start]);
        }
    
        SegmentTreeNode root = new SegmentTreeNode(start, end, 0);
        int mid = start + (end - start) / 2;
        root.left = build(start, mid, A);
        root.right = build(mid + 1, end, A);
        if (root.left != null) {
            root.sum += root.left.sum;
        }
        if (root.right != null) {
            root.sum += root.right.sum;
        }
        return root;
    }

    public long query(SegmentTreeNode root, int start, int end) {
        if (start <= root.start && end >= root.end) {
            return root.sum;
        }
    
        int mid = root.start + (root.end - root.start) / 2;
        long ans = 0;
        if (start <= mid) {
            ans += query(root.left, start, end);
        } 
        if (end > mid) {
            ans += query(root.right, start, end);
        }
    
        return ans;
    }
 
    SegmentTreeNode root;
    public List<Long> intervalSum(int[] A, List<Interval> queries) {
        root = build(0, A.length - 1, A);
        List<Long> list = new ArrayList<>();
        
        for (Interval num : queries) {
            long res = query(root, num.start, num.end);
            list.add(res);
        }
        
        return list;
    }
}

 

posted @ 2020-07-03 11:36  鸭子船长  阅读(322)  评论(0编辑  收藏  举报