线段树学习笔记

本文整理自《算法竞赛进阶指南》IO.wiki

线段树一种用来维护区间信息的通用数据结构

线段树可以在 O(logn) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树的建树

通过递归将值划分为大于1的若干个区间,把数列划分为树形结构,通过2合并区间来获得信息,每个非根节点[l,r]内部,都可以划分为
[l,mid],[mid+1,r],mid=l+r2

对于序列a[11,16,15,52,7]

我们有

image

递归建树对于初始节点p代表区间[1,n],如果区间[x,x]范围为1直接a[x]赋值然后返回即可,如果不是则继续划分左右子树
[l,mid],[mid+1,r]从叶子节点向上传递信息。

void build(int p, int l, int r) {
	t[p].l = l, t[p].r = r; // 节点p代表区间[l,r]
	if (l == r) { t[p].dat = a[l]; return; } // 叶节点
	int mid = (l + r) / 2; // 划分左右字数
	build(p*2, l, mid); // 左子节点[l,mid],编号p*2
	build(p*2+1, mid+1, r); // 右子节点[mid+1,r],编号p*2+1
	t[p].dat = max(t[p*2].dat, t[p*2+1].dat); // 从下往上传递信息
}

t l r val
1 1 5 52
2 1 3 16
3 1 2 52
4 1 1 16
5 3 3 15
6 4 5 52
7 5 5 7

查询操作

如果要查询某个值a[x]的话直接访问t[x].dat即可,如果查询区间
我们需要

  1. 查询[l,r]判断单前节点是否完全被[l,r]覆盖
  2. 如果是的话返回该节点值到上一层继续参加运算合并
  3. 如果不是就继续划分重叠的区间的子区间。

如:

查询[2,3]

image

int ask(int p, int l, int r) {
	if (l <= t[p].l && r >= t[p].r) return t[p].dat; // 完全包含,直接返回
	int mid = (t[p].l + t[p].r) / 2;
	int val = 0;
	if (l <= mid) val = max(val, ask(p*2, l, r)); // 左子节点有重叠
	if (r > mid) val = max(val, ask(p*2+1, l, r)); // 右子节点有重叠
	return val;
}

如果要修改区间的话我们要遍历整个[l,r]
我们需要懒惰标记

oiwiki上懒惰标记的定义:

懒惰标记,简单来说,就是通过延迟对节点信息的更改,从而减少可能不必要的操作次数。每次执行修改时,我们通过打标记的方法表明该节点对应的区间在某一次操作中被更改,但不更新该节点的子节点的信息。实质性的修改则在下一次访问带有标记的节点时才进行。

即:查询在修改

宏定义减少代码量,分别为加法,乘法定义(lazytag)add,mul

struct node
{
    int l,r;
    ll sum,mul,add;

    #define l(x) trees[x].l
    #define r(x) trees[x].r
    #define sum(x) trees[x].sum
    #define mul(x) trees[x].mul
    #define add(x) trees[x].add
    
}trees[SIZE * 4];

建树

mul懒惰标记自动为1,左右递归划分子区间。
sum自下向上更新

void build(int p,int l,int r)
{
    mul(p) = 1;
    l(p) = l,r(p) = r;
    if(l == r){sum(p) = a[l]%mod;return;}
    int mid = (l + r)/2;
    build(p * 2, l, mid);
    build(p * 2 + 1 ,mid + 1,r);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
}

维护lazytag

void spread(int p)//维护lazytag
{
    
    
        sum(p * 2) = (add(p) * (r(p * 2) - l(p * 2) + 1) + sum(p * 2) * mul(p))% mod;
        sum(p * 2 + 1) = (add(p) * (r(p * 2 + 1) - l(p * 2 + 1) +  1) + sum(p * 2 + 1) * mul(p)) % mod;

        mul(p * 2) = (mul(p) * mul(p * 2)) % mod;
        mul(p * 2 + 1) = (mul(p) * mul(p * 2 + 1)) % mod;

        add(p * 2) = (add(p) + add(p * 2)*mul(p)) % mod;
        add(p * 2 + 1) = (add(p) + add(p * 2 + 1)*mul(p)) % mod;

        mul(p) = 1;//更新lazytag
        add(p) = 0;

}

加法乘法操作

void change_mul(int p,int l,int r,int d)
{
    if(l <= l(p) && r >= r(p))//判断是否在区间内
    {
        sum(p) = (sum(p) * d) % mod;//更新主节点
        mul(p) = (mul(p) * d) % mod;
        add(p) = (add(p) * d) % mod;
        return ;
    }

    spread(p);

    int mid = (l(p) + r(p))/2;
    if(l <= mid)change_mul(p * 2,l,r,d);
    if(r > mid)change_mul(p * 2 + 1,l,r,d);

    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % mod;

}

void change_add(int p,int l,int r,int d)
{
    if(l <= l(p) && r(p) <= r)
    {
        add(p) =(add(p) + d)%mod;
        sum(p) = (d * (r(p) - l(p) + 1) + sum(p)) % mod;//更新主节点
        return ;
    }

    spread(p);

    int mid = (l(p) + r(p))/2;
    if(l <= mid)change_add(p * 2,l,r,d);
    if(r > mid)change_add(p * 2 + 1,l,r,d);
    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % mod;

}

查询

ll ask(int p,int l,int r)
{
    if( l <= l(p) && r >= r(p)){return sum(p);}//是否在区间内
    spread(p);//更新节点
    int mid = (l(p) + r(p))/2;
    ll val = 0;
    if(l <= mid) val = (val + ask(p * 2 ,l ,r)) % mod;
    if(r > mid) val = (val + ask(p * 2 + 1 ,l ,r)) %mod;
    return val;

}
#include <iostream> 

//17.2
using namespace std;
typedef long long ll ;
const int SIZE = 100010;


ll a[SIZE];
int mod;
int n,m;

struct node
{
    int l,r;
    ll sum,mul,add;

    #define l(x) trees[x].l
    #define r(x) trees[x].r
    #define sum(x) trees[x].sum
    #define mul(x) trees[x].mul
    #define add(x) trees[x].add
    
}trees[SIZE * 4];

void build(int p,int l,int r)
{
    mul(p) = 1;
    l(p) = l,r(p) = r;
    if(l == r){sum(p) = a[l]%mod;return;}
    int mid = (l + r)/2;
    build(p * 2, l, mid);
    build(p * 2 + 1 ,mid + 1,r);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
}


void spread(int p)//维护lazytag
{
    
    
        sum(p * 2) = (add(p) * (r(p * 2) - l(p * 2) + 1) + sum(p * 2) * mul(p))% mod;//先乘再加
        sum(p * 2 + 1) = (add(p) * (r(p * 2 + 1) - l(p * 2 + 1) +  1) + sum(p * 2 + 1) * mul(p)) % mod;

        mul(p * 2) = (mul(p) * mul(p * 2)) % mod;
        mul(p * 2 + 1) = (mul(p) * mul(p * 2 + 1)) % mod;

        add(p * 2) = (add(p) + add(p * 2)*mul(p)) % mod;
        add(p * 2 + 1) = (add(p) + add(p * 2 + 1)*mul(p)) % mod;

        mul(p) = 1;//更新lazytag
        add(p) = 0;

}


void change_add(int p,int l,int r,int d)
{
    if(l <= l(p) && r(p) <= r)
    {
        add(p) =(add(p) + d)%mod;
        sum(p) = (d * (r(p) - l(p) + 1) + sum(p)) % mod;//更新主节点
        return ;
    }

    spread(p);

    int mid = (l(p) + r(p))/2;
    if(l <= mid)change_add(p * 2,l,r,d);
    if(r > mid)change_add(p * 2 + 1,l,r,d);
    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % mod;

}

void change_mul(int p,int l,int r,int d)
{
    if(l <= l(p) && r >= r(p))//判断是否在区间内
    {
        sum(p) = (sum(p) * d) % mod;//更新主节点
        mul(p) = (mul(p) * d) % mod;
        add(p) = (add(p) * d) % mod;
        return ;
    }

    spread(p);

    int mid = (l(p) + r(p))/2;
    if(l <= mid)change_mul(p * 2,l,r,d);
    if(r > mid)change_mul(p * 2 + 1,l,r,d);

    sum(p) = (sum(p * 2) + sum(p * 2 + 1)) % mod;

}

ll ask(int p,int l,int r)
{
    if( l <= l(p) && r >= r(p)){return sum(p);}
    spread(p);
    int mid = (l(p) + r(p))/2;
    ll val = 0;
    if(l <= mid) val = (val + ask(p * 2 ,l ,r)) % mod;
    if(r > mid) val = (val + ask(p * 2 + 1 ,l ,r)) %mod;
    return val;

}


int main()
{

    scanf("%d%d%d",&n,&m,&mod);
    for(int i = 1 ; i <= n ; i ++ )scanf("%lld",&a[i]);
    build(1,1,n);

    for(int i = 1 ;i <= m ; i ++ )
    {
        int op, l,r,d;
        scanf("%d%d%d",&op,&l,&r);
        if(op == 2)
        {

            scanf("%d",&d);
            change_add(1,l,r,d);

        }
        if(op == 1)
        {
           
            scanf("%d",&d);
            change_mul(1,l,r,d);
           
        } 
        if(op == 3)
        {
            printf("%lld\n",ask(1,l,r) % mod);
           
        }
    }

    return 0;
}

新板子

线段树

相关数组tr[x]线段树子节点,mi[x]最小标记,laz[x]加法懒惰标记。
ls:左子树,rs右子树;
lsd 是否在查询区间内
uplazy:若此节点有懒惰标记,懒惰标记下传子节点

#define lscc lsqq,z	//更改左区间
#define rscc rsqq,z	//更该右区间
#define lsqq lson,x,y	//查询左区间
#define rsqq rson,x,y	//查询右区间
#define lson p << 1, l, mid	//递归左区间
#define rson p << 1 | 1, mid + 1, r	//递归右区间
#define rd y <= mid	//判定是否在左区间
#define ld x > mid	//判定是否在右区间
#define ls p << 1	//左树子树位置
#define rs p << 1 | 1	//右子树位置
#define mid ((l + r) >> 1)	//划分左右子树
#define rt return
#define vd void
#define ist l == x &&r == y	//是否在区间内
#define LL long long
#define uplazy  \
    if (laz[p]) \
    down(p, l, r)

建树

左右递归建树,最后更新节点


void build(int p, int l, int r){
    if (l == r)rt tr[p] = mi[p] = read(), vd();
    build(lson),build(rson);
    update(p);
}

更新节点


inline void update(int p){
    tr[p] = tr[ls] + tr[rs];
    mi[p] = min(mi[ls], mi[rs]);
}

更新子节点

下传lazy标记
lazy[son]lazy[son]+lazy[fa]
mi[son]mi[son]+lazy[fa]
tr[son]tr[son]+lazy[p]×

inline void down(int p, int l, int r){
    laz[ls] += laz[p];laz[rs] += laz[p];
    tr[ls] += laz[p] * (mid - l + 1); mi[ls] += laz[p];
    tr[rs] += laz[p] * (r - mid); mi[rs] += laz[p];
    laz[p] = 0;
}

区间修改


void change(int p, int l, int r, int x, int y, int z){
    if (ist){
        tr[p] += (LL)z * (r - l + 1);
        mi[p] += z;
        laz[p] += z;
        rt;
    }
    uplazy;
    if (rd)change(lscc);else
    if (ld)change(rscc);else
    change(lson, x, mid, z), change(rson, mid + 1, y, z);
    update(p);
}

查询

##区间和
LL ask(int p, int l, int r, int x, int y){
    if (ist)rt tr[p];  uplazy;
    if (rd)rt ask(lsqq); else
    if (ld)rt ask(rsqq); else
    rt ask(lson, x, mid) + ask(rson, mid + 1, y);
}
## 区间最值
LL query(int p, int l, int r, int x, int y){
    if (ist)return mi[p]; uplazy;
    if (rd)rt query(lsqq); else 
    if (ld)rt query(rsqq); else
    rt min(query(lson, x, mid), query(rson, mid + 1, y));
}

全部代码如下

#include <iostream>
#include <cstdio>
#define lscc lsqq,z
#define rscc rsqq,z
#define lsqq lson,x,y
#define rsqq rson,x,y
#define lson p << 1, l, mid
#define rson p << 1 | 1, mid + 1, r
#define rd y <= mid
#define ld x > mid
#define ls p << 1
#define rs p << 1 | 1
#define mid ((l + r) >> 1)
#define rt return
#define vd void
#define ist l == x &&r == y
#define LL long long
#define uplazy  \
    if (laz[p]) \
    down(p, l, r)

using namespace std;

const int N = 2e5+5,M=N<<2;
int n, q;
LL tr[M], laz[M], mi[M];

inline void update(int p){
    tr[p] = tr[ls] + tr[rs];
    mi[p] = min(mi[ls], mi[rs]);
}

inline void down(int p, int l, int r){
    laz[ls] += laz[p];laz[rs] += laz[p];
    tr[ls] += laz[p] * (mid - l + 1); mi[ls] += laz[p];
    tr[rs] += laz[p] * (r - mid); mi[rs] += laz[p];
    laz[p] = 0;
}
void build(int p, int l, int r){
    if (l == r)rt tr[p] = mi[p] = read(), vd();
    build(lson),build(rson);
    update(p);
}
void change(int p, int l, int r, int x, int y, int z){
    if (ist){
        tr[p] += (LL)z * (r - l + 1);
        mi[p] += z;
        laz[p] += z;
        rt;
    }
    uplazy;
    if (rd)change(lscc);else
    if (ld)change(rscc);else
    change(lson, x, mid, z), change(rson, mid + 1, y, z);
    update(p);
}
LL ask(int p, int l, int r, int x, int y){
    if (ist)rt tr[p];  uplazy;
    if (rd)rt ask(lsqq); else
    if (ld)rt ask(rsqq); else
    rt ask(lson, x, mid) + ask(rson, mid + 1, y);
}
LL query(int p, int l, int r, int x, int y){
    if (ist)return mi[p]; uplazy;
    if (rd)rt query(lsqq); else 
    if (ld)rt query(rsqq); else
    rt min(query(lson, x, mid), query(rson, mid + 1, y));
}
char ch[5];
int x, y, z;
int main(){
    n = read(),q = read();
    build(1, 1, n);
    while (q--){
        scanf("%s", ch + 1);
        x = read(),y = read();
        if (ch[1] == 'P'){
            z = read();
            change(1, 1, n, x, y, z);
        }
        else if (ch[1] == 'M')
            printf("%lld\n", query(1, 1, n, x, y));
        else if (ch[1] == 'S')
            printf("%lld\n", ask(1, 1, n, x, y));
    }
    rt 0;
}


例题

P3870 [TJOI2009] 开关

#include <iostream>
using namespace std;
#define ls (p*2)
#define rs (p*2+1)
typedef long long ll;
const int N = 1e6+10;
ll sum[N];
int lz[N];
void update(int p)
{
	sum[p] = sum[ls] + sum[rs];
}

void pushdown(int p,int l,int r)
{
	int mid = l + r >> 1;
	if(lz[p] == 1)
	{
		lz[ls]^=1;lz[rs]^=1;
		sum[ls]=mid - l + 1 - sum[ls];
		sum[rs]=r - mid - sum[rs];
		lz[p] = 0;
	}
}

void change(int p,int l,int r,int al,int ar)
{
	if(l >= al && r <= ar){
		lz[p]^=1;
		sum[p]= r - l + 1 - sum[p];
		return ;
	}
	int mid = l + r >> 1;
	pushdown(p,l,r);
	if(al <= mid)change(ls,l,mid,al,ar);
	if(ar >  mid)change(rs,mid+1,r,al,ar);
	update(p);
	
}

ll ask(int p,int l,int r,int al,int ar){
	ll ans = 0;
	if(l >= al && r <= ar)return sum[p];
	int mid = l + r >> 1;
	pushdown(p,l,r);
	if(al <= mid) ans += ask(ls,l,mid,al,ar);
	if(ar >  mid) ans += ask(rs,mid+1,r,al,ar);
	return ans;
	
}


int main()
{
	int n,a,b,c,m;
	cin >> n >> m;
	for(int i = 1 ; i <= m ; i ++ )
	{
		cin >> c >> a >> b;
		if(c==1)printf("%lld\n",ask(1,1,n,a,b));
		else change(1,1,n,a,b);
	}
	return 0;
}

posted @   Erfu  阅读(34)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 在鹅厂做java开发是什么体验
· 百万级群聊的设计实践
· WPF到Web的无缝过渡:英雄联盟客户端的OpenSilver迁移实战
· 永远不要相信用户的输入:从 SQL 注入攻防看输入验证的重要性
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
点击右上角即可分享
微信分享提示