线段树笔记√
线段树上的每一点表示一段区间和。
建树
首先,递归进去做,递归的参数是pos,l,r,分别表示,线段树上节点的编号(即当前编号),以及这个点表示的区间的左端点和右端点。那么终止的条件就是l=r,这个时候,node[pos].sum=a[l]
我们考虑一下l不等于r的时候,那么这个区间的左儿子就是[l,mid],右儿子就是[mid+1,r]
然后考虑一下左边儿子的节点编号是什么
我们之间令编号为x的节点的左儿子为x*2,右儿子为x*2+1
那么只有对左右儿子递归进去做就行了,做完了之后,我们更新一下sum
【注意】线段树的节点个数要开成序列长度的四倍
1 struct data{ 2 int sum; 3 }node[400001]; 4 void build(int pos,int l,int r) 5 { 6 if(l==r){ 7 node[pos].sum=a[l]; 8 return; 9 } 10 int mid=l+r>>1;//µÈ¼Û(l+r)/2,λÔËËãÓÅÏȼ¶×îµÍ 11 int lson=pos*2,rson=pos*2+1; 12 build(lson,l,mid); 13 build(rson,mid+1,r); 14 node[pos].sum=node[pos*2].sum+node[pos*2+1].sum; 15 }
void build(int pos,int l,int r) { ll[pos]=l,rr[pos]=r; if(l==r) { max1[pos]=a[l]; return; } int mid=l+r>>1; build(pos<<1,l,mid); build(pos<<1|1,mid+1,r); max1[pos]=MAX(max1[pos<<1],max1[pos<<1|1]); }
查询
比如说,你要查区间[5,6]的和,你会发现,这就是线段树上的一个节点,那直接找到这个节点就行。
//pos表示当前走到了哪个节点,l,r,表示这个节点所代表的区间,ql,qr表示你要查询的区间
我们考虑一下终止条件,是l=ql,r=qr。这个时候我们就可以直接返回node的信息,考虑一下如果不是,有哪些情况,首先我们先算出当前这段区间的中点,考虑一下,如果qr<=mid,那么我们要查询的区间都在左儿子里,这时候直接返回查询左儿子的值就行了。
还有什么情况呢,如果ql>mid,那是不是整个区间都在右儿子里,这时候就直接查右儿子就行了。
那考虑一下,如果查询的区间是一部分在左儿子里,一部分在右儿子里,要怎么办?
我们把查询的区间切开,原本要查[ql,qr],现在变成[ql,mid]+[mid+1,qr],那就是在左儿子里查询[ql,mid],在右儿子里查询[mid+1,qr]
1 int query(int pos,int l,int r,int ql,int qr) 2 { 3 if(l==ql&&r==qr)return node[pos].sum; 4 int mid=l+r>>1,lson=pos*2,rson=pos*2+1; 5 if(qr<=mid)return query(lson,l,mid,ql,qr); 6 else if(ql>mid) return query(rson,mid+1,r,ql,qr); 7 else return query(lson,l,mid,ql,mid)+query(rson,mid+1,r,mid+1,qr); 8 }
int query(int pos,int l,int r)//sum { l=max(ll[pos],l),r=min(rr[pos],r); if(l>r)return 0;//查询max就-inf,min就inf,sum就0 if(l==ll[pos]&&r==rr[pos]) return max1[pos]; int mid=l+r>>1; return (query(pos<<1 , l , r)+query(pos<<1|1 , l , r); }
int query(int pos,int l,int r)//max { l=max(ll[pos],l),r=min(rr[pos],r); if(l>r)return -1*inf;//查询max就-inf,min就inf,sum就0 if(l==ll[pos]&&r==rr[pos]) return max1[pos]; int mid=l+r>>1; return max(query(pos<<1 , l , r) , query(pos<<1|1 , l , r)); }
int query(int pos,int l,int r)//min { l=max(ll[pos],l),r=min(rr[pos],r); if(l>r)return inf; if(l==ll[pos]&&r==rr[pos])return min1[pos]; int mid=l+r>>1; return min(query(pos<<1,l,r),query(pos<<1|1,l,r)); }
修改
修改有两种情况,单点修改和区间修改。
单点修改:首先还是pos表示当前节点,l,r表示当前节点代表的区间,m表示要修改的位置,显然递归终止的条件是l=r。那考虑一下,其实一个点,要么在mid左边 要么在mid右边,直接判断一下在哪边递归进去做就行了
如果l=r,说明找到这个点了,sum+=v,然后退出
否则的话判断m是不是<=mid
是的话说明在左儿子里,递归下去修改。
不是的话,反之(模仿上述处理)
修改完了之后要注意,要记得更新节点的信息,这个点的和等于左儿子的和加上右儿子的和。
1 int modify_dot(int pos,int l,int r,int m,int v) 2 { 3 if(l==r){ 4 node[pos].sum 5 return; 6 } 7 int mid=l+r>>1,lson=pos*2,rson=pos*2+1; 8 if(m<=mid)modify_dot(lson,l,mid,m,v); 9 else modify_dot(rson,mid+1,r,m,v); 10 node[pos].sum=node[lson].sum+node[rson].sum; 11 12 }
int n,k; int a[200005]; int ll[maxn],rr[maxn]; int max1[maxn],min1[maxn],sum[maxn]; int lazy[maxn];//下传标记->区间修改用的,初值0 void downpush(int x) { if(ll[x]==rr[x])return; if(lazy[x]) { sum[x<<1]+=lazy[x]; lazy[x<<1]+=lazy[x]; sum[x<<1|1]+=lazy[x]; lazy[x<<1|1]+=lazy[x]; lazy[x]=0; } //lazyx表示x被加了lazyx 遍历到这个节点的时候顺便down一下维护他的儿子 } void modify(int x,int l,int r,int k)//区间修改 当前节点x,l~r区间加k { downpush(x); l=max(l,ll[x]),r=min(r,rr[x]); if(l>r)return; if(l==ll[x]&&r==rr[x]) { lazy[x]=k; sum[x]+=k*(r-l+1); } else modify(x<<1,l,r,k),modify(x<<1|1,l,r,k); }
e.g.
先给你n,m表示序列长度,和操作次数,接下来n个数,表示原序列,接下来m行,表示操作
如果输入格式是1 l r
表示查询l到r的和
如果输入格式是2 l r
表示查询l到r的最小值
如果输入格式是3 l r
表示查询l到r的最大值
如果输入格式是4 m v
表示序列中第m个数变成v
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> using namespace std; long long a[100005]; struct data { int sum,max,min; } node[400001]; void build(int pos,int l,int r) { if(l==r) { node[pos].sum=a[l]; node[pos].max=a[l]; node[pos].min=a[l]; return; } int mid=l+r>>1; build(pos*2,l,mid); build(pos*2+1,mid+1,r); node[pos].sum=node[pos*2].sum+node[pos*2+1].sum; node[pos].max=node[pos*2].max>node[pos*2+1].max?node[pos*2].max:node[pos*2+1].max; node[pos].min=node[pos*2].min>node[pos*2+1].min?node[pos*2+1].min:node[pos*2].min; } int query(int pos,int l,int r,int ql,int qr,int x) { if (l==ql && r==qr) if (x==3) return node[pos].max; else if (x==2) return node[pos].min; else if (x==1) return node[pos].sum; else; else { int mid=l+r>>1; if (qr<=mid) return query(pos*2,l,mid,ql,qr,x); else if (ql>mid) return query(pos*2+1,mid+1,r,ql,qr,x); else { int a=query(pos*2,l,mid,ql,mid,x),b=query(pos*2+1,mid+1,r,mid+1,qr,x); if (x==3) return a>b?a:b; else if (x==2) return a>b?b:a; else if (x==1) return a+b; } } } void modify_dot(int pos,int v,int n) { node[pos].max=v; node[pos].min=v; node[pos].sum=v; while (pos!=1) { if (pos%2==0) { //×ó node[pos/2].max=node[pos+1].max>node[pos].max?node[pos+1].max:node[pos].max; node[pos/2].min=node[pos+1].min<node[pos].min?node[pos+1].min:node[pos].min; node[pos/2].sum=node[pos].sum+node[pos+1].sum; } else { //ÓÒ node[pos/2].max=node[pos-1].max>node[pos].max?node[pos-1].max:node[pos].max; node[pos/2].min=node[pos-1].min<node[pos].min?node[pos-1].min:node[pos].min; node[pos/2].sum=node[pos].sum+node[pos-1].sum; } pos=pos/2;//°Ö°Ö } return; } int getpos(int pos,int l,int r,int m) { if (l==r && l==m) return pos; int mid=l+r>>1; if (m<=mid) return getpos(pos*2,l,mid,m); else return getpos(pos*2+1,mid+1,r,m); } int main() { int n,m; memset(a,0,sizeof(a)); scanf("%d%d",&n,&m); for(int i=1; i<=n; i++)scanf("%d",&a[i]); build(1,1,n); /* for (int j=1; j<=5; ++j) printf("%d µÄmax:%d min:%d sum:%d\n",j,node[j].max,node[j].min,node[j].sum); printf("\n"); */ for(int i=1; i<=m; i++) { int x,ql,qr; scanf("%d%d%d",&x,&ql,&qr); if(x==1) { printf("%d\n",query(1,1,n,ql,qr,1)); } if(x==2) { printf("%d\n",query(1,1,n,ql,qr,2)); } if(x==3) { printf("%d\n",query(1,1,n,ql,qr,3)); } if(x==4) { int pos=getpos(1,1,n,ql); modify_dot(pos,qr,n); } } return 0; }