线段树

线段树

一种我琢磨了很长时间才明白的数据结构
核心思想就是把一个序列,分成一个二叉树,叶子节点存的是每个元素,能够快速修改或访问区间中的数值,功能♂强大
线段树主要分为下面几步:

push_up操作

void push_up(int p){
    tree[p]=tree[lc(p)]+tree[rc(p)];
}

这个很简单

build操作

void build(int l,int r,int p){
    if(l==r){
        tree[p]=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(l,mid,lc(p));
    build(mid+1,r,rc(p));
    push_up(p);
}

其实就是一个递归
先建左子树,再建右子树,最后合并起来更新根节点

但是这里我们要注意哈,如果你是像我一样以这种形式

int a[N],n,m,tree[N<<2],tag[N<<2];
int x,y,k;

存树的话,那么build操作就没别的事了
还有的大佬是把每个节点的.l=cnt 这种形式的,那么还得进行其他操作,具体看别的大佬的博客,当然,我的这种方式是空间最小的

下面一步就是核心了(

pushdown操作

我们把一个tag数组,或者叫它lazy数组,来省事,这么想,如果多次修改了同一个点的大小,我们每次子啊那么大的一个树里,改那么多数看起来是不是很亏,那么不妨在查询之前,我们先把这些要加的数存起来,等到一定时刻需要查询时再加

但是这样有个问题,这些数我们存到哪呢?
伟大的OIer先辈们早已给我们整理好了

如果节点p加上了k
那么先让tree[p]+=k;然后让tag[p]也加上k,

等到需要的时候,我们把这个p传下去,然后让下面的节点的tag[lc(p)]+=k,
这是我们看小标题,这个操作叫做pushdown(push:推,down:下面)(推下去)
那么我们就可以考虑,推下去以后 子树 tree会变,tag会变,我们分别改一下,tag刚才其实已经改过了,只需要该tree就行了,tree更好改了,tag[p]存的是欠下面的数,然后子树的tree就要加上tag[p]* (子树的长度)
此时这个根节点已经做到应该做的了,所以tag[p]=0;
最后放上代码

void pushdown(int l,int r,int p){
    int mid=(l+r)>>1;
    if(tag[p]){
        tag[lc(p)]+=tag[p];
        tag[rc(p)]+=tag[p];
        tree[lc(p)]+=tag[p]*(mid-l+1);
        tree[rc(p)]+=tag[p]*(r-mid);
        tag[p]=0;
    }
}

update和query 区间修改和区间查询就直接背就行了

void update(int l,int r,int k,int m,int n,int p){
    //把l到r这段区间全部加k,总区间为m,n
    if(l<=m&&n<=r){
        tree[p]+=(n-m+1)*k;
        tag[p]+=k;
        return ;
    }
    int mid=(m+n)>>1;
    pushdown(m,n,p);
    if(l<=mid) update(l,r,k,m,mid,lc(p));
    if(r>mid) update(l,r,k,mid+1,n,rc(p));
    push_up(p);
}

int query(int l,int r,int m,int n,int p){
    if(l<=m&&n<=r) return tree[p];
    int mid=(m+n)>>1;
    pushdown(m,n,p);
    int sum=0;
    if(l<=mid) sum+=query(l,r,m,mid,lc(p));
    if(r>mid) sum+=query(l,r,mid+1,n,rc(p));
    return sum;
}

update on:2022.3.2
在给二期说线段树的时候我又重新疏离了一下思路

/*
BlackPink is the Revolution
light up the sky
Blackpink in your area
*/
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
#include <bitset>
#include <vector>
#include <cmath>
#include <queue>
#include <ctime>
#include <map>
#include <set>
#define int long long
#define rep(i, a, b) for(int i = (a); (i) <= (b); ++i)
#define per(i, a, b) for(int i = (a); (i) >= (b); --i)
#define whlie while
using namespace std;
const int N = 1e5 + 5;
typedef long long ll;
typedef pair<int,int> P;

namespace scan {
    template <typename T>
    inline void read(T &x) {
        x = 0; char c = getchar(); int f = 0;
        for (; !isdigit(c); c = getchar()) f |= (c == '-');
        for (; isdigit(c); c=getchar()) x = x * 10 + (c ^ 48);
        if (f) x = -x;
    }
    template <typename T, typename ...Args>
    inline void read(T &x, Args &...args) {
        read(x), read(args...);
    }
    template <typename T>
    inline void write(T x,char ch) {
        if (x < 0) putchar('-'), x = -x;
        static short st[30], tp;
        do st[++tp] = x % 10, x /= 10; while(x);
        while(tp) putchar(st[tp--] | 48);
        putchar(ch);
    }
    template <typename T>
    inline void write(T x) {
        if (x < 0) putchar('-'), x = -x;
        static short st[30], tp;
        do st[++tp] = x % 10, x /= 10; while(x);
        while(tp) putchar(st[tp--] | 48);
    }

    inline void write(char ch) {
        putchar(ch);
    }
}
using namespace scan;
int n, m, mod;
int a[N];
#define lc p << 1
#define rc p << 1 | 1//左移以后最后一位必定为0,所以接着或1即为加1

struct node {
    int sum, add, mul = 1; //建议写成结构体形式,这样的话比较美观,而且可以赋初值
} tr[N << 2];
inline void push_up(int p) {tr[p].sum = (tr[lc].sum + tr[rc].sum) % mod;}
inline void build(int p, int l, int r) { //一般函数前建议加上inline
                                         //上面的函数传进去的三个变量表示,对于[l,r]这个区间来说,它在线段树的节点p上
    if (l == r) { //递归终止条件
        tr[p].sum = a[l]; //叶子节点直接赋值
        return ;
    }
    int mid = (l + r) >> 1; //位运算更快一点
    build(lc, l, mid);
    build(rc, mid + 1, r);
    push_up(p); //下面建完树以后把值合并上来
}

inline void pushdown(int p, int l, int r) { //三个传参同理
    if (tr[p].mul == 1 && !tr[p].add) return ; //当mul和add都没有变化时,就不用往下传变量了
    int mid = (l + r) >> 1;
    tr[lc].add = (tr[lc].add * tr[p].mul % mod + tr[p].add) % mod;
    tr[rc].add = (tr[rc].add * tr[p].mul % mod + tr[p].add) % mod;
    tr[lc].mul = tr[lc].mul * tr[p].mul % mod;
    tr[rc].mul = tr[rc].mul * tr[p].mul % mod;
    tr[lc].sum = (tr[lc].sum * tr[p].mul % mod + tr[p].add * (mid - l + 1) % mod) % mod;
    tr[rc].sum = (tr[rc].sum * tr[p].mul % mod + tr[p].add * (r - mid) % mod) % mod;
    tr[p].add = 0, tr[p].mul = 1;
}

inline void update_add(int p, int l, int r, int ql, int qr, int k) { // [l,r]这段区间在线段树上对应节点p,ql和qr是我们要查询的节点
    if (r < ql || qr < l) return ;
    if (ql <= l && r <= qr) {
        tr[p].sum = (tr[p].sum + k * (r - l + 1) % mod) % mod; //val表示当前节点的值,应该是这一段的每个点都加上k,这一段一共 r-l+1 个点
        tr[p].add = (tr[p].add + k) % mod; //我现在欠你们一人k元,我这边记上
        return ;
    }    
    int mid = (l + r) >> 1;
    pushdown(p, l, r);
    //如果你想判断边界的话这么写
    //if(ql <= mid) update_add(lc, l, mid, ql, qr, k); 我需要查的区间还有一部分在左子树上
    //if(qr > mid)  update_add(rc, mid + 1, r, ql, qr, k); 我需要查的区间还有一部分在右子树上
    update_add(lc, l, mid, ql, qr, k);
    update_add(rc, mid + 1, r, ql, qr, k);
    push_up(p); //我更新完子树以后还是要更改一下自己的,上面我们改的只是那个恰好包含 [l,r] 的节点的值,但是它的父亲节点还是没有修改,所以要把信息重新传给父亲节点
}

inline void update_mul(int p, int l, int r, int ql, int qr, int k) { // [l,r]这段区间在线段树上对应节点p,ql和qr是我们要查询的节点
    if (r < ql || qr < l) return ;
    if (ql <= l && r <= qr) {
        tr[p].sum = tr[p].sum * k % mod;
        tr[p].add = tr[p].add * k % mod;
        tr[p].mul = tr[p].mul * k % mod;
        return ;
    }    
    int mid = (l + r) >> 1;
    pushdown(p, l, r);
    //如果你想判断边界的话这么写
    //if(ql <= mid) update_add(lc, l, mid, ql, qr, k); 我需要查的区间还有一部分在左子树上
    //if(qr > mid)  update_add(rc, mid + 1, r, ql, qr, k); 我需要查的区间还有一部分在右子树上
    update_mul(lc, l, mid, ql, qr, k);
    update_mul(rc, mid + 1, r, ql, qr, k);
    push_up(p); //我更新完子树以后还是要更改一下自己的,上面我们改的只是那个恰好包含 [l,r] 的节点的值,但是它的父亲节点还是没有修改,所以要把信息重新传给父亲节点
}

inline int query(int p ,int l, int r, int ql, int qr) { //和上面的同理
    if (r < ql || qr < l) return 0;
    if (ql <= l && r <= qr) return tr[p].sum;
    int mid = (l + r) >> 1;
    pushdown(p, l, r);
    //想判断边界的话:
    // int res = 0;
    //if(ql <= mid) res += query(lc, l, mid, ql, qr); 我需要查的区间还有一部分在左子树上
    //if(qr > mid)  res += query(rc, mid + 1, r, ql, qr); 我需要查的区间还有一部分在右子树上
    //return res;
    return (query(lc, l, mid, ql, qr) + query(rc, mid + 1, r, ql, qr)) % mod;
}

signed main(){
#ifndef ONLINE_JUDGE
    freopen("1.in", "r", stdin);
    freopen("1.out", "w", stdout);
#endif//这段话的意思是本地上是freopen读入文件,但是交到OJ上以后就会自动忽视这一段话
    read(n, m, mod);
    for (int i = 1; i <= n; i++) read(a[i]);
    build (1, 1, n);
    for (int i = 1; i <= m; i++) {
        int opt, l, r, k;
        read(opt);
        if (opt == 1) {
            read(l, r, k);
            update_mul(1, 1, n, l, r, k);
        }
        if (opt == 2) {
            read(l, r, k);
            update_add(1, 1, n, l, r, k);
        }
        if (opt == 3) {
            read(l, r);
            write(query(1, 1, n, l, r), '\n');
        }
    }
    return 0;
}
//write:RevolutionBP

完结撒花

posted @ 2021-09-10 22:20  RevolutionBP  阅读(51)  评论(0编辑  收藏  举报