线段树(SegmentTree)基础模板
线段树模板题来源:https://www.lintcode.com/problem/segment-tree-build/description
201. 线段树的构造
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end) {
* this->start = start, this->end = end;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/*
* @param start: start value.
* @param end: end value.
* @return: The root of Segment Tree.
*/
SegmentTreeNode * build(int start, int end) {
// write your code here
if(start > end) return nullptr;
auto root = new SegmentTreeNode(start, end);
if(start < end){
auto mid = (start + end) / 2;
root->left = build(start, mid);
root->right = build(mid+1, end);
}
return root;
}
};
202. 线段树的查询
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
* @param root: The root of segment tree.
* @param start: start value.
* @param end: end value.
* @return: The maximum number in the interval [start, end]
*/
int query(SegmentTreeNode * root, int start, int end) {
// write your code here
auto mid = root->start + (root->end - root->start) / 2;
if(start <= root->start && root->end <= end) return root->max;
else if(start > mid) return query(root->right, start, end);
else if(end <= mid) return query(root->left, start, end);
else return max(query(root->left, start, mid), query(root->right, mid+1, end));
}
};
203. 线段树的修改
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
* @param root: The root of segment tree.
* @param index: index.
* @param value: value
* @return: nothing
*/
void modify(SegmentTreeNode * root, int index, int value) {
// write your code here
if(root == nullptr || index > root->end || index < root->start) return;
if(root->start == root->end) {
root->max = value;
return ;
}
auto mid = root->start + (root->end - root->start) / 2;
if(index > mid){
modify(root->right, index, value);
} else {
modify(root->left, index, value);
}
root->max = max(root->right->max, root->left->max);
}
};
247. 线段树查询 II
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, count;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int count) {
* this->start = start;
* this->end = end;
* this->count = count;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/*
* @param root: The root of segment tree.
* @param start: start value.
* @param end: end value.
* @return: The count number in the interval [start, end]
*/
int query(SegmentTreeNode * root, int start, int end) {
// write your code here
if(root == NULL) return 0;
if(start <= root->start && root->end <= end) return root->count;
auto mid = root->start + (root->end - root->start) / 2;
if(end <= mid) return query(root->left, start, end);
else if(start > mid) return query(root->right, start, end);
else return query(root->left, start, mid) + query(root->right, mid + 1, end);
}
};
248. 统计比给定整数小的数的个数
class Solution {
public:
/**
* @param A: An integer array
* @param queries: The query list
* @return: The number of element in the array that are smaller that the given integer
*/
struct SegmentTreeNode {
int start, end, count;
SegmentTreeNode* left, *right;
SegmentTreeNode(int start_, int end_) :
start(start_), end(end_), count(0), left(nullptr), right(nullptr){}
};
SegmentTreeNode* build(int start, int end){
if(start > end) return nullptr;
auto root = new SegmentTreeNode(start, end);
if(start != end){
auto mid = (start + end) / 2;
root->left = build(start, mid);
root->right = build(mid+1, end);
}
return root;
}
void add(SegmentTreeNode* root, int index, int val){
auto mid = (root->start + root->end) / 2;
if(root->start == root->end) {
root->count += val;
return;
}
if(index > mid){
add(root->right, index, val);
} else {
add(root->left, index, val);
}
root->count += val;
}
int query(SegmentTreeNode* root, int start, int end){
if(start <= root->start && root->end <= end) return root->count;
auto mid = (root->start + root->end) / 2;
if(start > mid) return query(root->right, start, end);
else if(end <= mid) return query(root->left, start, end);
else return query(root->right, mid+1, end) + query(root->left, start, mid);
}
vector<int> countOfSmallerNumber(vector<int> &A, vector<int> &queries) {
// write your code here
auto root = build(0, 10000);
for(int i = 0; i < A.size(); i++){
add(root, A[i], 1);
}
vector<int> ret;
for(int i = 0; i < queries.size(); i++){
ret.push_back(query(root, 0, queries[i]-1));
}
return ret;
}
};
249. 统计前面比自己小的数的个数
class Solution {
public:
/**
* @param A: an integer array
* @return: A list of integers includes the index of the first number and the index of the last number
*/
struct SegmentTreeNode {
int start, end, count;
SegmentTreeNode* left, *right;
SegmentTreeNode(int start_, int end_) :
start(start_), end(end_), count(0), left(nullptr), right(nullptr){}
};
SegmentTreeNode* build(int start, int end){
if(start > end) return nullptr;
auto root = new SegmentTreeNode(start, end);
if(start != end){
auto mid = (start + end) / 2;
root->left = build(start, mid);
root->right = build(mid+1, end);
}
return root;
}
void add(SegmentTreeNode* root, int index, int val){
auto mid = (root->start + root->end) / 2;
if(root->start == root->end) {
root->count += val;
return;
}
if(index > mid){
add(root->right, index, val);
} else {
add(root->left, index, val);
}
root->count += val;
}
int query(SegmentTreeNode* root, int start, int end){
if(start > end) return 0;
if(start <= root->start && root->end <= end) return root->count;
auto mid = (root->start + root->end) / 2;
if(start > mid) return query(root->right, start, end);
else if(end <= mid) return query(root->left, start, end);
else return query(root->right, mid+1, end) + query(root->left, start, mid);
}
vector<int> countOfSmallerNumberII(vector<int> &A) {
// write your code here
auto root = build(0, 10000);
vector<int> ret;
for(int i = 0; i < A.size(); i++){
auto c = query(root, 0, A[i]-1);
ret.push_back(c);
add(root, A[i], 1);
}
return ret;
}
};