线段树简单思路
线段树
1、了解存储结构
存放左右范围,和左右子节点的预定义
#define lc p<<1
#define rc p<<1|1
struct node {
int l, r, sum;
}tr[4 * N];//注意要开4倍,感兴趣自己搜
2、递归建树
思路
- 对结点的左右范围赋值
- 判断是不是叶子结点
- 建左右子树
- 更新当前结点的sum值
void build(int p,int l,int r) {
tr[p].l = l, tr[p].r = r;
if (l == r)return;//表明是叶子结点,不需要再创建
int m = l + r >> 1;
build(lc, l, m);
build(rc, m + 1, r);
//建完树后更新当前结点
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
3、区间查询
思路
- 将区间分裂开,如果被覆盖,就直接返回,如果没有,就进行左右子树的查询
- 如果左子节点的区间与查询区间有交集,就查询,右边同理
int query(int p,int x,int y) {
//如果区间被包含,直接return
if (x <= tr[p].l && tr[p].r <= y)
return tr[p].sum;
int m = tr[p].l + tr[p].r >> 1;
int sum = 0;
//由m和xy进行对比,判断是不是与子树由交集
if (x <= m) sum += query(lc, x, y);
if (y > m) sum += query(rc, x, y);
return sum;
}
4、区间修改
运用懒标记
-
简单解释就是,如果修改的区间完全覆盖当前的区间,比如对2~7区间每个值都加3
那么就是先将当前的p的sum进行修改,sum+=(r-l+1)*3,然后将当前结点打上懒标记add+=3
所以存储结构要改为:
struct node {
int l, r, sum, add;
}tr[4 * N]
-
接上,如果不覆盖,先将当前懒标记向下传递,即pushdown(),然后递归更新子区间(类似区间查询的流程)
修改完毕后更新当前节点,此时引入pushup函数
void pushup(int p) { tr[p].sum = tr[lc].sum + tr[rc].sum; }
void pushdown(int p) {
if (!tr[p].add) return;
int add = tr[p].add;
tr[lc].sum += add * (tr[lc].r - tr[lc].l + 1);
tr[rc].sum += add * (tr[rc].r - tr[rc].l + 1);
tr[lc].add += add;
tr[rc].add += add;
tr[p].add = 0;
}
void update(int p,int x,int y,int k) {
//如果完全覆盖就直接更新
if (x <= tr[p].l && tr[p].r <= y) {
tr[p].sum += k * (tr[p].r - tr[p].l + 1);
tr[p].add += k;
return;
}
//如果不是完全覆盖
//先将add向下传
pushdown(p);
int m = tr[p].l + tr[p].r >> 1;
//再根据左右子树的区间重叠情况更新
if (x <= m)update(lc, x, y, k);
if (y > m)update(rc, x, y, k);
//最后更新一下当前节点
pushup(p);
}
5、完整代码
针对例题https://www.luogu.com.cn/problem/P3372
以下是完整代码,与上面略有差别
#include<iostream>
using namespace std;
typedef long long ll;
#define endl "\n"
#define IOS ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
const int N = 1e5 + 10;
#define lc p<<1
#define rc p<<1|1
struct node {
int l, r;
ll sum, add;
}tr[4 * N];
int arr[N];
void pushup(int p) {tr[p].sum = tr[lc].sum + tr[rc].sum; }
void pushdown(int p) {
if (!tr[p].add) return;
ll add = tr[p].add;
tr[lc].sum += add * (tr[lc].r - tr[lc].l + 1);
tr[rc].sum += add * (tr[rc].r - tr[rc].l + 1);
tr[lc].add += add;
tr[rc].add += add;
tr[p].add = 0;
}
void build(int p,int l,int r) {
tr[p] = { l,r,arr[l],0 };
if (l == r)return;
int m = l + r >> 1;
build(lc, l, m);
build(rc, m + 1, r);
//建完树后更新当前结点
pushup(p);//tr[p].sum = tr[lc].sum + tr[rc].sum;
}
ll query(int p,int x,int y) {
if (x <= tr[p].l && tr[p].r <= y)
return tr[p].sum;
pushdown(p);//此处先更新一下再查询
int m = tr[p].l + tr[p].r >> 1;
ll sum = 0;
if (x <= m) sum += query(lc, x, y);
if (y > m) sum += query(rc, x, y);
return sum;
}
void update(int p,int x,int y,int k) {
if (x <= tr[p].l && tr[p].r <= y) {
tr[p].sum += k * (tr[p].r - tr[p].l + 1);
tr[p].add += k;
return;
}
pushdown(p);
int m = tr[p].l + tr[p].r >> 1;
if (x <= m)update(lc, x, y, k);
if (y > m)update(rc, x, y, k);
pushup(p);
}
int n,m,choice,x,y,k;
int main() {
IOS;
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> arr[i];
build(1, 1, n);
while (m--) {
cin >> choice;
if (choice == 1) {
cin >> x >> y >> k;
update(1, x, y, k);
}
else {
cin >> x >> y;
cout << query(1, x, y) << endl;
}
}
return 0;
}