leetcode@ [307] Range Sum Query - Mutable / 线段树模板
Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.
The update(i, val) function modifies nums by updating the element at index i to val.
Example:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
Note:
- The array is only modifiable by the update function.
- You may assume the number of calls to update and sumRange function is distributed evenly
public class NumArray { private SegmentTreeNode root = null; private int size = 0; public NumArray(int[] nums) { root = buildInSegmentTree(nums, 0, nums.length - 1); size = nums.length; } void update(int i, int val) { if(i<0 || i>=size) return; updateInSegmentTree(root, i, val); } public int sumRange(int i, int j) { if(i>j || i<0 || j>=size) return -1; return querySum(root, i, j); } class SegmentTreeNode{ int lc = 0, rc = 0, sum = 0; SegmentTreeNode left = null, right = null; SegmentTreeNode(int l, int r, int val) { lc = l; rc = r; sum = val; } } public SegmentTreeNode buildInSegmentTree(int []nums, int l, int r) { if(l > r) return null; if(l == r) { SegmentTreeNode leaf = new SegmentTreeNode(l, r, nums[l]); return leaf; } SegmentTreeNode root = new SegmentTreeNode(l, r, 0); int mid = (l + r) >> 1; root.left = buildInSegmentTree(nums, l, mid); root.right = buildInSegmentTree(nums, mid+1, r); root.sum = root.left.sum + root.right.sum; return root; } public void updateInSegmentTree(SegmentTreeNode root, int i, int val) { if(root.lc == root.rc && root.lc == i) { root.sum = val; return; } int mid = (root.lc + root.rc) >> 1; if(i >= root.lc && i <= mid) updateInSegmentTree(root.left, i, val); else updateInSegmentTree(root.right, i, val); root.sum = root.left.sum + root.right.sum; } public int querySum(SegmentTreeNode root, int i, int j) { if(root.lc == i && root.rc == j) return root.sum; int mid = (root.lc + root.rc) >> 1; if(i <= mid && j <= mid) return querySum(root.left, i, j); else if(i > mid && j > mid) return querySum(root.right, i, j); else return querySum(root.left, i, mid) + querySum(root.right, mid+1, j); } } // Your NumArray object will be instantiated and called as such: // NumArray numArray = new NumArray(nums); // numArray.sumRange(0, 1); // numArray.update(1, 10); // numArray.sumRange(1, 2);
下面附上java版segmentTree模板代码(共有两个文件:一个是SegmentTreeNode.java,另一个是SegmentTree.java。)
package cc150; public class SegmentTreeNode { public int lc, rc, sum, add; SegmentTreeNode left, right; public SegmentTreeNode() { this.lc = 0; this.rc = 0; this.sum = 0; this.add = 0; this.left = null; this.right = null; } public SegmentTreeNode(int l, int r, int val) { this.lc = l; this.rc = r; this.sum = val; this.add = 0; this.left = null; this.right = null; } public static void main(String[] args) { // TODO Auto-generated method stub } }
package cc150; public class SegmentTree { public SegmentTreeNode root = null; int lower_bound, upper_bound; public SegmentTree() { this.root = null; this.lower_bound = 0; this.upper_bound = 0; } public SegmentTree(int l, int r, int []nums) { //@ SegmentTreeNode(left_idx, right_idx, sum). this.root = new SegmentTreeNode(l, r, 0); this.lower_bound = l; this.upper_bound = r; buildSegmentTree(l, r, nums, root); } public void buildSegmentTree(int l, int r, int []nums, SegmentTreeNode s) { SegmentTreeNode sroot = s; if(l > r) return; if(l == r) { sroot.sum = nums[l]; return; } int mid = (l + r) / 2; sroot.left = new SegmentTreeNode(l, mid, 0); buildSegmentTree(l, mid, nums, sroot.left); sroot.right = new SegmentTreeNode(mid+1, r, 0); buildSegmentTree(mid+1, r, nums, sroot.right); sroot.sum = sroot.left.sum + sroot.right.sum; } public void updateByPoint(SegmentTreeNode sroot, int idx, int val) { if(idx == sroot.lc && sroot.lc == sroot.rc) { sroot.sum = val; return; } int mid = (sroot.lc + sroot.rc) / 2; if(idx <= mid) updateByPoint(sroot.left, idx, val); else updateByPoint(sroot.right, idx, val); sroot.sum = sroot.left.sum + sroot.right.sum; } public void updateBySegment(SegmentTreeNode sroot, int l, int r, int val) { if(l == sroot.lc && r == sroot.rc) { sroot.add += val; sroot.sum += val * (r - l + 1); return; } if(sroot.lc == sroot.rc) return; int len = sroot.rc - sroot.lc + 1; if(sroot.add > 0) { sroot.left.add += sroot.add; sroot.right.add += sroot.add; sroot.left.sum += sroot.add * (len - (len/2)); sroot.right.sum += sroot.add * (len/2); sroot.add = 0; } int mid = sroot.lc + (sroot.rc - sroot.lc)/2; if(r <= mid) updateBySegment(sroot.left, l, r, val); else if(l > mid) updateBySegment(sroot.right, l, r, val); else { updateBySegment(sroot.left, l, mid, val); updateBySegment(sroot.right, mid+1, r, val); } sroot.sum = sroot.left.sum + sroot.right.sum; } static int querySum(SegmentTreeNode sroot, int i, int j) { if(i > j) { System.out.println("Invalid Query!"); return -1; } if(i<sroot.lc || j>sroot.rc) return querySum(sroot, sroot.lc, sroot.rc); if(sroot.lc == i && sroot.rc == j) return sroot.sum;/* int len = sroot.rc - sroot.lc + 1; if(sroot.add > 0) { sroot.left.add += sroot.add; sroot.right.add += sroot.add; sroot.left.sum += sroot.add * (len - len/2); sroot.right.sum += sroot.add * (len/2); sroot.add = 0; } */ int mid = (sroot.lc + sroot.rc) / 2; if(j <= mid) return querySum(sroot.left, i, j); else if(i > mid) return querySum(sroot.right, i, j); else return querySum(sroot.left, i, mid) + querySum(sroot.right, mid+1, j); } public static void main(String[] args) { // TODO Auto-generated method stub int []nums = new int[10]; for(int i=0;i<nums.length;++i) nums[i] = i; SegmentTree st = new SegmentTree(0, nums.length-1, nums); int tmp = querySum(st.root, 0, 9); System.out.println(tmp); st.updateByPoint(st.root, 5, 7); System.out.println(querySum(st.root, 0, 9)); st.updateBySegment(st.root, 3, 4, 2); System.out.println(querySum(st.root, 2, 7)); } }