『学习笔记』线段树
线段树和树状数组都是用来优化序列操作的数据结构。
线段树理解容易,常数大,解决问题范围广;树状数组理解比较困难,常数非常小,能解决的问题就没有线段树多了,可以说树状数组能解决的问题是线段树能解决的问题的子集。
线段树基本概念
线段树是一个二叉树,每个节点表示一个区间。
对于任意节点,要么是叶子节点,要么两个儿子都存在。
它可以快速在序列上修改及查询元素,可以是区间修改或查询。每次修改或查询的时间复杂度为 \(\mathcal{O}(\log n)\)。在使用之前,还需花费 \(\mathcal{O}(n)\) 的时间建树。
那么每个节点存什么?
- 如果是叶子节点,就存需要执行操作的序列的对应项。具体是哪项下面再说。
- 否则,就存这个节点的左右儿子之和或最小值、最大值、乘积等等。以求和为例,计算公式为 \(t_x=t_{\operatorname{left\_son}(x)}+t_{\operatorname{right\_son}(x)}\)。这个节点的值就是这个节点表示的区间之和。
为方便表示,本文中的 ls(x)
均代表 \(x\) 节点的左儿子节点,rs(x)
同理。
每个节点的儿子表示的区间都是当前节点区间的一半,左儿子表示的是 \(\left[l,\left\lfloor \dfrac{l+r}{2} \right\rfloor\right]\),右儿子表示的是 \(\left[\left\lfloor \dfrac{l+r}{2} \right\rfloor+1,r\right]\)。
例如,要使用一个长度为 \(8\) 的序列 \(a=[1,1,4,5,1,4,1,9]\) 构造一棵线段树,那么这棵线段树长这样:
各个非叶子节点的值都是根据 \(t_x=t_{\operatorname{ls}(x)}+t_{\operatorname{rs}(x)}\) 来计算的。
图中每个节点上面写的是这个节点包含的区间,下面写的是各个节点的值。
非叶子节点的值的计算过程也写了上去。
可以发现,表示第 \(i\) 个数的叶子节点的值就是 \(a_i\)。
线段树差不多就长这样子,下面来看详细的操作过程。
普通线段树(不支持区间修改)
如何存储
我们可以使用二叉堆的方式存储:根节点的位置为 \(1\),每个节点的左右儿子的位置分别为 \(i \times 2\) 和 \(i \times 2+1\)。
也就是说,你遍历存储这棵树的数组,和层次遍历这棵树一样。
若有空缺的位置,需要留着。
因为存储的节点除了最后一层还有许多个节点,所以数组长度要比 \(n\) 大。
有人计算过,存树的数组需要开到 \(4n\) 才行。
树的节点结构体定义如下:
struct node{
int l,r; // 表示区间
T v; // 当前值
}t[N<<2]; // 线段树存储数组
其中的 T
表示线段树需要维护的值的类型,下文也一样,就不多说了。
别问为啥这么写,我写数据结构都喜欢封装成一个类用,用起来爽,里面就比不用类要多一点东西了。
位运算优化
应该很容易猜到,在操作过程中对 \(i \times 2\) 和 \(i \times 2+1\) 的计算非常多。然而直接用乘号就有点慢了。
所以就要使用我们的卡常神器位运算了!
首先看 \(i \times 2\),没什么好说的。
众所周知,想让一个数乘 \(2^{x}\),只需使其左移 \(x\) 位即可。所以 \(i \times 2\) 就是 \(i\) 左移 \(1\) 位。
右儿子的运算又多了个加一,咋整?直接加?那是绝对不可能的!
一个数左移一位后,最后一位一定为 \(0\)。要加一,就是让它变成 \(1\)。
于是就可以——将这个数或上 \(1\)。
这样就让计算速度快起来了一点点嘛!
于是,可以定义两个函数:
inline int ls(int rt){return rt<<1;} // 左儿子
inline int rs(int rt){return rt<<1|1;} // 右儿子
前面最好加上 inline
,防止你算一下儿子在哪都要再丢一个东西到栈里。
还有几个小优化,例如计算一个区间的中间分界点 \(mid\) 时,也可以用位运算来实现除二操作。
还有一个微不足道的,就是计算 \(4n\) 时可以直接左移 \(2\) 位。
废话多不多?
建树
呵呵,正文终于开始了。
上面提到过,建树要用 \(\mathcal{O}(n)\) 的时间复杂度进行。因为有 \(n\) 个元素。
使用深度优先搜索的方式来遍历整棵树。遍历的中间,为各个节点的 \(l\) 和 \(r\) 赋值。
遍历到叶子节点(当前的 \(l=r\))时,这个叶子节点的值就应该是 \(a_l\)。
两个儿子都遍历过后,需要通过已经处理好的儿子实时计算当前节点的值。
流程大概是这样的:
为方便起见,我们定义一个函数 pushup
用来计算当前节点的值。
inline void pushup(int rt){
t[rt].v=t[ls(rt)].v+t[rs(rt)].v; // 计算当前节点的值
}
通过修改 pushup
函数,可以直接修改线段树维护内容。
例如,将其修改为维护最大值的线段树:
inline void pushup(int rt){
t[rt].v=max(t[ls(rt)].v,t[rs(rt)].v);
}
看代码吧!还是代码形象一点:
void build(int rt,int l,int r){ // rt 表示当前节点,l 和 r 表示当前节点表示的区间
t[rt].l=l,t[rt].r=r; // 首先的一步就是指定当前节点表示的区间范围
if(l==r){ // 叶子节点情况
t[rt].v=a[l]; // 为叶子节点赋值
return; // 碰到叶子节点了就要回溯了
}
int mid=l+r>>1; // 计算区间分界点
build(ls(rt),l,mid); // 递归遍历左儿子
build(rs(rt),mid+1,r); // 递归遍历右儿子
pushup(rt); // 计算当前节点值
}
应该很好理解,就是常数...
单点修改
废话了那么多,终于开始说操作了...
单点修改,就是要修改数列中的一个数。
那么在线段树中,就是修改其中一个叶子节点,我们需要修改某个叶子节点后维护整棵线段树,使其还是保持原来的特性(非叶子节点等于两个儿子的和等特性)。
例如,要修改下标为 \(6\) 的数为 \(8\)。
从根节点一直向下找,查找要修改的叶子节点。
若当前搜索的的节点不是叶子节点,那么就需要判断需要修改的叶子节点在左儿子里还是右儿子里:
- 令 \(mid \gets \left\lfloor \dfrac{l+r}{2} \right\rfloor\)。
- 若下标 \(idx \leq mid\),则说明在左儿子里,向左儿子中搜索。
- 否则,就去右儿子。
代码如下:
int mid=t[rt].l+t[rt].r>>1; // 找中间点
if(idx<=mid) update(ls(rt),idx,v); // 进左儿子
else update(rs(rt),idx,v); // 右儿子
最终一定会找到一个叶子节点,它就是我们需要修改的。
单纯修改叶子节点会破坏整棵线段树的平衡,所以回溯时需要更新查找需要更改的叶子节点时经过的节点。
在函数末尾加上一句 pushup(rt)
即可。
完整代码:
// rt 是当前节点,idx 是需要修改的数的下标,v 是要替换的数(或累加的数)
void update(int rt,int idx,T v){
if(t[rt].l==t[rt].r){ // 找到叶子节点的情况
t[rt].v=v; // 修改
return; // 回溯
}
int mid=t[rt].l+t[rt].r>>1;
if(idx<=mid) update(ls(rt),idx,v);
else update(rs(rt),idx,v);
pushup(rt); // 找到叶子节点后需要将路径上的所有节点都更新一下,从下向上更新
}
很容易看出来,时间复杂度是 \(\mathcal{O}(\log n)\)。别看比暴力还差,区间查询可是 \(\mathcal{O}(\log n)\) 的。
单点查询
没什么好说的,就是从一棵树上找到叶子节点,return
就是了。
这个应该看代码就够了。
T query(int rt,int idx){ // 参数就不多说了
if(t[rt].l==t[rt].r){ // 找到目标
return t[rt].v;
}
int mid=t[rt].l+t[rt].r>>1;
if(idx<=mid) return query(ls(rt),idx); // 在左儿子中
else return query(rs(rt),idx); // 右儿子
// 这里不需要 pushup,因为没有任何修改
}
区间查询
我们之所以维护整棵线段树就是为了使这个操作的时间复杂度变为 \(\mathcal{O}(\log n)\)。
暴力查询时是一个一个累加,但有了线段树就不一样了。
线段树的节点除了叶子节点都存储的是一个区间的和,若某个节点表示的区间在查询区间之内,那么就可以 \(\mathcal{O}(1)\) 地累加出这个节点表示的区间的和。
那如果当前节点表示的区间和查询区间有交集,但并不是查询区间的子集,咋办?
直接看看左右儿子表示的区间是否与查询区间有交集,如果有,则进入相应的儿子查询(两个儿子随便去,但没有都不去的情况,那样当前节点表示的区间要么是查询区间的子集,要么就与查询区间没关系)。
好像有点不好理解...看图吧。
应该步骤写的很清楚了,可以通过代码进一步理解。
T query(int rt,int l,int r){ // l 和 r 表示查询区间!不是当前节点表示区间!
if(l<=t[rt].l && t[rt].r<=r){ // 刚好是查询区间的子集
return t[rt].v; // 直接返回
}
T res=0; // 因为左右儿子都可能去,所以定义一个变量累加
int mid=t[rt].l+t[rt].r>>1;
if(l<=mid) res+=query(ls(rt),l,r); // 若查询区间左端点在左儿子右端点之前,则表示左儿子包含
if(r>mid) res+=query(rs(rt),l,r); // 查询区间右端点在右儿子左端点之后,同上
return res;
}
代码也不长,应该挺好理解吧?
下面看一道例题。
P3374 【模板】树状数组 1
题目大意
给定一个长度为 \(n\) 的序列 \(a\),\(m\) 个操作,每次操作包含 \(3\) 个整数:
1 x k
:将第 \(x\) 个数加上 \(k\)。2 x y
:查询区间 \([x,y]\) 内每个数的和。
思路
虽然是树状数组题,但拿来写线段树也是不错的选择。
这题要写的模板就是单点查询区间修改的模板,我这给出的代码是我封装好的线段树类。
其实不用看类中那些东西,就看那三个函数和私有的几个函数就行了。
应该不用写注释吧(
代码
#include <iostream>
using namespace std;
template<typename T=int>
inline T read(){
T X=0; bool flag=1; char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-') flag=0; ch=getchar();}
while(ch>='0' && ch<='9') X=(X<<1)+(X<<3)+ch-'0',ch=getchar();
if(flag) return X;
return ~(X-1);
}
template<typename T=int>
inline void write(T X){
if(X<0) putchar('-'),X=~(X-1);
T s[20],top=0;
while(X) s[++top]=X%10,X/=10;
if(!top) s[++top]=0;
while(top) putchar(s[top--]+'0');
putchar('\n');
}
const int N=5e5+5;
int n,m,a[N],op,x,y;
template<class T=long long>
class SgT{
public:
SgT(){
a_res=new int[N];
for(int i=0; i<N; i++){
a_res[i]=0;
}
a=a_res;
}
SgT(int rt,int l,int r,int *_a=nullptr):a(_a==nullptr ? a_res : _a){
build(rt,l,r);
}
~SgT(){
delete[] a_res;
}
void build(int rt,int l,int r){
t[rt].l=l,t[rt].r=r;
if(l==r){
t[rt].v=a[l];
return;
}
int mid=l+r>>1;
build(ls(rt),l,mid);
build(rs(rt),mid+1,r);
pushup(rt);
}
void update(int rt,int idx,T v){
if(t[rt].l==t[rt].r){
t[rt].v+=v;
return;
}
int mid=t[rt].l+t[rt].r>>1;
if(t[rt].l<=mid) update(ls(rt),idx,v);
else update(rs(rt),idx,v);
pushup(rt);
}
T query(int rt,int l,int r){
if(l<=t[rt].l && t[rt].r<=r){
return t[rt].v;
}
T res=0;
int mid=t[rt].l+t[rt].r>>1;
if(l<=mid) res+=query(ls(rt),l,r);
if(r<mid) res+=query(rs(rt),l,r);
return res;
}
private:
int *a,*a_res;
struct node{
int l,r;
T v;
}t[N<<2];
inline int ls(int rt){return rt<<1;}
inline int rs(int rt){return rt<<1|1;}
inline void pushup(int rt){
t[rt].v=t[ls(rt)].v+t[rs(rt)].v;
}
};
int main(){
n=read(),m=read();
for(int i=1; i<=n; i++){
a[i]=read();
}
SgT t(1,1,n,a);
while(m--){
op=read(),x=read(),y=read();
if(op==1){
t.update(1,x,y);
}else{
write(t.query(1,x,y));
}
}
return 0;
}
区间修改
说了那么多,就差你一个区间修改了。
如果你直接用单点修改的方法一个一个改,那么时间复杂度就变成 \(\mathcal{O}(n \log n)\) 了,比暴力还差。
那我们可不可以参考区间查询的思想呢?一次修改一个区间?那就需要一个叫懒标记的东西了。
懒标记
我们给节点的结构体加一个变量,叫 \(tag\),懒标记的意思。它表示这个节点之下的所有节点的 \(v\) 都需要加上这个 \(tag\)。
这样的话,一次修改一个区间就能实现了:若需修改区间包含某个节点表示的区间,直接将这个节点的 \(tag\) 加上需要增加的值。
可以这样理解懒标记:放寒假了,老师每过一段时间给你布置一次作业(修改一次),你却只是记住有哪些作业(修改懒标记),在开学时(查询)才写(将标记下传)。
除了查询,修改时也需要下传懒标记,节点后代修改(或查询)时需要。
接下来说说如何下传懒标记:
首先一步,就是将懒标记给左右儿子都加上。
还需要修改两个儿子的值,都是修改成儿子表示的区间长度乘上父亲节点的懒标记。因为儿子包含的每一个数都要加上父节点的懒标记,所以要将懒标记乘上长度。
我们将下传懒标记的函数定义为 pushdown()
:
inline void pushdown(int rt){
t[ls(rt)].tag+=t[rt].tag; // 懒标记传下去
t[ls(rt)].v+=t[rt].tag*(t[ls(rt)].r-t[ls(rt)].l+1); // 修改值
// 右儿子同上
t[rs(rt)].tag+=t[rt].tag;
t[rs(rt)].v+=t[rt].tag*(t[rs(rt)].r-t[rs(rt)].l+1);
t[rt].tag=0; // 记得将父节点的懒标记置 0
}
P3372 【模板】线段树 1
题目大意
需要维护一个序列,支持区间改查。
思路
没别的,看代码就行了。
顺便熟悉一下区间改查线段树。
代码
#include <iostream>
using namespace std;
template<typename T=int>
inline T read(){
T X=0; bool flag=1; char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-') flag=0; ch=getchar();}
while(ch>='0' && ch<='9') X=(X<<1)+(X<<3)+ch-'0',ch=getchar();
if(flag) return X;
return ~(X-1);
}
template<typename T=int>
inline void write(T X){
if(X<0) putchar('-'),X=~(X-1);
T s[20],top=0;
while(X) s[++top]=X%10,X/=10;
if(!top) s[++top]=0;
while(top) putchar(s[top--]+'0');
putchar('\n');
}
const int N=1e5+5;
int n,m,a[N],op,x,y,k;
template<class T=long long>
class SgT{
public:
SgT(int rt=-1,int l=0,int r=0,int *_a=nullptr):a_res(new int[N]),a(_a==nullptr ? a_res : _a){
for(int i=0; i<N; i++){
a_res[i]=0;
}
if(rt!=-1) build(rt,l,r);
}
~SgT(){
delete[] a_res;
}
void build(int rt,int l,int r){
t[rt].l=l,t[rt].r=r;
t[rt].tag=0;
if(l==r){
t[rt].v=a[l];
return;
}
int mid=l+r>>1;
build(ls(rt),l,mid);
build(rs(rt),mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,T v){
if(l<=t[rt].l && t[rt].r<=r){
t[rt].v+=v*(t[rt].r-t[rt].l+1);
t[rt].tag+=v;
return;
}
pushdown(rt);
int mid=t[rt].l+t[rt].r>>1;
if(l<=mid) update(ls(rt),l,r,v);
if(r>mid) update(rs(rt),l,r,v);
pushup(rt);
}
T query(int rt,int l,int r){
if(l<=t[rt].l && t[rt].r<=r){
return t[rt].v;
}
pushdown(rt);
T res=0;
int mid=t[rt].l+t[rt].r>>1;
if(l<=mid) res+=query(ls(rt),l,r);
if(r>mid) res+=query(rs(rt),l,r);
return res;
}
private:
int *a,*a_res;
struct node{
int l,r;
T v,tag;
}t[N<<2];
inline int ls(int rt){return rt<<1;}
inline int rs(int rt){return rt<<1|1;}
inline void pushup(int rt){
t[rt].v=t[ls(rt)].v+t[rs(rt)].v;
}
inline void pushdown(int rt){
t[ls(rt)].tag+=t[rt].tag;
t[ls(rt)].v+=t[rt].tag*(t[ls(rt)].r-t[ls(rt)].l+1);
t[rs(rt)].tag+=t[rt].tag;
t[rs(rt)].v+=t[rt].tag*(t[rs(rt)].r-t[rs(rt)].l+1);
t[rt].tag=0;
}
};
int main(){
n=read(),m=read();
for(int i=1; i<=n; i++){
a[i]=read();
}
SgT t(1,1,n,a);
while(m--){
op=read();
if(op==1){
x=read(),y=read(),k=read();
t.update(1,x,y,k);
}else{
x=read(),y=read();
write(t.query(1,x,y));
}
}
return 0;
}
推荐题单
暂时就学了这么点线段树。
从易到难排序。
要是能全刷完,那一定是线段树大神了。反正我是刷不完的。