线段树实现。很多细节值得品味 都在注释里面了
class SegTree:
def __init__(self,N,query_fn,update_fn):
self.tree=[0]*(2*N+2) # 最后一个节点有可能无用 ,但是确保树是完全的
self.lazy=[0]*(N+1)
self.N=N
self.h=0
while ((1<<self.h) < N ):
self.h=self.h+1
self.query_fn=query_fn
self.update_fn=update_fn
#这里实现只有单个节点的更改,对于整个节点来说以上的节点没有杯更新到 所以要使用 pull 做进一步的更改
def _apply(self,x,val):
self.tree[x] = self.update_fn(val,self.tree[x])
if(x<self.N):
self.lazy[x]=self.update_fn(self.lazy[x] , self.tree[x])
'''
pull
从叶部到根部 ,用于更新后的维护
push
从根部到叶部 ,查询前的维护
'''
def pull(self,x):
while x>1 : # 先while 再除 保证了 (x==1 仍然会执行&&第一次执行的时候不是叶子节点)
x = x // 2
self.tree[x] = self.query_fn(self.tree[x*2+1],self.tree[x*2 ] ) #针对区间之间的信息 是 query_fn
self.tree[x] = self.update_fn(self.tree[x],self.lazy[x])
def push(self,x):
for i in range(self.h,0,-1):
y = x >> i
if(self.lazy[ y ]): #巧妙的使用了 lazy 前空出无用的 lazy[0]
self._apply(y*2, self.lazy[y]) #_apply 代码的复用 (修改值 并且置 lazy )
self._apply(y*2+1, self.lazy[y])
self.lazy[y]=0
def update(self,left,right,val):
#print("update "+str(left )+" "+str(right )+" "+str(val))
left =left +self.N
right=right+self.N
init_left = left
init_right = right
while(left<=right): #追溯最邻近父节点的写法
if( left & 1 ):
self._apply(left , val )
left = left +1
if((right & 1) ==0):
self._apply(right, val)
right = right -1
left = left // 2
right = right // 2
#以上只更新到左右节点的最邻近父节点 所以需要进行pull 更新上面的所有节点
self.pull( init_left )
self.pull( init_right )
def query(self,left,right):
#print("query "+str(left )+" "+str(right ))
left = left +self.N
right = right + self.N
self.push( left )
self.push( right )
ans = 0
while ( left <= right ):#追溯最邻近父节点
#print("left "+str(left)+" right"+str(right))
if(left & 1):
ans= self.query_fn(ans , self.tree[left] )
left = left + 1 #这个+1 是为了对称 没有实际意义?
if((right & 1) ==0 ):
#print(right)
ans = self.query_fn(ans , self.tree[right ])
right = right -1
left = left // 2
right = right // 2
return ans
class Solution:
def fallingSquares(self, positions):
"""
:type positions: List[List[int]]
:rtype: List[int]
"""
position = [i[0] for i in positions]
position . extend([i[0]+i[1]-1 for i in positions])
position=sorted(set(position))
pos2idx ={pos:i+1 for i,pos in enumerate(position)}
N=len(pos2idx)
#print("N is "+str(N))
segtree=SegTree(N,max,max)
best = -1
ans_list = []
#print(pos2idx)
for block in positions:
l,r=pos2idx[block[0]],pos2idx[block[0]+block[1]-1]
height = segtree.query(l,r) + block[1]
#print("query height"+str(height-block[1])+" blcok size:"+str(block[1]))
segtree.update(l,r,height)
best=max(best,height)
ans_list.append(best)
return ans_list