线段树的数组实现
参考:https://blog.csdn.net/Yaokai_AssultMaster/article/details/79599809
https://zhuanlan.zhihu.com/p/34150142
http://www.cnblogs.com/TenosDoIt/p/3453089.html
看了google搜索出来的前三个讲解,两个是从上而下,用类实现的;另一个讲的从下往上,但是却没讲延迟更新
自己总结一下,用数组实现的带延迟更新的线段树
线段树
是一种树,负责[1,N]
区间上的区间查询、更新
- 根结点负责
[1,N]
的查询值 - 每个叶结点负责每个整数点的查询值
- 每个非叶结点负责它的子树下的叶结点所代表的区间的查询值
举个例子:
线段树适合 【解决相邻的区间的信息可以被合并成两个区间的并区间】 的信息的问题。
也可以用类实现,记录一下用数组实现的方法。
创建
对含有N
个数字的数组arr
建立线段树。
可以建成一个有2N-1
个结点的完全二叉树。
用数组ST[:2N]
来记录就可以了,把0的位置空出来能减少运算次数。
这样,ST[1:N]
都是非叶结点,ST[N:2N]
都是叶结点。
当N
也不是2的幂的时候也可以,比如下图:
构建线段数的时候要从叶结点向上更新:
def construct(arr):
n = length(arr)
for i in xrange(n,2n):
ST[i] = arr[i - n]
for i in xrange(n-1,0,-1):
ST[i] = query_func(ST[i*2],ST[i*2+1])
其中query_func(a,b)
是需要查询的功能,a
和b
都是结点的查询值。
比如要查询最小值,query_func
就是min
查询
查询的时候,可以从下向上查询。
def query(l,r):
l += n # 到达相应的叶结点
r += n # 到达相应的叶结点
res = 绝对不会被选上的值
while l<=r:
# 如果查询的左边界是它父亲的左孩子,那么说明它父亲的左右孩子都在查询区间内,这样只要看父亲的值就行了。
if l%2!=0: # 否则,如果该点是它父亲的右孩子,说明左孩子不在查询区间内,那么这个点的值要考虑
res = query_func(res, ST[l])
l += 1 # +1之后是个左孩子并且在查询区间内,所以看它父亲的值就行了
# 右边同理
if r%2==0:
res = query_func(res, ST[r])
r -= 1
# 然后往上走一层
l /= 2
r /= 2
return res
更新
单点更新
和堆很像,更新对应叶结点的值,然后向上更新就行,是的
def update_one(x,v):
x += n
ST[x] = v
while x>1: # 如果有父结点
x = x/2
ST[x] = func(ST[x*2],ST[x*2+1])
区间更新
这要涉及延迟更新了。如果非叶结点x
的所代表的叶结点都是同一个值的话,那么该结点下面的结点可以先不用更新,用lazy[x]
来记录这个更新值。
等到需要查询该结点的子结点的时候再更新。
def update(self, l, r, v):
l += n
r += n
l0,r0 = l,r
while l <= r:
if l & 1: # 如果是右孩子,那么要更新。否则如果是左孩子,那么让父亲结点延迟更新就好
_apply(l, v)
l += 1
if r & 1 == 0:
_apply(r, v)
r -= 1
l /= 2; r /= 2
# 上面的循环更新的结点都是在更新区间`[l0,r0]`内的
self._pull(l0)
self._pull(r0)
其中_apply(x,v)
表示,要把x所代表的区间更新为v
def _apply(self, x, v):
ST[x] = update_func(ST[x], val) # 这个结点要更新值
if x < n: # 如果它是个非叶结点,那么要记录延迟更新值
lazy[x] = update_func(lazy[x], val)
举个例子,如该图所示,假如我们要更新区间[3,6]
上面的while循环会把需要更新的最低节点都更新了,即标出来的橙色节点,灰色节点就是被延迟的节点。
但是这次更新还会涉及到所有橙色节点的父节点们,最后两行_pull
就是用来做这个事情的。
_pull
的代码为:
def _pull(self, x):
while x > 1: # 假如还有父节点
x /= 2
ST[x] = query_func(ST[x*2], ST[x*2 + 1]) # 重新计算父节点的准确值
ST[x] = update_func(ST[x], lazy[x]) # 如果有延迟更新值也要计算进去,因为该节点的儿子节点有可能用的是延迟更新前的值,为了得到父节点的准确值,需要把延迟更新值找过去
而由于延迟更新的存在,查询过程也将改变
延迟更新下的查询过程
查询过程多了两行,要先处理延迟更新。
def query(l,r):
l += n # 到达相应的叶结点
r += n # 到达相应的叶结点
res = 绝对不会被选上的值
_push(l)
_pull(r)
while l<=r:
# 如果查询的左边界是它父亲的左孩子,那么说明它父亲的左右孩子都在查询区间内,这样只要看父亲的值就行了。
if l%2!=0: # 否则,如果该点是它父亲的右孩子,说明左孩子不在查询区间内,那么这个点的值要考虑
res = query_func(res, ST[l])
l += 1 # +1之后是个左孩子并且在查询区间内,所以看它父亲的值就行了
# 右边同理
if r%2==0:
res = query_func(res, ST[r])
r -= 1
# 然后往上走一层
l /= 2
r /= 2
return res
其中_push
就是相应地处理延迟更新的过程
def _push(self, x):
for h in xrange(self.H, 0, -1): # 从根节点开始,一个个遍历该节点的祖先
y = x >> h
if lazy[y]: # 如果存在有待更新的,就往下推一层
_apply(y * 2, self.lazy[y])
_apply(y * 2+ 1, self.lazy[y])
lazy[y] = 0
最后完整的线段树代码:
class ST(object):
def __init__(self,n,uf,qf):
self.n = n
self.uf = uf
self.qf = qf
self.t = [0]*(n<<1)
self.lazy = [0]*n
def _apply(self,x,v):
self.t[x] = self.uf(self.t[x],v)
if x<self.n:
self.lazy[x] = self.uf(self.lazy[x],v)
def _pull(self,x):
while x>1:
x /= 2
self.t[x] = self.qf(self.t[x<<1],self.t[(x<<1)+1])
self.t[x] = self.uf(self.t[x],self.lazy[x])
def update(self,l,r,v):
l += self.n
r += self.n
l0,r0 = l,r
while l<=r:
if l&1:
self._apply(l,v)
l += 1
if not r&1:
self._apply(r,v)
r -= 1
l,r = l>>1,r>>1
self._pull(l0)
self._pull(r0)
def _push(self,x):
H = self.n.bit_length()
for h in xrange(H,0,-1):
y = x>>h
if self.lazy[y]:
self._apply(y<<1,self.lazy[y])
self._apply((y<<1)+1,self.lazy[y])
self.lazy[y] = 0
def query(self,l,r):
l += self.n
r += self.n
self._push(l)
self._push(r)
res = 0
while l<=r:
if l&1:
res = self.qf(res,self.t[l])
l += 1
if not r&1:
res = self.qf(res,self.t[r])
r -= 1
l,r = l>>1,r>>1
return res