线段树的数组实现

参考:https://blog.csdn.net/Yaokai_AssultMaster/article/details/79599809
https://zhuanlan.zhihu.com/p/34150142
http://www.cnblogs.com/TenosDoIt/p/3453089.html

看了google搜索出来的前三个讲解,两个是从上而下,用类实现的;另一个讲的从下往上,但是却没讲延迟更新
自己总结一下,用数组实现的带延迟更新的线段树

线段树

是一种树,负责[1,N]区间上的区间查询、更新

  1. 根结点负责[1,N]的查询值
  2. 每个叶结点负责每个整数点的查询值
  3. 每个非叶结点负责它的子树下的叶结点所代表的区间的查询值

举个例子:
1
线段树适合 【解决相邻的区间的信息可以被合并成两个区间的并区间】 的信息的问题。

也可以用类实现,记录一下用数组实现的方法。

创建

对含有N个数字的数组arr建立线段树。
可以建成一个有2N-1个结点的完全二叉树。
用数组ST[:2N]来记录就可以了,把0的位置空出来能减少运算次数。
这样,ST[1:N]都是非叶结点,ST[N:2N]都是叶结点。
N也不是2的幂的时候也可以,比如下图:
2

构建线段数的时候要从叶结点向上更新:

def construct(arr):
  n = length(arr)
  for i in xrange(n,2n):
    ST[i] = arr[i - n]
  for i in xrange(n-1,0,-1):
    ST[i] = query_func(ST[i*2],ST[i*2+1])

其中query_func(a,b)是需要查询的功能,ab都是结点的查询值。
比如要查询最小值,query_func就是min

查询

查询的时候,可以从下向上查询。

def query(l,r):
  l += n # 到达相应的叶结点
  r += n # 到达相应的叶结点
  res = 绝对不会被选上的值
  while l<=r:
    # 如果查询的左边界是它父亲的左孩子,那么说明它父亲的左右孩子都在查询区间内,这样只要看父亲的值就行了。
    if l%2!=0:  # 否则,如果该点是它父亲的右孩子,说明左孩子不在查询区间内,那么这个点的值要考虑
      res = query_func(res, ST[l])
      l += 1  # +1之后是个左孩子并且在查询区间内,所以看它父亲的值就行了
    # 右边同理
    if r%2==0:
      res = query_func(res, ST[r])
      r -= 1
    # 然后往上走一层
    l /= 2
    r /= 2
  return res

更新

单点更新

和堆很像,更新对应叶结点的值,然后向上更新就行,是O(logn)O(logn)

def update_one(x,v):
  x += n
  ST[x] = v
  while x>1:  # 如果有父结点
    x = x/2
    ST[x] = func(ST[x*2],ST[x*2+1])

区间更新

这要涉及延迟更新了。如果非叶结点x的所代表的叶结点都是同一个值的话,那么该结点下面的结点可以先不用更新,用lazy[x]来记录这个更新值。
等到需要查询该结点的子结点的时候再更新。

def update(self, l, r, v):
  l += n
  r += n
  l0,r0 = l,r
  while l <= r:
      if l & 1: # 如果是右孩子,那么要更新。否则如果是左孩子,那么让父亲结点延迟更新就好
          _apply(l, v)
          l += 1
      if r & 1 == 0:
          _apply(r, v)
          r -= 1
      l /= 2; r /= 2
  # 上面的循环更新的结点都是在更新区间`[l0,r0]`内的
  self._pull(l0)  
  self._pull(r0)

其中_apply(x,v)表示,要把x所代表的区间更新为v

def _apply(self, x, v):
  ST[x] = update_func(ST[x], val) # 这个结点要更新值
  if x < n: # 如果它是个非叶结点,那么要记录延迟更新值
      lazy[x] = update_func(lazy[x], val)

举个例子,如该图所示,假如我们要更新区间[3,6]
3
上面的while循环会把需要更新的最低节点都更新了,即标出来的橙色节点,灰色节点就是被延迟的节点。
但是这次更新还会涉及到所有橙色节点的父节点们,最后两行_pull就是用来做这个事情的。
_pull的代码为:

def _pull(self, x):
  while x > 1:  # 假如还有父节点
      x /= 2
      ST[x] = query_func(ST[x*2], ST[x*2 + 1])  # 重新计算父节点的准确值
      ST[x] = update_func(ST[x], lazy[x]) # 如果有延迟更新值也要计算进去,因为该节点的儿子节点有可能用的是延迟更新前的值,为了得到父节点的准确值,需要把延迟更新值找过去

而由于延迟更新的存在,查询过程也将改变

延迟更新下的查询过程

查询过程多了两行,要先处理延迟更新。

def query(l,r):
  l += n # 到达相应的叶结点
  r += n # 到达相应的叶结点
  res = 绝对不会被选上的值
  _push(l)
  _pull(r)
  while l<=r:
    # 如果查询的左边界是它父亲的左孩子,那么说明它父亲的左右孩子都在查询区间内,这样只要看父亲的值就行了。
    if l%2!=0:  # 否则,如果该点是它父亲的右孩子,说明左孩子不在查询区间内,那么这个点的值要考虑
      res = query_func(res, ST[l])
      l += 1  # +1之后是个左孩子并且在查询区间内,所以看它父亲的值就行了
    # 右边同理
    if r%2==0:
      res = query_func(res, ST[r])
      r -= 1
    # 然后往上走一层
    l /= 2
    r /= 2
  return res

其中_push就是相应地处理延迟更新的过程

def _push(self, x):
  for h in xrange(self.H, 0, -1): # 从根节点开始,一个个遍历该节点的祖先
      y = x >> h
      if lazy[y]:  # 如果存在有待更新的,就往下推一层
          _apply(y * 2, self.lazy[y])
          _apply(y * 2+ 1, self.lazy[y])
          lazy[y] = 0

最后完整的线段树代码:

class ST(object):
    def __init__(self,n,uf,qf):
        self.n = n
        self.uf = uf
        self.qf = qf
        self.t = [0]*(n<<1)
        self.lazy = [0]*n
    
    def _apply(self,x,v):
        self.t[x] = self.uf(self.t[x],v)
        if x<self.n:
            self.lazy[x] = self.uf(self.lazy[x],v)
    
    def _pull(self,x):
        while x>1:
            x /= 2
            self.t[x] = self.qf(self.t[x<<1],self.t[(x<<1)+1])
            self.t[x] = self.uf(self.t[x],self.lazy[x])
    
    def update(self,l,r,v):
        l += self.n
        r += self.n
        l0,r0 = l,r
        while l<=r:
            if l&1:
                self._apply(l,v)
                l += 1
            if not r&1:
                self._apply(r,v)
                r -= 1
            l,r = l>>1,r>>1
        self._pull(l0)
        self._pull(r0)
    
    def _push(self,x):
        H = self.n.bit_length()
        for h in xrange(H,0,-1):
            y = x>>h
            if self.lazy[y]:
                self._apply(y<<1,self.lazy[y])
                self._apply((y<<1)+1,self.lazy[y])
                self.lazy[y] = 0
    
    def query(self,l,r):
        l += self.n
        r += self.n
        self._push(l)
        self._push(r)
        res = 0
        while l<=r:
            if l&1:
                res = self.qf(res,self.t[l])
                l += 1
            if not r&1:
                res = self.qf(res,self.t[r])
                r -= 1
            l,r = l>>1,r>>1
        return res

例题

Leetcode 699. Falling Squares解题报告

posted @ 2019-04-06 00:50  milliele  阅读(139)  评论(0编辑  收藏  举报