线段树学习笔记
概述
本篇主要讲朴素线段树以及简单变式。
朴素线段树
时间 \(O(n \log n )\) ,空间开 \(4\) 倍,主要操作有 build , query , update , pushup , pushdown 。
公式
\[lc=p<<1
\]
\[rc=(p<<1)|1
\]
易错点
- 判断是否完全覆盖区间,大于等于和小于等于的方向:
ln<=tr[p].l&&rn>=tr[p].r
。 - query 与 update 都要 pushdown 。
- update 与 build 都要 pushup 。
- pushdown 与 update 时,注意给子节点加懒标记是用 += 而不是 = 。
- 注意初始化时要建树 :
build(1,1,n)
,最好一开始就写上,防止忘记。 - pushdown 操作可以放在函数最前面,防止忘记,但是在 pushup 和 pushdown 里加上边界特判。
- 线段长度为
r-l+1
。 - mid 应取
(l+r)>>1
。 - query 作为有返回值的函数,一样要用 long long 类型。
- update 与 query 函数中 ln,rn 的值恒定不变(存的是要查询的区间),只有 build 中的 ln,rn 要改变(存的是当前节点左右区间)。
- 懒标记是打在最后的节点上,前面的并不会打到,因此 update 操作要 pushup 。
- 到左儿子是判断 $lc \le mid $ ,到右儿子是判断 \(rc \ge mid+1\) ,必须加一。
- 如果不想在 pushup 和 pushdown 里加边界特判的话,就在删掉他们之后把 pushdown 操作放在 query 和 update 里面的 if 判断后面,不然会 RE 。
- lc 和 rc 里的 p 错写成右移 >> ,正确写法应该是左移 << 。
- build 中某个节点初值赋错,例如把
tr[p].v=a[ln]
写成tr[p].v=a[p]
。 - update 中,完全覆盖时修改完区间之后没有 return。
模板
#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N];
struct node{
int l,r;
ll sum,add=0;
}tr[4*N];
void pushup(int p)
{
tr[p].sum=0;
if(lc<4*N)tr[p].sum+=tr[lc].sum;
if(rc<4*N)tr[p].sum+=tr[rc].sum;
}
void pushdown(int p)
{
if(tr[p].add!=0)
{
if(lc<4*N)
{
tr[lc].add+=tr[p].add;
tr[lc].sum+=(tr[lc].r-tr[lc].l+1)*tr[p].add;
}
if(rc<4*N)
{
tr[rc].add+=tr[p].add;
tr[rc].sum+=(tr[rc].r-tr[rc].l+1)*tr[p].add;
}
}
tr[p].add=0;
}
void build(int p,int ln,int rn)
{
tr[p].l=ln,tr[p].r=rn,tr[p].add=0;
if(ln==rn)
{
tr[p].sum=a[ln];
return;
}
int mid=(ln+rn)>>1;
build(lc,ln,mid);
build(rc,mid+1,rn);
pushup(p);
}
ll query(int p,int ln,int rn)
{
pushdown(p);
if(ln<=tr[p].l&&rn>=tr[p].r)return tr[p].sum;
int mid=(tr[p].l+tr[p].r)>>1;
ll nsum=0;
if(ln<=mid)nsum+=query(lc,ln,rn);
if(rn>=mid+1)nsum+=query(rc,ln,rn);
return nsum;
}
void update(int p,int ln,int rn,ll k)
{
pushdown(p);
if(ln<=tr[p].l&&rn>=tr[p].r)
{
tr[p].add+=k;
tr[p].sum+=(tr[p].r-tr[p].l+1)*k;
return;
}
int mid=(tr[p].l+tr[p].r)>>1;
if(ln<=mid)update(lc,ln,rn,k);
if(rn>=mid+1)update(rc,ln,rn,k);
pushup(p);
}
int n,m;
int main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(m--)
{
int op;
cin>>op;
if(op==1)
{
ll x,y,k;
cin>>x>>y>>k;
update(1,x,y,k);
}
else
{
ll x,y;
cin>>x>>y;
cout<<query(1,x,y)<<endl;
}
}
return 0;
}
变式:线段树2 (多个懒标记)
对于新加的一个乘法操作,我们需要多加一个懒标记,那么在下传懒标记的时候就要考虑谁先谁后的问题。
于是进行分讨:
加法先,乘法后
如果要使结果正确,原来的数为 \(x\) ,父亲的加法懒标记为 \(fa\) ,父亲乘法懒标记为 \(fm\) ,自己加法懒标记为 \(a\) ,自己乘法懒标记为 \(m\) ,那么正确结果 \(res\) :
\[\begin{aligned}
res &= \left\{[(x+a)*m]+fa \right\}*fm \\
&= \left\{m*x+m*a+fa \right\}*fm \\
&= m*x*fm+m*a+fm+fa*fm \\
&= [x+(a+fa/m)]*m*fm
\end{aligned}
\]
先按加法优先的顺序把式子写出来,化为最简形式,并且最后化为同样加法优先且使结果正确的形式,那么这就是懒标记的下传。
因此加法优先,懒标记为:
\[a=a+fa/m
\]
\[m=m*fm
\]
其中 \(fa/m\) 为浮点数,精度不佳,不能用这种方式,舍去。
乘法先,加法后
如果要使结果正确,那么正确结果 \(res\) :
\[\begin{aligned}
res &= \left\{[(x*m)+a]*fm\right\}+fa \\
&= \left\{[x*m+a]*fm\right\}+fa \\
&= \left\{x*m*fm+a*fm\right\}+fa \\
&= x*m*fm+a*fm+fa \\
&= [x*(m*fm)]+(a*fm+fa)
\end{aligned}
\]
于是懒标记为:
\[a=a*fm+fa
\]
\[m=m*fm
\]
全是乘法,不会损失精度,因此可行。
下传懒标记后的 \(sum\)
首先因为先乘后加,所以我们可以写出以下式子:
\[\begin{aligned}
sum &= (x_{l}*fm+fa)+(x_{l+1}*fm+fa)+ ... + (x_{r-1}*fm+fa)+(x_{r}*fm+fa) \\
&=(x_{l}+x_{l+1}+...+x_{r-1}+x_{r})*fm + fa + fa+ ... + fa+fa \\
&= sum*fm+(r-l+1)*fa
\end{aligned}
\]
注意是父亲的懒标记 \(fm,fa\) ,不是自己的,因为自己的之前已经加过了。
代码
未简化边界特判版(大常数)
#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N],mod;
struct node{
int l,r;
ll sum,add=0,mul=1;
}tr[4*N];
void pushup(int p)
{
tr[p].sum=0;
if(lc<4*N)tr[p].sum+=tr[lc].sum;
if(rc<4*N)tr[p].sum+=tr[rc].sum;
tr[p].sum%=mod;
}
void cal(node &t,ll mu,ll ad)
{
t.sum=(t.sum*(mu%mod)%mod+(t.r-t.l+1)*ad%mod)%mod;
t.add=(t.add*mu%mod+ad%mod)%mod;
t.mul=t.mul*(mu%mod)%mod;
}
void pushdown(int p)
{
if(tr[p].add!=0 || tr[p].mul!=1)
{
if(lc<4*N)
{
cal(tr[lc],tr[p].mul,tr[p].add);
}
if(rc<4*N)
{
cal(tr[rc],tr[p].mul,tr[p].add);
}
}
tr[p].add=0;
tr[p].mul=1;
}
void build(int p,int ln,int rn)
{
tr[p].l=ln,tr[p].r=rn,tr[p].add=0,tr[p].mul=1;
if(ln==rn)
{
tr[p].sum=a[ln];
return;
}
int mid=(ln+rn)>>1;
build(lc,ln,mid);
build(rc,mid+1,rn);
pushup(p);
}
ll query(int p,int ln,int rn)
{
pushdown(p);
if(tr[p].l>=ln&&tr[p].r<=rn)return tr[p].sum;
int mid=(tr[p].l+tr[p].r)>>1;
ll nsum=0;
if(ln<=mid)nsum=(nsum+query(lc,ln,rn)%mod)%mod;
if(rn>=mid+1)nsum=(nsum+query(rc,ln,rn)%mod)%mod;
return nsum;
}
void update(int p,int ln,int rn,ll mu,ll ad)
{
pushdown(p);
if(tr[p].l>=ln&&tr[p].r<=rn)
{
cal(tr[p],mu,ad);
return;
}
int mid=(tr[p].l+tr[p].r)>>1;
if(ln<=mid)update(lc,ln,rn,mu,ad);
if(rn>=mid+1)update(rc,ln,rn,mu,ad);
pushup(p);
}
int n,q;
int main()
{
cin>>n>>q>>mod;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(q--)
{
int op;
cin>>op;
if(op==1)
{
ll x,y,k;
cin>>x>>y>>k;
update(1,x,y,k,0);
}
else if(op==2)
{
ll x,y,k;
cin>>x>>y>>k;
update(1,x,y,1,k);
}
else if(op==3)
{
ll x,y;
cin>>x>>y;
cout<<query(1,x,y)%mod<<endl;
}
}
return 0;
}
简化版(小常数)
#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N],mod;
struct node{
int l,r;
ll sum,add=0,mul=1;
}tr[4*N];
void pushup(int p)
{
tr[p].sum=(tr[lc].sum+tr[rc].sum)%mod;
}
void cal(node &t,ll mu,ll ad)
{
t.sum=(t.sum*(mu%mod)%mod+(t.r-t.l+1)*ad%mod)%mod;
t.add=(t.add*mu%mod+ad%mod)%mod;
t.mul=t.mul*(mu%mod)%mod;
}
void pushdown(int p)
{
cal(tr[lc],tr[p].mul,tr[p].add);
cal(tr[rc],tr[p].mul,tr[p].add);
tr[p].add=0;
tr[p].mul=1;
}
void build(int p,int ln,int rn)
{
tr[p].l=ln,tr[p].r=rn,tr[p].add=0,tr[p].mul=1;
if(ln==rn)
{
tr[p].sum=a[ln];
return;
}
int mid=(ln+rn)>>1;
build(lc,ln,mid);
build(rc,mid+1,rn);
pushup(p);
}
ll query(int p,int ln,int rn)
{
if(tr[p].l>=ln&&tr[p].r<=rn)return tr[p].sum;
pushdown(p);
int mid=(tr[p].l+tr[p].r)>>1;
ll nsum=0;
if(ln<=mid)nsum=(nsum+query(lc,ln,rn)%mod)%mod;
if(rn>=mid+1)nsum=(nsum+query(rc,ln,rn)%mod)%mod;
return nsum;
}
void update(int p,int ln,int rn,ll mu,ll ad)
{
if(tr[p].l>=ln&&tr[p].r<=rn)
{
cal(tr[p],mu,ad);
return;
}
pushdown(p);
int mid=(tr[p].l+tr[p].r)>>1;
if(ln<=mid)update(lc,ln,rn,mu,ad);
if(rn>=mid+1)update(rc,ln,rn,mu,ad);
pushup(p);
}
int n,q;
int main()
{
cin>>n>>q>>mod;
for(int i=1;i<=n;i++)cin>>a[i];
build(1,1,n);
while(q--)
{
int op;
cin>>op;
if(op==1)
{
ll x,y,k;
cin>>x>>y>>k;
update(1,x,y,k,0);
}
else if(op==2)
{
ll x,y,k;
cin>>x>>y>>k;
update(1,x,y,1,k);
}
else if(op==3)
{
ll x,y;
cin>>x>>y;
cout<<query(1,x,y)%mod<<endl;
}
}
return 0;
}
如果不想在 pushup 和 pushdown 里加边界特判的话,就在删掉他们之后把 pushdown 操作放在 query 和 update 里面的 if 判断后面,不然会 RE 。
朴素线段树例题
XOR的艺术:\(lazytag\) 变成取反的标志,\(sum\) 更新为 \((r-l+1)-sum\) 。
忠诚:去掉 update 和 pushdown 以及懒标记,区间答案设为其最小值即可。