【算法微解读】浅谈线段树
浅谈线段树
(来自TRTTG大佬的供图)
线段树个人理解和运用时,认为这个是一个比较实用的优化算法。
这个东西和区间树有点相似,是一棵二叉搜索树,也就是查找节点和节点所带值的一种算法。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN),这个时间复杂度非常的理想,但是空间复杂度在应用时是开4N的。
所以这个算法有优势,也有劣势。
我们提出一个问题
如果当前有一个区间,需要你在给定区间内做以下操作:
- l,z 在l上加上z
- l 查询l的值
- l,r,z 在[l,r]区间所有数都+z
- l,r, 查询l到r之间的和
你是不是在想,暴力解决一切问题,但是如果给你的数据是极大的,暴力完全做不了。
那么我们就需要使用线段树了。
我们就以这个问题为例来对线段树进行讲解。
先提供一下这个题目的AC代码
#include <bits/stdc++.h>
using namespace std;
const int maxn=10010;
struct segment_tree{
int l,r,sum,lazy;
}tree[maxn<<2];
int a[maxn];
int n,m;
void pushup(int nod) {
tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum;
}
void pushdown(int nod,int l,int r) {
int mid=(l+r)>>1;
tree[nod<<1].sum+=(mid-l+1)*tree[nod].lazy;
tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].lazy;
tree[nod<<1].lazy+=tree[nod].lazy;
tree[(nod<<1)+1].lazy+=tree[nod].lazy;
tree[nod].lazy=0;
}
void build(int l,int r,int nod) {
if (l==r) {
tree[nod].sum=a[l];
tree[nod].l=l;
tree[nod].r=r;
tree[nod].lazy=0;
return;
}
int mid=(l+r)>>1;
build(l,mid,nod<<1);
build(mid+1,r,(nod<<1)+1);
pushup(nod);
}
void update1(int l,int r,int k,int value,int nod) {
if (l==r) {
tree[nod].sum+=value;
return ;
}
int mid=(l+r)>>1;
pushdown(nod,l,r);
if (k<=mid) update1(l,mid,k,value,nod<<1);
else update1(mid+1,r,k,value,(nod<<1)+1);
pushup(nod);
}
int query1(int l,int r,int nod,int k) {
if (l==r) return tree[nod].sum;
int mid=(l+r)>>1;
pushdown(nod,l,r);
if (k<=mid) return query1(l,mid,nod<<1,k);
else return query1(mid+1,r,(nod<<1)+1,k);
}
void update2(int l,int r,int ll,int rr,int nod,int value) {
if (l==ll&&r==rr) {
tree[nod].sum+=(r-l+1)*value;
tree[nod].lazy+=value;
return;
}
pushdown(nod,l,r);
int mid=(l+r)>>1;
if (rr<=mid) update2(l,mid,ll,rr,nod<<1,value);
else if (ll>mid) update2(mid+1,r,ll,rr,(nod<<1)+1,value);
else {
update2(l,mid,ll,mid,nod<<1,value);
update2(mid+1,r,mid+1,rr,(nod<<1)+1,value);
}
pushup(nod);
}
int query2(int l,int r,int ll,int rr,int nod) {
if (l==ll&r==rr) {
return tree[nod].sum;
}
pushdown(nod,l,r);
int mid=(l+r)>>1;
if (rr<=mid) return query2(l,mid,ll,rr,nod<<1);
else if (ll>mid) return query2(mid+1,r,ll,rr,(nod<<1)+1);
else return query2(l,mid,ll,mid,nod<<1)+query2(mid+1,r,mid+1,rr,(nod<<1)+1);
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,n,1);
while (m--) {
int c,x,y,z;
scanf("%d",&c);
if (c==1) {
scanf("%d%d",&x,&y);
update1(1,n,x,y,1);
}
if (c==2) {
scanf("%d",&x);
printf("%d\n",query1(1,n,1,x));
}
if (c==3) {
scanf("%d%d%d",&x,&y,&z);
update2(1,n,x,y,1,z);
}
if (c==4) {
scanf("%d%d",&x,&y);
printf("%d\n",query2(1,n,x,y,1));
}
}
return 0;
}
线段树的一些基本操作
- 建树
- 单点修改
- 单点查找
- 区间修改
- 区间查找
- pushup(儿子把信息传给父亲)
- pushdown(父亲把信息传给儿子)
(其他的应该都是这些基本操作的变形)
以下我们来逐一讲解一下
结构体
作为一课非常正经的树,我们还是要给它开一个结构体。
struct segment_tree{
int l,r,sum;
}tree[maxn];
关于线段树的一些小提醒
我们写线段树,应该先知道当前节点nod的左右儿子的编号是多少,答案是(nod2)和(nod2+1)
为什么?我们写的线段树应该是一棵满二叉树,所以根据满二叉树节点的特点,我们就可以知道了他的儿子就是以上的答案。
建树
由于是二叉搜索树,也就是一个二叉树,需要做搜索操作。那么我们就是以树状结构来存储数据。
我们来了解一下线段树:
我们设当前的线段树的节点是$$ tree.l\ tree.r $$,也就是当前这段区间的左右l和r。(其实我们在写代码的时候一般是不写这个l和r的)
其次我们还需要当前节点$$ tree.sum $$,表示当前节点所带的值。
在后面我们会讲到$$ tree.lazy $$,表示当前节点的懒标记,来方便我们进行区间修改的一个东西,我们现在先不讲
线段树的基本思想:二分。
那么就可以得到线段树的建树的程序
void build(int l,int r,int nod) {
if (l==r) {
tree[nod].sum=a[l];
tree[nod].l=l;
tree[nod].r=r;
return;
}
int mid=(l+r)>>1;
build(l,mid,nod<<1);
build(mid+1,r,(nod<<1)+1);
pushup(nod);
}
有人在问这个pushup是什么东西?
pushup
pushup就是把儿子的信息上传给自己的父亲节点
以当前问题为例,那么这个pushup的过程就是以下程序
void pushup(int nod) {
tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum;
}
其实也就是把和上传给父亲,非常简单,其他的pushup都是这个道理
单点修改
我们单点修改只需要直接在原节点上修改就可以了。
那么我们废话不多说,直接上代码更好理解
void update(int l,int r,int k,int value,int nod){
if(l==r) {
tree[nod].sum+=value;
return;
}
int mid=(l+r)/2;
if(k<=mid)update(l,mid,k,value,nod*2);
else update(mid+1,r,k,value,nod*2+1);
pushup(nod);
return;
}
这段程序也就是左右查找当前节点,k是我们需要寻找的节点,如果在左区间,那么就在左区间查找,有区间也是这个意思。
单点查找
方法与二分查询基本一致,如果当前枚举的点左右端点相等,即叶子节点,就是目标节点。如果不是,因为这是二分法,所以设查询位置为x,当前结点区间范围为了l,r,中点为mid,则如果x<=mid,则递归它的左孩子,否则递归它的右孩子。
直接上代码
int query(int l,int r,int ll,int rr,int nod){
if(l==ll&&r==rr)return tree[nod].sum;
int mid=(l+r)/2;
if(rr<=mid)return query(l,mid,ll,rr,nod*2);
else if(ll>mid)return query(mid+1,r,ll,rr,nod*2+1);
else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,nod*2+1);
}
非常的简单我们就不多说了
区间修改
我们思考一个问题,如果我们只是像单点修改那样子,用一个循环语句,把要修改区间内的所有点都进行单点修改,那么这个的复杂度应该是O(NlogN),那么这就无法发挥出线段树的优势了。
那么我们应该怎么做呢?
这个时候我们就需要引入一个叫做懒标记的东西。
顾名思义,这个就是一个非常懒的标记,这个就是在我们要的区间内的节点上所加的标记,这个标记也就只有我们要对父亲区间内的数进行修改或者附其他值的时候才会用到的一个东西。
这个标记比较难理解,所以我们稍微讲的详细一点?
首先如果要对一个区间内的节点进行修改,那么就只需要在所需的区间内进行修改,也就只是放在那里,让他不要动。
当你要对接下来的区间内的数进行询问时,我们就需要进行pushdown的操作,这个操作就是要把父亲的懒标记上所拥有的全部信息全部给自己的儿子。
再传给儿子后,我们的父亲就要删除自己的懒标记,因为自己的懒标记已经传给了自己的儿子了,为了不产生错误,我们就要删除父亲的懒标记。
还是与我们这个例题为例,我们的区间修改的应该是这样写的:
void update2(int l,int r,int ll,int rr,int nod,int value) {
if (l==ll&&r==rr) {
tree[nod].sum+=(r-l+1)*value;
tree[nod].lazy+=value;
return;
}
pushdown(nod,l,r);
int mid=(l+r)>>1;
if (rr<=mid) update2(l,mid,ll,rr,nod<<1,value);
else if (l>mid) update2(mid+1,r,ll,rr,(nod<<1)+1,value);
else {
update2(l,mid,ll,mid,nod<<1,value);
update2(mid+1,r,mid+1,rr,(nod<<1)+1,value);
}
pushup(nod);
}
我们再回到这个问题,为什么会有这么多的if语句,我们现在来讲解一下
ll,rr是需要修改的区间。
当你的区间的rr也就是最右边在mid的左边,那么说明我们整个区间就在l和mid之间,就是以下的情况
好了右区间也是一样,其他的情况就是当前的区间分布在mid的左右,那么就分成两部分修改就可以了
那么最后因为儿子可能被改变了,所以我们就要pushup一下。
小提醒
如果你实在不知道什么时候要pushup或者是pushdown,那么多多益善,这样只是会增高你的时间复杂度,而不会影响正确率。
pushdown
这个操作在上文已经讲过是把父亲的lazy下传给儿子的过程。
直接上代码
void pushdown(int nod,int l,int r) {
int mid=(l+r)>>1;
tree[nod<<1].sum+=(mid-l+1)*tree[nod].lazy;
tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].lazy;
tree[nod<<1].lazy+=tree[nod].lazy;
tree[(nod<<1)+1].lazy+=tree[nod].lazy;
tree[nod].lazy=0;
}
区间查询
这个道理和区间修改差不多,还更简单一点。
也不多讲了,直接上代码
int query2(int l,int r,int ll,int rr,int nod) {
if (l==ll&r==rr) {
return tree[nod].sum;
}
pushdown(nod,l,r);
int mid=(l+r)>>1;
if (rr<=mid) return query2(l,mid,ll,rr,nod<<1);
else if (ll>mid) return query2(mid+1,r,ll,rr,(nod<<1)+1);
else return query2(l,mid,ll,mid,nod<<1)+query2(mid+1,r,mid+1,rr,(nod<<1)+1);
}
一些模板题
codevs线段树练习
#include<bits/stdc++.h>
using namespace std;
const int N=100000;
int tree[N*4+10],s[N];
void build(int l,int r,int nod)
{
if(l==r){tree[nod]=s[l];return;}
int mid=(l+r)/2;
build(l,mid,2*nod);
build(mid+1,r,nod*2+1);
tree[nod]=tree[nod*2]+tree[nod*2+1];
return;
}
void update(int l,int r,int k,int value,int nod){
if(l==r){tree[nod]+=value;return;}
int mid=(l+r)/2;
if(k<=mid)update(l,mid,k,value,nod*2);
else update(mid+1,r,k,value,nod*2+1);
tree[nod]=tree[nod*2]+tree[nod*2+1];
return;
}
int query(int l,int r,int ll,int rr,int nod){
if(l==ll&&r==rr)return tree[nod];
int mid=(l+r)/2;
if(rr<=mid)return query(l,mid,ll,rr,nod*2);
else if(ll>mid)return query(mid+1,r,ll,rr,nod*2+1);
else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,nod*2+1);
}
int main()
{
int n,m;
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&s[i]);
build(1,n,1);
scanf("%d",&m);
while(m--){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
if(x==1)update(1,n,y,z,1);
else printf("%d\n",query(1,n,y,z,1));
}
return 0;
}
codevs线段树练习2
#include<bits/stdc++.h>
using namespace std;
const int N=1000000;
int tree[N*4+10],a[N];
void update(int nod,int l,int r,int ll,int rr,int value){
if(l==ll&&r==rr){tree[nod]+=value;return;}
int mid=(l+r)/2;
if(rr<=mid)update(2*nod,l,mid,ll,rr,value);
else if(ll>mid)update(nod*2+1,mid+1,r,ll,rr,value);
else{
update(2*nod,l,mid,ll,mid,value);
update(2*nod+1,mid+1,r,mid+1,rr,value);
}
return;
}
void pushdown(int nod){
tree[nod*2+1]+=tree[nod];
tree[nod*2]+=tree[nod];
tree[nod]=0;
return;
}
int query(int nod,int l,int r,int k){
if(l==r)return a[l]+tree[nod];
int mid=(l+r)/2;
pushdown(nod);
if(k<=mid)return query(2*nod,l,mid,k);
else return query(2*nod+1,mid+1,r,k);
}
int main()
{
int n,m;
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
scanf("%d",&m);
while(m--){
int x,y,z,k;
scanf("%d",&x);
if(x==1){
scanf("%d%d%d",&y,&z,&k);
update(1,1,n,y,z,k);
}
else{
scanf("%d",&y);
printf("%d\n",query(1,1,n,y));
}
}
return 0;
}
codevs线段树练习4
#include<bits/stdc++.h>
using namespace std;
const int N(200000);
struct node{
long long sum,add;
}tree[4*N+10];
int a[N+10];
inline void pushdown(long long nod,long long l,long long r){
long long mid((l+r)>>1);
tree[nod<<1].sum+=(mid-l+1)*tree[nod].add;
tree[(nod<<1)+1].sum+=(r-mid)*tree[nod].add;
tree[nod<<1].add+=tree[nod].add;
tree[(nod<<1)+1].add+=tree[nod].add;
tree[nod].add=0;
return;
}
inline long long read(){
long long x(0);
char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x;
}
void pushup(long long nod){
tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum;
return;
}
void build(long long l,long long r,long long nod){
tree[nod].add=0;
if(l==r){
tree[nod].sum=a[l];
return;
}
long long mid((l+r)>>1);
build(l,mid,nod<<1);
build(mid+1,r,(nod<<1)+1);
tree[nod].sum=tree[nod<<1].sum+tree[(nod<<1)+1].sum;
return;
}
void update(long long l,long long r,long long ll,long long rr,long long value,long long nod){
if(l==ll&&r==rr){
tree[nod].sum+=(r-l+1)*value;
tree[nod].add+=value;
return;
}
pushdown(nod,l,r);
long long mid((l+r)>>1);
if(rr<=mid)update(l,mid,ll,rr,value,nod<<1);
else if(ll>mid)update(mid+1,r,ll,rr,value,(nod<<1)+1);
else{
update(l,mid,ll,mid,value,nod<<1);
update(mid+1,r,mid+1,rr,value,(nod<<1)+1);
}
pushup(nod);
return;
}
long long query(long long l,long long r,long long ll,long long rr,long long nod){
if(l==ll&&r==rr)return tree[nod].sum;
pushdown(nod,l,r);
long long mid=(l+r)>>1;
if(rr<=mid)return query(l,mid,ll,rr,nod<<1);
else if(ll>mid)return query(mid+1,r,ll,rr,(nod<<1)+1);
else return query(l,mid,ll,mid,nod*2)+query(mid+1,r,mid+1,rr,(nod<<1)+1);
}
int main()
{
long long m;
register long long n;
m=read();
for(long long i=1;i<=m;++i)a[i]=read();
build(1,m,1);
n=read();
while(n--){
long long t,x,y,z;
t=read();
if(t==1){
x=read(); y=read(); z=read();
update(1,m,x,y,z,1);
}
else{
x=read(); y=read();
printf("%lld\n",query(1,m,x,y,1));
}
}
return 0;
}
codevs线段树练习4
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=1000000;
int add[N],sum[N*4+10][7],a[N];
inline int read(){
int x(0);
char ch=getchar();
while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x;
}
void pushup(int nod){
for(int i=0;i<7;i++)
sum[nod][i]=sum[nod<<1][i]+sum[(nod<<1)+1][i];
return;
}
void build(int l,int r,int nod){
if(l==r){
sum[nod][a[l]%7]++;
return;
}
int mid((l+r)>>1);
build(l,mid,nod<<1);
build(mid+1,r,(nod<<1)+1);
pushup(nod);
return;
}
void modify(int nod,int v){
int t[7];
for(int i=0;i<7;i++)
t[(i+v)%7]=sum[nod][i];
for(int i=0;i<7;i++)
sum[nod][i]=t[i];
add[nod]=(add[nod]+v)%7;
return;
}
void pushdown(int nod){
modify(nod<<1,add[nod]);
modify((nod<<1)+1,add[nod]);
add[nod]=0;
return;
}
int query(int l,int r,int ll,int rr,int nod){
if(l==ll&&r==rr)
return sum[nod][0];
int mid((l+r)>>1);
pushdown(nod);
if(rr<=mid)query(l,mid,ll,rr,nod<<1);
else if(ll>mid)query(mid+1,r,ll,rr,(nod<<1)+1);
else return query(l,mid,ll,mid,nod<<1)+query(mid+1,r,mid+1,rr,(nod<<1)+1);
}
void update(int l,int r,int ll,int rr,int value,int nod){
if(l==ll&&r==rr){
modify(nod,value);
return;
}
int mid((l+r)>>1);
pushdown(nod);
if(rr<=mid)update(l,mid,ll,rr,value,nod<<1);
else if(ll>mid)update(mid+1,r,ll,rr,value,(nod<<1)+1);
else{
update(l,mid,ll,mid,value,nod<<1);
update(mid+1,r,mid+1,rr,value,(nod<<1)+1);
}
pushup(nod);
return;
}
int main()
{
int n;
n=read();
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
build(1,n,1);
int q;
q=read();
while(q--){
char s[10];
scanf("%s",s);
if(s[0]=='c'){
int x,y;
x=read();
y=read();
printf("%d\n",query(1,n,x,y,1));
}
else{
int x,y,z;
x=read();
y=read();
z=read();
update(1,n,x,y,z,1);
}
}
return 0;
}