线段树模板讲解
洛谷题目链接:线段树
线段树是一种用于区间修改查询的数据结构,可以支持的操作有单点修改区间查询,区间修改单点查询,区间修改区间查询等.
线段树有递归版和结构体版,递归版在处理一开始没有赋初始值的问题时可以不用建树,而结构体版的则显得比较条理清晰.
线段树比树状数组的代码复杂的多,但是树状数组多支持一个区间修改区间查询的操作,并且可以与其他的一些数据结构相适应(像是树链剖分等),也有很多的在此基础上的提高的算法,线面开始讲解一下思路.
线段树操作如下:
- 输入点权,建树.
- 进行修改和查询.
建树:
在线段树中建树一般采用的是递归的方式建树,在建树的过程中保留每个点的信息(区间左端点,右端点,区间和等),在线段树中的每个节点就是一个区间.
代码如下:
void build(lol root,lol left,lol right){//节点,节点左端点,节点右端点 if(left==right){//当节点左端点等于右端点时,即为叶子节点 sum[root]=w[left];//叶子节点的区间和即为点权 return;//这里记得要退出回溯 } build(root*2,left,mid);//左子树 build(root*2+1,mid+1,right);//右子树递归建树 sum[root]=sum[ll(root)]+sum[rr(root)];//回溯时收集区间和 }
那么这样建好的树就会有这样的性质:
- 当前节点的左右儿子分别是root*2和root*2+1
- 当前节点包括的范围刚好是左右儿子的区间的并集
所以之后在进行修改,查询等操作时会很方便.
下移懒惰标记:
这里提及了一个重要的操作:lazy[]数组,懒惰标记.
lazy数组用于保存修改在某个区间上的值,但不即时修改赋值到节点上,而是在之后查询时需要查询它子树的值的时候再将懒惰标记下移.
那么下移lazy标记就是将当前区间的修改细化到每个小区间上.具体细节见代码注释:
void pushdown(int root,int left,int right){ lazy[root*2]+=lazy[root];//左右子树加上上面节点的标记 lazy[root*2+1]+=lazy[root]; sum[root*2]+=lazy[root]*(mid-left+1);//将区间和加上子树的个数乘标记的大小 sum[root*2+1]+=lazy[root]*(right-mid); lazy[root]=0;//消去节点的懒惰标记 }
为什么修改区间和时要将左子树乘上(mid-left+1)呢?因为左子树的左端点是left,右端点是mid,那么正好(mid-left+1)就是左子树的节点数,乘上lazy[root]的值也就是将它的子树每个加上lazy[root]的值.
修改:
修改时是寻找一个能全部包含于修改范围的区间,并将它打上lazy标记.
一个需要修改的范围可以看做是一个个小的范围的集合,如:
1到8的区间可以看做是[1,5]和[6,8];5到6的区间可以看做是[5,5]和[6,6];
如果是将1~8加5,则将区间[1,5]和区间[6,8]节点的懒惰标记加上5就可以了.
代码如下:
1 void updata(int root,int left,int right,int l,int r,int val){ 2 if(l<=left&&right<=r){//如果找到一个能全部被包含于修改范围的区间 3 lazy[root]+=val;//则加上懒惰标记 4 sum[root]+=val*(right-left+1);//同时也要修改区间和,与pushdown的修改同理 5 return; 6 } 7 if(lazy[root]) pushdown(root,left,right);//修改时也要将之前的懒惰标记下移 8 if(l<=mid) updata(ll(root),left,mid,l,r,val); 9 if(mid<r) updata(rr(root),mid+1,right,l,r,val);//递归寻找能全被包含的区间 10 sum[root]=sum[ll(root)]+sum[rr(root)]; 11 }
查询:
查询和修改操作比较像,也是将一个大区间细化为一个个小区间,找能被完全包含的区间进行查询.
下面直接上代码:
1 lol query(int root,int left,int right,int l,int r){ 2 if(l<=left&&right<=r) return sum[root];//完全被包含的区间直接返回区间和 3 if(r<left||right<l) return 0;//如果区间和查找区间没有交集,则直接返回0 4 if(lazy[root]) pushdown(root,left,right);//这里要写在判断是否与查找区间有交集后面 5 return query(ll(root),left,mid,l,r)+query(rr(root),mid+1,right,l,r);//递归查询 6 }
为什么要把懒惰标记下移的操作放在判断后呢?我们举个例子,假设已经递归到了叶子节点,如果先下移标记,就有可能导致数组的越界(向下移标记时左右子树的节点标号都是节点的两倍),所以要先进行判断区间的操作.
几个简单的操作讲完了,下面放一个完整模板:
1 #include<bits/stdc++.h> 2 #define mid (left+right>>1) 3 #define ll(x) (x<<1) 4 #define rr(x) (x<<1|1) 5 using namespace std; 6 typedef long long lol; 7 const int N=500000; 8 9 lol n,m; 10 lol w[N+10]; 11 lol sum[N+10]; 12 lol lazy[N+10]; 13 14 int gi(){ 15 int ans=0,f=1;char i=getchar(); 16 while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();} 17 while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();} 18 return ans*f; 19 } 20 21 void build(lol root,lol left,lol right){ 22 if(left==right){ 23 sum[root]=w[left]; 24 return; 25 } 26 build(ll(root),left,mid); 27 build(rr(root),mid+1,right); 28 sum[root]=sum[ll(root)]+sum[rr(root)]; 29 } 30 31 void pushdown(int root,int left,int right){ 32 lazy[ll(root)]+=lazy[root]; 33 lazy[rr(root)]+=lazy[root]; 34 sum[ll(root)]+=lazy[root]*(mid-left+1); 35 sum[rr(root)]+=lazy[root]*(right-mid); 36 lazy[root]=0; 37 } 38 39 void updata(int root,int left,int right,int l,int r,int val){ 40 if(l<=left&&right<=r){ 41 lazy[root]+=val; 42 sum[root]+=val*(right-left+1); 43 return; 44 } 45 if(lazy[root]) pushdown(root,left,right); 46 if(l<=mid) updata(ll(root),left,mid,l,r,val); 47 if(mid<r) updata(rr(root),mid+1,right,l,r,val); 48 sum[root]=sum[ll(root)]+sum[rr(root)]; 49 } 50 51 lol query(int root,int left,int right,int l,int r){ 52 if(l<=left&&right<=r) return sum[root]; 53 if(r<left||right<l) return 0; 54 if(lazy[root]) pushdown(root,left,right); 55 return query(ll(root),left,mid,l,r)+query(rr(root),mid+1,right,l,r); 56 } 57 58 int main(){ 59 int x,y,val,flag; 60 n=gi();m=gi(); 61 for(int i=1;i<=n;i++) w[i]=gi(); 62 build(1,1,n); 63 for(int i=1;i<=m;i++){ 64 flag=gi(); 65 if(flag==1){ 66 x=gi();y=gi();val=gi(); 67 updata(1,1,n,x,y,val); 68 } 69 if(flag==2){ 70 x=gi();y=gi(); 71 printf("%lld\n",query(1,1,n,x,y)); 72 } 73 } 74 return 0; 75 }
另外再贴一个结构体版的:
#include<bits/stdc++.h> #define ll(x) (x<<1) #define rr(x) (x<<1|1) using namespace std; const int N=400000+5; typedef long long lol; lol n, m; lol w[N]; struct seg_tree{ lol sum, l, r, lazy; }t[N]; lol gi(){ lol ans = 0 , f = 1; char i=getchar(); while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();} while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();} return ans * f; } void up(lol root){ t[root].sum = t[ll(root)].sum + t[rr(root)].sum; } void build(lol root,lol l,lol r){ int mid = l+r>>1; t[root].l = l , t[root].r = r; if(l == r){ t[root].sum = w[l]; return; } build(ll(root),l,mid); build(rr(root),mid+1,r); up(root); } void pushdown(lol root){ lol mid = t[root].l + t[root].r >> 1; t[ll(root)].lazy += t[root].lazy; t[rr(root)].lazy += t[root].lazy; t[ll(root)].sum += t[root].lazy*(mid-t[root].l+1); t[rr(root)].sum += t[root].lazy*(t[root].r-mid); t[root].lazy = 0; } void updata(lol root,lol l,lol r,lol val){ lol mid = t[root].l+t[root].r>>1; if(l<=t[root].l && t[root].r<=r){ t[root].sum += val * (t[root].r-t[root].l+1); t[root].lazy += val; return; } if(t[root].lazy) pushdown(root); if(l <= mid) updata(ll(root),l,r,val); if(mid < r) updata(rr(root),l,r,val); up(root); } lol query(lol root,lol l,lol r){ if(l<=t[root].l && t[root].r<=r) return t[root].sum; if(r<t[root].l || t[root].r<l) return 0; if(t[root].lazy) pushdown(root); return query(ll(root),l,r)+query(rr(root),l,r); } int main(){ lol f, x, y, val; n = gi(); m = gi(); for(lol i=1;i<=n;i++) w[i] = gi(); build(1,1,n); for(lol i=1;i<=m;i++){ f = gi(); x = gi(); y = gi(); if(f == 1) val = gi() , updata(1,x,y,val); else printf("%lld\n",query(1,x,y)); } return 0; }