线段树

线段树(SegTree)

引入

现在有一个序列an,对于这个序列有一些操作,分别为:
1.将区间[i,j]内的所有数都加上k
2.询问区间[i,j]内的和
这是一个动态处理的过程,如果强行进行暴力的话,时间复杂度将会非常的高,于是,边考虑一种数据结构专门解决这一类问题,于是,就有了*线段树*

算法思想

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。 对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。 使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。所以,用线段树可以在O(m*logn)的时间内完成这道题目,大概是10^5的数量级,可以承受。 那么线段树到底怎么用呢?
线段树是建立在线段的基础上,每个结点都代表了一条线段[a,b]。长度为1的线段称为元线段。非元线段都有两个子结点,左结点代表的线段为[a,(a + b) / 2],右结点代表的线段为[((a + b) / 2)+1,b]。
在查询时,我们从根节点开始自顶向下找到待查询线段的左边界和右边界,则“夹在中间”的所有叶子节点不重复不遗漏地覆盖了整个待查询线段。(如查询【2,5】)
[1 2 3 4 5 6 7 8]

    /               \

[1 2 3 4] [5 6 7 8]

  /       \          /      \

[1 2] [3 4]   [5 6]  [7 8]

/     \  /   \     /   \   /    \

1    2 3 4     5  6 7      8

从图中不难发现,树的左右各有一条“主线”,虽有分叉,但每层最多只有两个结点继续向下延伸(整棵树的左右子树各一个)。如上图所示[2,5]=[2]+[3,4]+[5]。在后文中,凡是遇到这样的区间分解,就把分解的区间叫做边界区间,因为它们对应与分解过程的递归边界。
如何更新线段树呢?update(x,v)显然需要更新[x]对应的结点,然后还要更新他的所有祖先结点。
直接上代码,注意如果当前区间没有完全包含std区间的话,一定要push down

只支持区间加和的SegTree

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define MAXN 1000001

using namespace  std;

ll n,m,a[MAXN],ans[MAXN<<2],tag[MAXN<<2];

inline ll Read()
{
	ll s=0,k=1;
	char c=getchar();
	while(c!='-'&&(c<'0'||c>'9')) c=getchar();
	if(c=='-') {k=-1;c=getchar();}
	while(c>='0'&&c<='9') {s=(s<<3)+(s<<1)+c-'0';c=getchar();}
	return s*k;
}



inline ll ls (ll x) {return x<<1;}

inline ll rs (ll x) {return x<<1|1;}

inline void push_up(ll x)//利用子节点更新父节点信息 
{
	ans[x]=ans[ls(x)]+ans[rs(x)];
}

inline void build (ll x,ll l,ll r)//建树 
{
	tag[x]=0;
	if(l==r){ans[x]=a[l]; return;}
    ll mid= (l+r)>>1 ;
	build(ls(x),l,mid);build(rs(x),mid+1,r);
	push_up(x);	
}

inline void Add (ll x,ll l,ll r,ll k)//将l~r区间整体加上k 
{
	tag[x]+=k;
	ans[x]+=k*(r-l+1);    //直接修改和 
}

inline void push_down (ll x,ll l,ll r)//下传懒标记 ,每次只下传一层 
{
	ll mid=(l+r)>>1;
	Add(ls(x),l,mid,tag[x]);
	Add(rs(x),mid+1,r,tag[x]);
	tag[x]=0;//清除懒标记 
}

inline void update(ll x,ll nl,ll nr,ll l,ll r,ll k)//拆分区间,本质上是分块 nl nr 是需要修改的区间 
{
	if(nl<=l&&r<=nr)
	{ 
	    ans[x]+=k*(r-l+1);
		tag[x]+=k;
		return;
	}
	push_down(x,l,r);
	ll mid=(l+r)>>1;
	if(nl<=mid)update(ls(x),nl,nr,l, mid ,k);
	if(nr >mid)update(rs(x),nl,nr,mid+1,r,k);
	push_up(x);
}

inline ll query(ll x,ll q_x,ll q_y,ll l,ll r)//询问 
{
	ll res=0;
	if(q_x<=l&&r<=q_y) return ans[x];
	ll mid=(l+r)>>1;
	push_down(x,l,r);//没有完全包含的话需要先下传懒标记 
	if(q_x<=mid) res+=query(ls(x),q_x,q_y,l, mid );
	if(q_y >mid) res+=query(rs(x),q_x,q_y,mid+1,r);
	return res;
}

int main ()
{
	ll op,nl,nr,k,qx,qy;
	
	n=Read();m=Read();
	for(ll i=1;i<=n;i++) a[i]=Read();
	
	build(1,1,n);
	
	for(ll i=1;i<=m;i++)
	{
		op=Read();
		switch(op)
		{
			case 1:
			{
				nl=Read();nr=Read();k=Read();
				update(1,nl,nr,1,n,k);
				break;
			}
			case 2:
			{
				qx=Read();qy=Read();
				printf("%lld\n",query(1,qx,qy,1,n));
				break;
			}
			default:break;
		}
	}
	return 0;

  

支持区间乘法

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAXN 100005
#define ll long long

using namespace std;

ll n,m,mod,num[MAXN<<2],tag_p[MAXN<<2],tag_m[MAXN<<2],sum[MAXN<<2];


inline ll Read()
{
	int s=0,k=1;
	char c=getchar();
	while(c!='-'&&(c>'9'||c<'0')) c=getchar();
	if(c=='-'){k=-1;c=getchar();}
	while(c>='0'&&c<='9'){s=(s<<3)+(s<<1)+(c-'0');c=getchar();}
	return s*k;
}

inline ll ls(ll x){return x<<1;} 

inline ll rs(ll x){return x<<1|1;}

inline void push_up(ll x)
{
	sum[x]=(sum[ls(x)]+sum[rs(x)])%mod; 
}

inline void build(ll x,ll l,ll r)
{
	tag_p[x]=0;tag_m[x]=1;sum[x]=0;
	if(l==r){sum[x]=num[l];return ;}
	ll mid= (l+r)>>1 ;
	build(ls(x),l,mid);build(rs(x),mid+1,r);
	push_up(x);
}

inline void modify(ll x,ll l,ll r,ll m_k,ll p_k)
{
	sum[x]=(m_k*sum[x]+(r-l+1)*p_k)%mod;
	tag_p[x]=(tag_p[x]*m_k+p_k)%mod;
	tag_m[x]=(tag_m[x]*m_k)%mod; 
}

inline void push_down(ll x,ll l,ll r)
{
	ll mid =(l+r)>>1;
	modify(ls(x),l,mid,tag_m[x],tag_p[x]);
	modify(rs(x),mid+1,r,tag_m[x],tag_p[x]);
	tag_p[x]=0;tag_m[x]=1; 
}

inline void update_p(ll x,ll l,ll r,ll stdl,ll stdr,ll k)
{
	if(l>stdr || r<stdl) return ;
	
	if(l>=stdl && r<=stdr)
	{
		tag_p[x]=(tag_p[x]+k)%mod;
		sum[x]=(sum[x]+((r-l+1)*k)%mod)%mod;
		return ;
	}
	ll mid =(l+r)>>1;
	push_down(x,l,r);
	update_p(ls(x),l, mid ,stdl,stdr,k);
	update_p(rs(x),mid+1,r,stdl,stdr,k);
	push_up(x);
}

inline void update_m(ll x,ll l,ll r,ll stdl,ll stdr,ll k)
{
	if(l>stdr || r<stdl) return ;
	
	if(l>=stdl && r<=stdr)
	{
		tag_p[x]=(tag_p[x]*k)%mod;tag_m[x]=(tag_m[x]*k)%mod;
		sum[x]=(sum[x]*k)%mod;
		return ;
	}
	ll mid = (l+r)>>1;
	push_down(x,l,r);
	update_m(ls(x),l, mid ,stdl,stdr,k);
	update_m(rs(x),mid+1,r,stdl,stdr,k);
	push_up(x);
}

inline ll query(ll x,ll l,ll r,ll stdl,ll stdr)
{
	if(r<stdl || l>stdr) return 0;
	ll res=0;
	if(l>=stdl && r<=stdr) return sum[x];
	ll mid=(l+r)>>1;
	push_down(x,l,r);
	res+=query(ls(x),l, mid ,stdl,stdr);
	res+=query(rs(x),mid+1,r,stdl,stdr);
	 return res%mod;
 } 

int main ()
{
	n=Read();m=Read();mod=Read();
	for(ll i=1;i<=n;i++) num[i]=Read();
	build(1,1,n);
	
	ll op,x,y,k;
	for(ll i=1;i<=m;i++)
	{
		op=Read();
		switch(op)
		{
			case 1:{
				x=Read();y=Read();k=Read();
				update_m(1,1,n,x,y,k);
				break;
			}
			
			case 2:{
				x=Read();y=Read();k=Read();
				update_p(1,1,n,x,y,k);
				break;
			}
			
			case 3:{
				x=Read();y=Read();
				printf("%lld\n",query(1,1,n,x,y)%mod); 
				break;
			}
			default: break;
		 } 
	 } 
	return 0;
 } 

  

posted @ 2021-05-14 19:19  Roy0_0  阅读(61)  评论(0编辑  收藏  举报
Live2D