线段树 python 实现
1 class TreeNode: 2 def __init__(self, left, right, mx): 3 self.left = left 4 self.right = right 5 self.mx = mx 6 7 8 # 线段树类 9 # 以_开头的是递归实现 10 class Tree(object): 11 def __init__(self, n, arr): 12 self.n = n 13 self.max_size = 4 * n 14 self.tree = [TreeNode() for _ in range(self.max_size)] # 维护一个TreeNode数组 15 self.arr = arr 16 17 # index从1开始 18 def _build(self, index, left, right): 19 self.tree[index].left = left 20 self.tree[index].right = right 21 if left == right: 22 self.tree[index].mx = left 23 else: 24 mid = (left + right) // 2 25 self._build(index * 2, left, mid) 26 self._build(index * 2, mid + 1, right) 27 self.tree[index].mx = max(self.tree[index * 2].mx, self.tree[index * 2 + 1].mx) 28 29 # 构建线段树 30 def build(self): 31 self._build(1, 1, self.n) 32 33 def _update(self, ind, k, v): # 点更新,将arr[k]的值改成v 34 if self.tree[ind].left == self.tree[ind].right and self.tree[ind].left == k: 35 self.tree[ind].mx = v 36 return 37 mid = (self.tree[ind].left + self.tree[ind].right) // 2 38 if k <= mid: 39 self._update(ind * 2, k, v) 40 else: 41 self._update(ind * 2 + 1, k, v) 42 43 # 回归时更新 44 self.tree[ind].mx = max(self.tree[ind * 2].mx, self.tree[ind * 2 + 1].mx) 45 46 # 区间覆盖 47 def _query(self, ind, l, r): 48 if self.tree[ind].left >= l and self.tree[ind].right <= r: 49 return self.tree[ind].mx 50 mid = (self.tree[ind].left + self.tree[ind].right) // 2 51 res = float("-inf") # 局部变量 52 if l <= mid: 53 res = max(res, self._query(ind * 2, l, r)) 54 if r > mid: 55 res = max(res, self._query(ind * 2 + 1, l, r)) 56 return res 57 58 # 区间相等 59 def _query2(self, ind, l, r): 60 if self.tree[ind].left == l and self.tree[ind].right == r: 61 return self.tree[ind].mx 62 mid = (self.tree[ind].left + self.tree[ind].right) // 2 63 if r < mid: 64 return self._query2(ind * 2, l, r) 65 elif l > mid: 66 return self._query2(ind * 2 + 1, l, r) 67 else: 68 return max(self._query2(ind * 2, l, mid), self._query2(ind * 2 + 1, mid + 1, r)) 69 70 def query(self, ql, qr): 71 return self._query(1, ql, qr) 72 73 # 深度遍历打印数组 74 def _show_arr(self, i): 75 if self.tree[i].left == self.tree[i].right and self.tree[i].left != -1: 76 print(self.tree[i].mx, end=" ") 77 if 2 * i < len(self.tree): 78 self._show_arr(i * 2) 79 self._show_arr(i * 2 + 1) 80 81 # 显示更新后的数组的样子 82 def show_arr(self, ): 83 self._show_arr(1)