线段树实现。很多细节值得品味 都在注释里面了

class SegTree:
    def __init__(self,N,query_fn,update_fn):
        self.tree=[0]*(2*N+2) # 最后一个节点有可能无用 ,但是确保树是完全的
        self.lazy=[0]*(N+1)
        self.N=N
    
        self.h=0
        while ((1<<self.h) < N ):
            self.h=self.h+1 
        
        self.query_fn=query_fn
        self.update_fn=update_fn 
        
        
    #这里实现只有单个节点的更改,对于整个节点来说以上的节点没有杯更新到 所以要使用 pull 做进一步的更改
    def _apply(self,x,val):        
        self.tree[x] = self.update_fn(val,self.tree[x])
        
        if(x<self.N):
            self.lazy[x]=self.update_fn(self.lazy[x] , self.tree[x])
            
    '''
    pull 
    从叶部到根部 ,用于更新后的维护
    
    push 
    从根部到叶部 ,查询前的维护 
    
    '''
    
            
    def pull(self,x):
     
        while x>1 :  # 先while 再除 保证了 (x==1 仍然会执行&&第一次执行的时候不是叶子节点)
            x = x // 2
            self.tree[x] = self.query_fn(self.tree[x*2+1],self.tree[x*2 ] ) #针对区间之间的信息 是 query_fn 
            
            
            self.tree[x] = self.update_fn(self.tree[x],self.lazy[x])
            
    def push(self,x):
        for i in range(self.h,0,-1):
            y = x >> i
            if(self.lazy[ y ]):                  #巧妙的使用了 lazy 前空出无用的 lazy[0]
                self._apply(y*2, self.lazy[y])   #_apply 代码的复用 (修改值 并且置 lazy )
                self._apply(y*2+1, self.lazy[y])
                self.lazy[y]=0
                
        
        
        
    
    def update(self,left,right,val):    
        #print("update "+str(left )+" "+str(right )+" "+str(val))
        left =left +self.N
        right=right+self.N
        
        init_left = left 
        init_right = right 
        while(left<=right): #追溯最邻近父节点的写法 
            
            if( left & 1 ):
                self._apply(left , val )
                left = left +1
            
            if((right & 1) ==0):
                self._apply(right, val)
                right = right -1 
                
            left = left // 2 
            right = right // 2  
            
        #以上只更新到左右节点的最邻近父节点 所以需要进行pull 更新上面的所有节点
        self.pull( init_left )
        self.pull( init_right )
        
    
        
    def query(self,left,right):
        #print("query "+str(left )+" "+str(right ))
        left = left +self.N
        right = right + self.N
        self.push( left )
        self.push( right )
        
        ans = 0 
        while ( left <= right ):#追溯最邻近父节点
            #print("left "+str(left)+" right"+str(right))
            if(left & 1):
            
                ans= self.query_fn(ans , self.tree[left] )
                left = left + 1 #这个+1 是为了对称 没有实际意义?
            
            if((right & 1) ==0 ):
                #print(right)
                ans = self.query_fn(ans , self.tree[right ])
                right = right -1                 
              
            left = left   // 2 
            right = right // 2 
        return ans 

class Solution:
    def fallingSquares(self, positions):
        """
        :type positions: List[List[int]]
        :rtype: List[int]
        """
        position = [i[0] for i in positions]
        position . extend([i[0]+i[1]-1 for i in positions])
        
        position=sorted(set(position))

        pos2idx ={pos:i+1 for i,pos in enumerate(position)}
        
        N=len(pos2idx)
        
        #print("N is "+str(N))
        segtree=SegTree(N,max,max)
        
        best = -1 
        ans_list = []
        #print(pos2idx)
        for block in positions:
            l,r=pos2idx[block[0]],pos2idx[block[0]+block[1]-1]
            
            height = segtree.query(l,r) + block[1]
            #print("query height"+str(height-block[1])+"   blcok size:"+str(block[1]))
            
            
            segtree.update(l,r,height)
            
            best=max(best,height)
            
            ans_list.append(best)
        
        return ans_list