线段树 python 实现

 

 1 class TreeNode:
 2     def __init__(self, left, right, mx):
 3         self.left = left
 4         self.right = right
 5         self.mx = mx
 6 
 7 
 8 # 线段树类
 9 # 以_开头的是递归实现
10 class Tree(object):
11     def __init__(self, n, arr):
12         self.n = n
13         self.max_size = 4 * n
14         self.tree = [TreeNode() for _ in range(self.max_size)]  # 维护一个TreeNode数组
15         self.arr = arr
16 
17     # index从1开始
18     def _build(self, index, left, right):
19         self.tree[index].left = left
20         self.tree[index].right = right
21         if left == right:
22             self.tree[index].mx = left
23         else:
24             mid = (left + right) // 2
25             self._build(index * 2, left, mid)
26             self._build(index * 2, mid + 1, right)
27             self.tree[index].mx = max(self.tree[index * 2].mx, self.tree[index * 2 + 1].mx)
28 
29     # 构建线段树
30     def build(self):
31         self._build(1, 1, self.n)
32 
33     def _update(self, ind, k, v):  # 点更新,将arr[k]的值改成v
34         if self.tree[ind].left == self.tree[ind].right and self.tree[ind].left == k:
35             self.tree[ind].mx = v
36             return
37         mid = (self.tree[ind].left + self.tree[ind].right) // 2
38         if k <= mid:
39             self._update(ind * 2, k, v)
40         else:
41             self._update(ind * 2 + 1, k, v)
42 
43         # 回归时更新
44         self.tree[ind].mx = max(self.tree[ind * 2].mx, self.tree[ind * 2 + 1].mx)
45 
46     # 区间覆盖
47     def _query(self, ind, l, r):
48         if self.tree[ind].left >= l and self.tree[ind].right <= r:
49             return self.tree[ind].mx
50         mid = (self.tree[ind].left + self.tree[ind].right) // 2
51         res = float("-inf")  # 局部变量
52         if l <= mid:
53             res = max(res, self._query(ind * 2, l, r))
54         if r > mid:
55             res = max(res, self._query(ind * 2 + 1, l, r))
56         return res
57 
58     # 区间相等
59     def _query2(self, ind, l, r):
60         if self.tree[ind].left == l and self.tree[ind].right == r:
61             return self.tree[ind].mx
62         mid = (self.tree[ind].left + self.tree[ind].right) // 2
63         if r < mid:
64             return self._query2(ind * 2, l, r)
65         elif l > mid:
66             return self._query2(ind * 2 + 1, l, r)
67         else:
68             return max(self._query2(ind * 2, l, mid), self._query2(ind * 2 + 1, mid + 1, r))
69 
70     def query(self, ql, qr):
71         return self._query(1, ql, qr)
72 
73     # 深度遍历打印数组
74     def _show_arr(self, i):
75         if self.tree[i].left == self.tree[i].right and self.tree[i].left != -1:
76             print(self.tree[i].mx, end=" ")
77         if 2 * i < len(self.tree):
78             self._show_arr(i * 2)
79             self._show_arr(i * 2 + 1)
80 
81     # 显示更新后的数组的样子
82     def show_arr(self, ):
83             self._show_arr(1)

 

posted @ 2022-04-11 23:02  r1-12king  阅读(342)  评论(0编辑  收藏  举报