线段树基础详解
转载自:点击打开链接(基础版)
进阶版
基本概念:
简单的记法: 足够的空间 = 数组大小n的四倍。
实际上足够的空间 = (n向上扩充到最近的2的某个次方)的两倍。
举例子:假设数组长度为5,就需要5先扩充成8,8*2=16.线段树需要16个元素。如果数组元素为8,那么也需要16个元素。
所以线段树需要的空间是n的两倍到四倍之间的某个数,一般就开4*n的空间就好,如果空间不够,可以自己算好最大值来省点空间。
以下以维护数列区间和的线段树为例,演示最基本的线段树代码。
(0)定义:
- #define maxn 100007 //元素总个数
- #define ls l,m,rt<<1
- #define rs m+1,r,rt<<1|1
- int Sum[maxn<<2],Add[maxn<<2];//Sum求和,Add为懒惰标记
- int A[maxn],n;//存原数组数据下标[1,n]
(1) 建树:
(2) //PushUp函数更新节点信息 ,这里是求和
(3) void PushUp(int rt){Sum[rt]=Sum[rt<<1]+Sum[rt<<1|1];}
(4) //Build函数建树
(5) void Build(int l,int r,int rt){ //l,r表示当前节点区间,rt表示当前节点编号
(6) if(l==r) {//若到达叶节点
(7) Sum[rt]=A[l];//储存数组值
(8) return;
(9) }
(10) int m=(l+r)>>1;
(11) //左右递归
(12) Build(l,m,rt<<1);
(13) Build(m+1,r,rt<<1|1);
(14) //更新信息
(15) PushUp(rt);
(16) }
(2)点修改:
假设A[L]+=C:
1. void Update(int L,int C,int l,int r,int rt){//l,r表示当前节点区间,rt表示当前节点编号
2. if(l==r){//到叶节点,修改
3. Sum[rt]+=C;
4. return;
5. }
6. int m=(l+r)>>1;
7. //根据条件判断往左子树调用还是往右
8. if(L <= m) Update(L,C,l,m,rt<<1);
9. else Update(L,C,m+1,r,rt<<1|1);
10. PushUp(rt);//子节点更新了,所以本节点也需要更新信息
11. }
(3)区间修改:
- void Update(int L,int R,int C,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号
- if(L <= l && r <= R){//如果本区间完全在操作区间[L,R]以内
- Sum[rt]+=C*(r-l+1);//更新数字和,向上保持正确
- Add[rt]+=C;//增加Add标记,表示本区间的Sum正确,子区间的Sum仍需要根据Add的值来调整
- return ;
- }
- int m=(l+r)>>1;
- PushDown(rt,m-l+1,r-m);//下推标记
- //这里判断左右子树跟[L,R]有无交集,有交集才递归
- if(L <= m) Update(L,R,C,l,m,rt<<1);
- if(R > m) Update(L,R,C,m+1,r,rt<<1|1);
- PushUp(rt);//更新本节点信息
- }
(4)区间查询:
询问A[L,R]的和
首先是下推标记的函数:
1. void PushDown(int rt,int ln,int rn){
2. //ln,rn为左子树,右子树的数字数量。
3. if(Add[rt]){
4. //下推标记
5. Add[rt<<1]+=Add[rt];
6. Add[rt<<1|1]+=Add[rt];
7. //修改子节点的Sum使之与对应的Add相对应
8. Sum[rt<<1]+=Add[rt]*ln;
9. Sum[rt<<1|1]+=Add[rt]*rn;
10. //清除本节点标记
11. Add[rt]=0;
12. }
13. }
然后是区间查询的函数:
- int Query(int L,int R,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号
- if(L <= l && r <= R){
- //在区间内,直接返回
- return Sum[rt];
- }
- int m=(l+r)>>1;
- //下推标记,否则Sum可能不正确
- PushDown(rt,m-l+1,r-m);
- //累计答案
- int ANS=0;
- if(L <= m) ANS+=Query(L,R,l,m,rt<<1);
- if(R > m) ANS+=Query(L,R,m+1,r,rt<<1|1);
- return ANS;
- }
(5)函数调用:
- //建树
- Build(1,n,1);
- //点修改
- Update(L,C,1,n,1);
- //区间修改
- Update(L,R,C,1,n,1);
- //区间查询
- int ANS=Query(L,R,1,n,1);