线段树学习笔记

概述

本篇主要讲朴素线段树以及简单变式。

朴素线段树

时间 \(O(n \log n )\) ,空间开 \(4\) 倍,主要操作有 build , query , update , pushup , pushdown 。

公式

\[lc=p<<1 \]

\[rc=(p<<1)|1 \]

易错点

  1. 判断是否完全覆盖区间,大于等于和小于等于的方向:ln<=tr[p].l&&rn>=tr[p].r
  2. query 与 update 都要 pushdown 。
  3. update 与 build 都要 pushup 。
  4. pushdown 与 update 时,注意给子节点加懒标记是用 += 而不是 = 。
  5. 注意初始化时要建树 : build(1,1,n) ,最好一开始就写上,防止忘记。
  6. pushdown 操作可以放在函数最前面,防止忘记,但是在 pushup 和 pushdown 里加上边界特判。
  7. 线段长度为 r-l+1
  8. mid 应取 (l+r)>>1
  9. query 作为有返回值的函数,一样要用 long long 类型。
  10. update 与 query 函数中 ln,rn 的值恒定不变(存的是要查询的区间),只有 build 中的 ln,rn 要改变(存的是当前节点左右区间)。
  11. 懒标记是打在最后的节点上,前面的并不会打到,因此 update 操作要 pushup 。
  12. 到左儿子是判断 $lc \le mid $ ,到右儿子是判断 \(rc \ge mid+1\) ,必须加一。
  13. 如果不想在 pushup 和 pushdown 里加边界特判的话,就在删掉他们之后把 pushdown 操作放在 query 和 update 里面的 if 判断后面,不然会 RE 。
  14. lc 和 rc 里的 p 错写成右移 >> ,正确写法应该是左移 << 。
  15. build 中某个节点初值赋错,例如把 tr[p].v=a[ln] 写成 tr[p].v=a[p]
  16. update 中,完全覆盖时修改完区间之后没有 return。

模板

#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N];
struct node{
	int l,r;
	ll sum,add=0;
}tr[4*N];
void pushup(int p)
{
	tr[p].sum=0;
	if(lc<4*N)tr[p].sum+=tr[lc].sum;
	if(rc<4*N)tr[p].sum+=tr[rc].sum;
}
void pushdown(int p)
{
	if(tr[p].add!=0)
	{
		if(lc<4*N)
		{
			tr[lc].add+=tr[p].add;
			tr[lc].sum+=(tr[lc].r-tr[lc].l+1)*tr[p].add;
		}
		if(rc<4*N)
		{
			tr[rc].add+=tr[p].add;
			tr[rc].sum+=(tr[rc].r-tr[rc].l+1)*tr[p].add;
		}
	}
	tr[p].add=0;
}
void build(int p,int ln,int rn)
{
	tr[p].l=ln,tr[p].r=rn,tr[p].add=0;
	if(ln==rn)
	{
		tr[p].sum=a[ln];
		return;
	}
	int mid=(ln+rn)>>1;
	build(lc,ln,mid);
	build(rc,mid+1,rn);
	pushup(p);
}
ll query(int p,int ln,int rn)
{
	pushdown(p);
	if(ln<=tr[p].l&&rn>=tr[p].r)return tr[p].sum;
	int mid=(tr[p].l+tr[p].r)>>1;
	ll nsum=0;
	if(ln<=mid)nsum+=query(lc,ln,rn);
	if(rn>=mid+1)nsum+=query(rc,ln,rn);
	return nsum;
}
void update(int p,int ln,int rn,ll k)
{
	pushdown(p);
	if(ln<=tr[p].l&&rn>=tr[p].r)
	{
		tr[p].add+=k;
		tr[p].sum+=(tr[p].r-tr[p].l+1)*k;
		return;
	}
	int mid=(tr[p].l+tr[p].r)>>1;
	if(ln<=mid)update(lc,ln,rn,k);
	if(rn>=mid+1)update(rc,ln,rn,k);
	pushup(p);
}
int n,m;
int main()
{
	cin>>n>>m;
	for(int i=1;i<=n;i++)cin>>a[i];
	build(1,1,n);
	while(m--)
	{
		int op;
		cin>>op;
		if(op==1)
		{
			ll x,y,k;
			cin>>x>>y>>k;
			update(1,x,y,k);
		}
		else
		{
			ll x,y;
			cin>>x>>y;
			cout<<query(1,x,y)<<endl;
		}
	}
	return 0;
}

变式:线段树2 (多个懒标记)

对于新加的一个乘法操作,我们需要多加一个懒标记,那么在下传懒标记的时候就要考虑谁先谁后的问题。
于是进行分讨:

加法先,乘法后

如果要使结果正确,原来的数为 \(x\) ,父亲的加法懒标记为 \(fa\) ,父亲乘法懒标记为 \(fm\) ,自己加法懒标记为 \(a\) ,自己乘法懒标记为 \(m\) ,那么正确结果 \(res\)

\[\begin{aligned} res &= \left\{[(x+a)*m]+fa \right\}*fm \\ &= \left\{m*x+m*a+fa \right\}*fm \\ &= m*x*fm+m*a+fm+fa*fm \\ &= [x+(a+fa/m)]*m*fm \end{aligned} \]

先按加法优先的顺序把式子写出来,化为最简形式,并且最后化为同样加法优先且使结果正确的形式,那么这就是懒标记的下传。

因此加法优先,懒标记为:

\[a=a+fa/m \]

\[m=m*fm \]

其中 \(fa/m\) 为浮点数,精度不佳,不能用这种方式,舍去。

乘法先,加法后

如果要使结果正确,那么正确结果 \(res\) :

\[\begin{aligned} res &= \left\{[(x*m)+a]*fm\right\}+fa \\ &= \left\{[x*m+a]*fm\right\}+fa \\ &= \left\{x*m*fm+a*fm\right\}+fa \\ &= x*m*fm+a*fm+fa \\ &= [x*(m*fm)]+(a*fm+fa) \end{aligned} \]

于是懒标记为:

\[a=a*fm+fa \]

\[m=m*fm \]

全是乘法,不会损失精度,因此可行。

下传懒标记后的 \(sum\)

首先因为先乘后加,所以我们可以写出以下式子:

\[\begin{aligned} sum &= (x_{l}*fm+fa)+(x_{l+1}*fm+fa)+ ... + (x_{r-1}*fm+fa)+(x_{r}*fm+fa) \\ &=(x_{l}+x_{l+1}+...+x_{r-1}+x_{r})*fm + fa + fa+ ... + fa+fa \\ &= sum*fm+(r-l+1)*fa \end{aligned} \]

注意是父亲的懒标记 \(fm,fa\) ,不是自己的,因为自己的之前已经加过了。

代码

未简化边界特判版(大常数)

#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N],mod;
struct node{
	int l,r;
	ll sum,add=0,mul=1;
}tr[4*N];
void pushup(int p)
{
	tr[p].sum=0;
	if(lc<4*N)tr[p].sum+=tr[lc].sum;
	if(rc<4*N)tr[p].sum+=tr[rc].sum;
	tr[p].sum%=mod;
}
void cal(node &t,ll mu,ll ad)
{
	t.sum=(t.sum*(mu%mod)%mod+(t.r-t.l+1)*ad%mod)%mod;
	t.add=(t.add*mu%mod+ad%mod)%mod;
	t.mul=t.mul*(mu%mod)%mod;
}
void pushdown(int p)
{
	if(tr[p].add!=0 || tr[p].mul!=1)
	{
		if(lc<4*N)
		{
			cal(tr[lc],tr[p].mul,tr[p].add);
		}
		if(rc<4*N)
		{
			cal(tr[rc],tr[p].mul,tr[p].add);
		}
	}
	tr[p].add=0;
	tr[p].mul=1;
}
void build(int p,int ln,int rn)
{
	tr[p].l=ln,tr[p].r=rn,tr[p].add=0,tr[p].mul=1;
	if(ln==rn)
	{
		tr[p].sum=a[ln];
		return;
	}
	int mid=(ln+rn)>>1;
	build(lc,ln,mid);
	build(rc,mid+1,rn);
	pushup(p);
}
ll query(int p,int ln,int rn)
{
	pushdown(p);
	if(tr[p].l>=ln&&tr[p].r<=rn)return tr[p].sum;
	int mid=(tr[p].l+tr[p].r)>>1;
	ll nsum=0;
	if(ln<=mid)nsum=(nsum+query(lc,ln,rn)%mod)%mod;
	if(rn>=mid+1)nsum=(nsum+query(rc,ln,rn)%mod)%mod;
	return nsum;
}
void update(int p,int ln,int rn,ll mu,ll ad)
{
	pushdown(p);
	if(tr[p].l>=ln&&tr[p].r<=rn)
	{
		cal(tr[p],mu,ad);
		return;
	}
	int mid=(tr[p].l+tr[p].r)>>1;
	if(ln<=mid)update(lc,ln,rn,mu,ad);
	if(rn>=mid+1)update(rc,ln,rn,mu,ad);
	pushup(p);
}
int n,q;
int main()
{
	cin>>n>>q>>mod;
	for(int i=1;i<=n;i++)cin>>a[i];
	build(1,1,n);
	while(q--)
	{
		int op;
		cin>>op;
		if(op==1)
		{
			ll x,y,k;
			cin>>x>>y>>k;
			update(1,x,y,k,0);
		}
		else if(op==2)
		{
			ll x,y,k;
			cin>>x>>y>>k;
			update(1,x,y,1,k);			
		}
		else if(op==3)
		{
			ll x,y;
			cin>>x>>y;
			cout<<query(1,x,y)%mod<<endl;
		}
	}
	return 0;
}

简化版(小常数)

#include <bits/stdc++.h>
#define lc (p<<1)
#define rc ((p<<1)|1)
using namespace std;
typedef long long ll;
const int N=100005;
ll a[N],mod;
struct node{
	int l,r;
	ll sum,add=0,mul=1;
}tr[4*N];
void pushup(int p)
{
	tr[p].sum=(tr[lc].sum+tr[rc].sum)%mod;
}
void cal(node &t,ll mu,ll ad)
{
	t.sum=(t.sum*(mu%mod)%mod+(t.r-t.l+1)*ad%mod)%mod;
	t.add=(t.add*mu%mod+ad%mod)%mod;
	t.mul=t.mul*(mu%mod)%mod;
}
void pushdown(int p)
{
	cal(tr[lc],tr[p].mul,tr[p].add);
	cal(tr[rc],tr[p].mul,tr[p].add);
	tr[p].add=0;
	tr[p].mul=1;
}
void build(int p,int ln,int rn)
{
	tr[p].l=ln,tr[p].r=rn,tr[p].add=0,tr[p].mul=1;
	if(ln==rn)
	{
		tr[p].sum=a[ln];
		return;
	}
	int mid=(ln+rn)>>1;
	build(lc,ln,mid);
	build(rc,mid+1,rn);
	pushup(p);
}
ll query(int p,int ln,int rn)
{
	if(tr[p].l>=ln&&tr[p].r<=rn)return tr[p].sum;
	pushdown(p);
	int mid=(tr[p].l+tr[p].r)>>1;
	ll nsum=0;
	if(ln<=mid)nsum=(nsum+query(lc,ln,rn)%mod)%mod;
	if(rn>=mid+1)nsum=(nsum+query(rc,ln,rn)%mod)%mod;
	return nsum;
}
void update(int p,int ln,int rn,ll mu,ll ad)
{
	if(tr[p].l>=ln&&tr[p].r<=rn)
	{
		cal(tr[p],mu,ad);
		return;
	}
	pushdown(p);
	int mid=(tr[p].l+tr[p].r)>>1;
	if(ln<=mid)update(lc,ln,rn,mu,ad);
	if(rn>=mid+1)update(rc,ln,rn,mu,ad);
	pushup(p);
}
int n,q;
int main()
{
	cin>>n>>q>>mod;
	for(int i=1;i<=n;i++)cin>>a[i];
	build(1,1,n);
	while(q--)
	{
		int op;
		cin>>op;
		if(op==1)
		{
			ll x,y,k;
			cin>>x>>y>>k;
			update(1,x,y,k,0);
		}
		else if(op==2)
		{
			ll x,y,k;
			cin>>x>>y>>k;
			update(1,x,y,1,k);			
		}
		else if(op==3)
		{
			ll x,y;
			cin>>x>>y;
			cout<<query(1,x,y)%mod<<endl;
		}
	}
	return 0;
}

如果不想在 pushup 和 pushdown 里加边界特判的话,就在删掉他们之后把 pushdown 操作放在 query 和 update 里面的 if 判断后面,不然会 RE 。

朴素线段树例题

XOR的艺术\(lazytag\) 变成取反的标志,\(sum\) 更新为 \((r-l+1)-sum\)

忠诚:去掉 update 和 pushdown 以及懒标记,区间答案设为其最小值即可。

posted @ 2024-07-04 00:50  KS_Fszha  阅读(9)  评论(0编辑  收藏  举报