zkw线段树
强行学习zkw线段树qwq
参考资料:强烈安利zkw线段树出处 zkw的 统计的力量
以及推荐某位神犇的博客, 讲解很详细qwq
线段树==树状数组??
下面代码里维护的有最大值、最小值、区间和。
一、建树:
zkw线段树使用堆式存储qwq
首先你要写个循环,让\(m\)(非叶节点数)大于\(n\)(叶结点数),以此保证这棵树的叶子能够容纳你要维护的\(n\)个值
然后你要从\(m\)倒推到 \(1\) 号节点(注意是\(M\)倒推回\(1\) ,保证维护每个节点时该节点的孩子都已经被维护完毕),让每个节点维护它左右孩子的信息。
直接一波艹完叶子然后建树 不递归。
下面代码里维护的是最大值、最小值、区间和。
inline void build(int n){
for(m=1; m<n; m<<=1);//开大空间
for(int i=m+1; i<=m+n; i++) a[i]=read();//叶子节点|原始数组
for(int i=m-1; i; --i){
sum[i]=a[i<<1]+a[i<<1|1];
mn[i]=min(mn[i<<1],mn[i<<1|1]);
mx[i]=max(mx[i<<1],mx[i<<1|1]);
}
}
但这样写不支持修改qwq所以差分来写
注意从这里开始下面都用了差分的办法
inline void build(){
for(m=1; m<=n; m<<=1);
for(int i=m+1; i<=m+n; ++i)
sum[i]=mn[i]=mx[i]=read();
for(int i=m-1; i; --i){
sum[i]=sum[i<<1]+sum[i<<1|1];
mn[i]=min(mn[i<<1],mn[i<<1|1]),
mn[i<<1]-=mn[i],mn[i<<1|1]-=mn[i];
mx[i]=max(mx[i<<1],mx[i<<1|1]),
mx[i<<1]-=mx[i],mx[i<<1|1]-=mx[i];
}
}
二、修改操作:
单点修改:
维护差分数组 有点类似树状数组
inline void update_node(int x,int v,int A=0){
x+=m,mx[x]+=v,mn[x]+=v,sum[x]+=v;
for(; x>1; x>>=1){
sum[x]+=v;
A=min(mn[x],mn[x^1]);
mn[x]-=A,mn[x^1]-=A,mn[x>>1]+=A;
A=max(mx[x],mx[x^1]),
mx[x]-=A,mx[x^1]-=A,mx[x>>1]+=A;
}
}
区间修改:
修改\([s, t]\)这段区间,我们转化为\((s-1, t+1)\)这样的开区间处理
定义 \(lc\) 代表左端点所处的节点下有多少长度的区间在更新区间内, \(rc\) 同理 ,通俗一点地说,就是 \(s\) 和 \(t\) 所分别走过的节点中包含的更新过的区间的总长
inline void update_part(int s, int t, int v){
int A=0, lc=0, rc=0, len=1;
for(s+=m-1,t+=m+1; s^t^1; s>>=1,t>>=1,len<<=1){ //add就是标记数组
if(s&1^1) add[s^1]+=v,lc+=len, mn[s^1]+=v,mx[s^1]+=v;
if(t&1) add[t^1]+=v,rc+=len, mn[t^1]+=v,mx[t^1]+=v;
sum[s>>1]+=v*lc, sum[t>>1]+=v*rc;
A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
A=min(mn[t],mn[t^1]),mn[t]-=A,mn[t^1]-=A,mn[t>>1]+=A;
A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A,
A=max(mx[t],mx[t^1]),mx[t]-=A,mx[t^1]-=A,mx[t>>1]+=A;
}
for(lc+=rc; s; s>>=1){
sum[s>>1]+=v*lc;
A=min(mn[s],mn[s^1]),mn[s]-=A,mn[s^1]-=A,mn[s>>1]+=A,
A=max(mx[s],mx[s^1]),mx[s]-=A,mx[s^1]-=A,mx[s>>1]+=A;
}
}
三、查询操作
单点查询:
把差分的 加回来 有点类似树状数组
inline int query_mn_node(int x, int ans=0){
for(x+=m; x; x>>=1) ans+=mn[s];
return ans;
}
inline int query_mx_node(int x, int ans=0){
for(x+=m; x; x>>=1) ans+=mx[s];
return ans;
}
inline int query_sum_node(int x, int ans=0){
for(x+=m; x; x>>=1) ans+=sum[s];
return ans;
}
区间查询:
被我咕掉了qwq
先上代码qwq
inline int query_sum(int s,int t){
int lc=0,rc=0,len=1,ans=0;
for(s+=m-1,t+=m+1;s^t^1;s>>=1,t>>=1,len<<=1){
if(s&1^1) ans+=sum[s^1]+len*add[s^1],lc+=len;
if(t&1) ans+=sum[t^1]+len*add[t^1],rc+=len;
if(add[s>>1]) ans+=add[s>>1]*lc;
if(add[t>>1]) ans+=add[t>>1]*rc;
}
for(lc+=rc,s>>=1;s;s>>=1)
if(add[s]) ans+=add[s]*lc;
return ans;
}
inline int query_min(int s,int t,int L=0,int R=0,int ans=0){
if(s==t) return query_node(s); // 单点要特判, 下同
for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){ // 这里 s 和 t 直接加上 m
L+=mn[s],R+=mn[t];
if(s&1^1) L=min(L,mn[s^1]);
if(t&1) R=min(R,mn[t^1]);
}
for(ans=min(L,R),s>>=1;s;s>>=1) ans+=mn[s];
return ans;
}
inline int query_max(int s,int t,int L=0,int R=0,int ans=0){
if(s==t) return query_node(s);
for(s+=m,t+=m;s^t^1;s>>=1,t>>=1){
L+=mx[s],R+=mx[t];
if(s&1^1) L=max(L,mx[s^1]);
if(t&1) R=max(R,mx[t^1]);
}
for(ans=max(L,R),s>>=1;s;s>>=1) ans+=mx[s];
return ans;
}
然后 其他比较强大的用途 都被我咕咕咕咕咕qwq
不如吃茶去