zkw线段树学习笔记
ZKW线段树
应某迪要求,写一篇数据结构学习笔记。
实际上还没有学很多东西,只是一些基础的操作。
zkw线段树的学习资料,网上有很多,这里记录的只是自己的一些理解。
建树
1 inline void build(){ 2 for(bit=1,n=read();bit<=n+1;bit<<=1); 3 for(int i=bit+1;i<=bit+n;++i) sum[i]=read(); 4 for(int i=bit-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1]; 5 }
$zkw$线段树构造了一棵完美二叉树,只有最后一层叶子节点管辖的区间大小为1。
$zkw$线段树是基于位运算的,对于节点$p$,$p<<1$为它的左儿子,$p<<1|1$为它的右儿子。
因为是一棵完美二叉树,除掉叶子节点的部分一定为$2^k-1$的形式,将这个$2^k$记为$bit$,可以方便我们之后的操作。
其意义是,对于原序列的点$i$,可以直接得到对应线段树上的节点$i+bit$。
注意这里我们忽略了$bit$也就是$2^k$这一个节点,以后再提。
同时建树的一个细节是$bit$应当大于$n+1$,其原因也可以留到后面。
单点修改
1 inline void modify(int p,int val){ 2 for(p+=bit;p;p>>=1) sum[p]+=val; 3 }
找到位置之后,直接修改一条祖先链。
区间修改
1 inline void modify(int l,int r,int val){ 2 int lc=0,rc=0,len=1; 3 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 4 sum[l]+=lc*val; sum[r]+=rc*val; 5 if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,lc+=len; 6 if(r&1) sum[r^1]+=len*val,add[r^1]+=val,rc+=len; 7 } 8 for(;l;l>>=1,r>>=1) sum[l]+=lc*val,sum[r]+=rc*val; 9 }
$lc$:当前左指针包含的区间长度。$rc$:当前右指针包含的区间长度。$len$:当前翻到的节点层管辖区间的长度。
这里我们将$l$,$r$都作为开区间。所以分别加$bit-1$,$bit+1$处理。
因为操作是自下而上进行的,$zkw$线段树一般不维护懒标记,
因而我们用一个数组$add$进行标记永久化,表示这个区间的所有序列应该被加上这个值,显然这个值是不能下传的。
当$l$的最后一位为0,也就是说$l$指针为左儿子,那么l的右兄弟在当前修改的区间内。
同理$r$的左兄弟会在修改区间内。
当$l$,$r$两个指针已经成为兄弟,也就是说二者在二进制下只有最后一位不同,即异或值为1,那么全部的修改操作已经完成,可以结束。
然而祖先链上的$sum$值仍然需要修改。
这里可以解释,为什么$bit$应该大于$n+1$而不是$n$,为什么$bit+0$这个节点需要被空出来,因为我们需要开区间来进行操作。
然而似乎使$bit$仅保证大于$n$的打法是正确的,手玩确实没有错误。
单点查询
1 inline int query(int p){ 2 int ans=0; 3 for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p]; 4 return ans; 5 }
统计叶子节点的$sum$值,并不断加上祖先链的$add$标记即可。
应当注意的是不要加上叶子节点的$add$标记,这个标记是无意义的。
区间查询
1 inline int query(int l,int r){ 2 int ans=0,lc=0,rc=0,len=1; 3 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 4 ans+=add[l]*lc+add[r]*rc; 5 if(~l&1) ans+=sum[l^1],lc+=len; 6 if(r&1) ans+=sum[r^1],rc+=len; 7 } 8 for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc; 9 return ans; 10 }
区间查询的打法是类似于区间修改的。
首先将$l$,$r$设为开区间。
不断翻祖先链,记得加上兄弟节点整体的$sum$值和祖先链上部分的$add$标记就可以了。
应当注意的是循环中统计$add$标记和兄弟$sum$值的顺序不可交换,否则可能导致$lc$,$rc$变量维护的含义错误。
区间最值
思想大概是将儿子的最值不断差分到父亲身上。
因为不同的部分已经被差分掉,可以直接修改区间的最值。
应当注意的是,区间最值的求法与区间求和不同。
为了减少特判,原本的开区间被转化为闭区间。
但这样产生一个问题,如果查询区间长度为1会导致一些问题:左右端点永远不会成为兄弟,故导致了死循环。
所以要加一个单点查询的特判。
因为要维护区间最值,修改操作同时也要不断差分,于是打的麻烦了许多,代码可以参考下面。
基础操作
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=1e6+7; 4 inline int read(register int x=0,register char ch=getchar(),register int f=0){ 5 while(!isdigit(ch)) f=ch=='-',ch=getchar(); 6 while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); 7 return f?-x:x; 8 } 9 int n,m,bit; 10 int sum[N<<2],add[N<<2],mn[N<<2],mx[N<<2]; 11 inline void build(){ 12 for(bit=1;bit<=n+1;bit<<=1); 13 for(int i=bit+1;i<=bit+n;++i) mx[i]=mn[i]=sum[i]=read(); 14 for(int i=bit-1;i;--i){ 15 sum[i]=sum[i<<1]+sum[i<<1|1]; 16 mn[i]=min(mn[i<<1],mn[i<<1|1]); mn[i<<1]-=mn[i]; mn[i<<1|1]-=mn[i]; 17 mx[i]=max(mx[i<<1],mx[i<<1|1]); mx[i<<1]-=mx[i]; mx[i<<1|1]-=mx[i]; 18 } 19 } 20 inline int query(int p){ 21 int ans=0; 22 for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p]; 23 return ans; 24 } 25 inline int query(int l,int r){ 26 int ans=0,lc=0,rc=0,len=1; 27 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 28 ans+=add[l]*lc+add[r]*rc; 29 if(~l&1) ans+=sum[l^1],lc+=len; 30 if(r&1) ans+=sum[r^1],rc+=len; 31 } 32 for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc; 33 return ans; 34 } 35 inline int query_min(int l,int r){ 36 if(l==r) return query(l); 37 int lans=0,rans=0; 38 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 39 lans+=mn[l]; rans+=mn[r]; 40 if(~l&1) lans=min(lans,mn[l^1]); 41 if(r&1) rans=min(rans,mn[r^1]); 42 } 43 for(lans=min(lans+mn[l],rans+mn[r]),l>>=1;l;l>>=1) lans+=mn[l]; 44 return lans; 45 } 46 inline int query_max(int l,int r){ 47 if(l==r) return query(l); 48 int lans=0,rans=0; 49 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 50 lans+=mx[l]; rans+=mx[r]; 51 if(~l&1) lans=max(lans,mx[l^1]); 52 if(r&1) rans=max(rans,mx[r^1]); 53 } 54 for(lans=max(lans+mx[l],rans+mx[r]),l>>=1;l;l>>=1) lans+=mx[l]; 55 return lans; 56 } 57 inline void modify(int l,int r,int val){ 58 int lc=0,rc=0,len=1,x; 59 for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){ 60 sum[l]+=lc*val; sum[r]+=rc*val; 61 if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,mn[l^1]+=val,lc+=len; 62 if(r&1) sum[r^1]+=len*val,add[r^1]+=val,mn[r^1]+=val,rc+=len; 63 x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x; 64 x=min(mn[r],mn[r^1]); mn[r]-=x; mn[r^1]-=x; mn[r>>1]+=x; 65 } 66 for(;l;l>>=1,r>>=1){ 67 sum[l]+=lc*val; sum[r]+=rc*val; 68 x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x; 69 x=max(mx[l],mx[l^1]); mx[l]-=x; mx[l^1]-=x; mx[l>>1]+=x; 70 } 71 } 72 inline void modify(int p,int val){ 73 int x; 74 for(p+=bit;p;p>>=1){ 75 sum[p]+=val; mn[p]+=val; mx[p]+=val; 76 x=min(mn[p],mn[p^1]); mn[p]-=x; mn[p^1]-=x; mn[p>>1]+=x; 77 x=max(mx[p],mx[p^1]); mx[p]-=x; mx[p^1]-=x; mx[p>>1]+=x; 78 } 79 } 80 int main(){ 81 n=read(); 82 build(); 83 return 0; 84 }
区间信息合并(山海经)
思想大概与区间查询最值一致。
为了减少特判将开区间转化为闭区间。
需要注意的是信息合并有左右的先后顺序。
所以左右指针的写法并不相同,最后将$l$,$r$扫过的信息合并就可以了。
1 #include<bits/stdc++.h> 2 const int N=100010; 3 inline int read(register int x=0,register char ch=getchar(),bool f=0){ 4 while(!isdigit(ch)) f=ch=='-',ch=getchar(); 5 while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); 6 return f?-x:x; 7 } 8 struct Ans{ 9 int l,r,val; 10 }; 11 struct Node{ 12 Ans la,ra,mx,tot; 13 }s[N<<2]; 14 int n,m,bit; 15 inline bool operator <(const Ans &a,const Ans &b){ 16 return a.val<b.val||(a.val==b.val&&a.l>b.l)||(a.val==b.val&&a.l==b.l&&a.r>b.r); 17 } 18 inline Ans operator +(const Ans &a,const Ans &b){ 19 return (Ans){a.l,b.r,a.val+b.val}; 20 } 21 inline Node operator +(const Node &a,const Node &b){ 22 return (Node){std::max(a.la,a.tot+b.la),std::max(b.ra,a.ra+b.tot),std::max(std::max(a.mx,b.mx),a.ra+b.la),a.tot+b.tot}; 23 } 24 void build(){ 25 for(bit=1;bit<=n+1;bit<<=1); 26 for(int i=bit+1;i<=bit+n;++i) s[i].tot=s[i].mx=s[i].la=s[i].ra=(Ans){i-bit,i-bit,read()}; 27 for(int i=bit-1;i;--i) s[i]=s[i<<1]+s[i<<1|1]; 28 } 29 Ans query(int r,int l){ 30 if(l==r) return s[l+bit].mx; 31 Node L=s[l+bit],R=s[r+bit]; 32 for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){ 33 if(~l&1) L=L+s[l^1]; 34 if(r&1) R=s[r^1]+R; 35 } 36 return (L+R).mx; 37 } 38 void print(const Ans &x){ 39 printf("%d %d %d\n",x.l,x.r,x.val); 40 } 41 int main(){ 42 n=read(); m=read(); build(); 43 for(int i=1;i<=m;++i) print(query(read(),read())); 44 return 0; 45 }