浅谈线段树

  线段树在信息竞赛中是一个非常重要的数据结构,它可以在大多数题目中使用并且很受OIer的钟爱。

  那么在这里,本蒟蒻就来跟大家浅谈一下线段树。大家可以先看一下线段树的模版题。

  线段树特殊在什么地方呢?就是它每个节点代表的不是单独一个数,而是区间内的一个值,这就便于我们以区间为单位对它进行查询与修改(当然单点也是可以的)

  首先,线段树的原理是什么呢?类似一个满二叉树,它的叶子结点(最下方的点)代表一个数(可理解为[1,1],而它的每一个父亲节点表示其两个子结点的并集(至于是最大值还是和,全看题目怎么说)。操作自然也是一个区间内的操作,对于两个区间来说,其情况无非三种:包含,相交或无交集。对于操作的区间包含了现在这个点所表示的区间的时候,可以直接对这个点进行操作,而相交的区间则可以前往其子结点操作,无交集的话则没有继续访问的必要。 

  线段树的好处有很多,那咱们就线谈谈它的坏处吧————码量大,细节多,很容易错(排除那些把模版烂熟于心的大佬)。而起好处也是显而易见的,快————就是其最大的好处。仔细一想,貌似线段树只是把数组变成了一个树的形式,那么线段数的速度究竟高在哪里呢?它的lazy标记。

  那么lazy标记所表示的意思是什么呢?其蕴涵着两个深意:1.该点的val是正确的。2.该点的之后的点的val是错误的,需要把lazy下放。这就意味着我们每次操作不需要遍布每个结点,等到需要操作其子区间时再将标记下放即可,这就是线段树的效率高的重要原因。

  线段树的打法有两种,一种是使用它满二叉树的性质,根据子结点与父亲节点的位置关系,另一种是直接纪录每个节点的左右儿子。相比较来说,第二种方法是比较节省空间的(曾经被一道题卡过空间)

  而这里也为大家提供了两种方法:

  第一种

#include<cstdio>
using namespace std;
#define kb 100010
#define ll long long

inline ll read()
{
    ll ans=0, w=1;
    char ch=getchar();
    while(ch<'0' || ch>'9')
    {
        if(ch=='-')    w=-1;
        ch=getchar();
    }
    while(ch>='0' && ch<='9')
    {
        ans=ans*10+ch-'0';
        ch=getchar();
    }
    return ans*w;
}

struct SegNode
{
    ll val, lazy, l, r;
}ST[4*kb+1];

void buildT(int L, int R, int x)
{
    ST[x].l=L, ST[x].r=R;
    if(L==R)
    {
        ST[x].val=read();
        return ;
    }
    int mid=(L+R)>>1;
    buildT(L, mid, x<<1);
    buildT(mid+1, R, (x<<1)+1);
    ST[x].val=ST[x<<1].val+ST[(x<<1)+1].val;
}

void PushDown(int x)
{
    if(!ST[x].lazy)    return ;
    ll k=ST[x].lazy;
    ST[x<<1].lazy+=k;
    ST[x<<1].val+=k*(ST[x<<1].r-ST[x<<1].l+1);
    ST[(x<<1)+1].lazy+=k;
    ST[(x<<1)+1].val+=k*(ST[(x<<1)+1].r-ST[(x<<1)+1].l+1);
    ST[x].lazy=0;
    return ;
}

void update(int L, int R, int k, int x=1)
{
    if(ST[x].l>=L && ST[x].r<=R)
    {
        ST[x].lazy+=k;
        ST[x].val+=k*(ST[x].r-ST[x].l+1);
        return ;
    }
    int mid=(ST[x].l+ST[x].r)>>1;
    PushDown(x);
    if(mid>=L)        update(L, R, k, x<<1);
    if(mid<R)        update(L, R, k, (x<<1)+1);
    ST[x].val=ST[x<<1].val+ST[(x<<1)+1].val; 
}

ll query(int L, int R, int x=1)
{
    if(ST[x].l>=L && ST[x].r<=R)
        return ST[x].val;
    int mid=(ST[x].l+ST[x].r)>>1;ll ans=0;
    PushDown(x);
    if(mid>=L)        ans+=query(L, R, x<<1);
    if(mid<R)        ans+=query(L, R, (x<<1)+1);
    return ans;
}

int main()
{
    int n=read();
    int m=read();
    buildT(1, n, 1);
    for(int i=1; i<=m; i++)
    {
        int tj=read();
        if(tj==1)
        {
            int x=read();
            int y=read();
            int z=read();
            update(x, y, z); 
        }
        if(tj==2)
        {
            int x=read();
            int y=read();
            printf("%lld\n", query(x, y)); 
        }
    }
    return 0;
}

  第二种

#include<cstdio>
using namespace std;

#define kb 100010
#define ll long long

inline ll read()
{
    ll ans=0, w=1;
    char ch=getchar();
    while(ch<'0' || ch>'9')
    {
        if(ch=='-') w=-1;
        ch=getchar();
    }
    while(ch>='0' && ch<='9')
    {
        ans=ans*10+ch-'0';
        ch=getchar();
    }
    return ans*w;
}

struct SegNode
{
    ll lazy, val, l, r, ls, rs;
}ST[kb*2];

int tot=1;
void buildT(int L, int R, int s=1)
{
    ST[s].l=L;ST[s].r=R;
    if(L==R)
    {
        ST[s].val=read();
        return ;
    }
    int mid=(ST[s].l+ST[s].r)>>1;ST[s].ls=++tot;ST[s].rs=++tot;
    buildT(L, mid, ST[s].ls);
    buildT(mid+1, R, ST[s].rs);
    ST[s].val=ST[ST[s].ls].val+ST[ST[s].rs].val;
}

void PushDown(int s)
{
    if(!ST[s].lazy) return ;
    ll k=ST[s].lazy;
    ST[ST[s].ls].lazy+=k;
    ST[ST[s].ls].val+=k*(ST[ST[s].ls].r-ST[ST[s].ls].l+1);
    ST[ST[s].rs].lazy+=k;
    ST[ST[s].rs].val+=k*(ST[ST[s].rs].r-ST[ST[s].rs].l+1);
    ST[s].lazy=0;
    return ;
}

void add(int L, int R, int k, int s=1)
{
    if(L<=ST[s].l  && R>=ST[s].r)
    {
        ST[s].lazy+=k;
        ST[s].val+=k*(ST[s].r-ST[s].l+1);
        return ;
    }
    int mid=(ST[s].l+ST[s].r)>>1;
    PushDown(s);
    if(L<=mid)         add(L, R, k, ST[s].ls);
    if(R>=mid+1)    add(L, R, k, ST[s].rs);
    ST[s].val=ST[ST[s].ls].val+ST[ST[s].rs].val;
}

ll query(int L, int R, int s=1)
{
    if(L<=ST[s].l && R>=ST[s].r)
    {
        return ST[s].val;
    }
    int mid=(ST[s].l+ST[s].r)>>1;ll ans=0;
    PushDown(s);
    if(L<=mid)        ans+=query(L, R, ST[s].ls);
    if(R>=mid+1)    ans+=query(L, R, ST[s].rs);
    return ans;
}

int main()
{
    int n=read();int m=read();
    buildT(1, n);
    for(int i=1; i<=m; i++)
    {
        int a=read();int b=read();int c=read();
        if(a==1)
        {
            int d=read();
            add(b, c, d);
        }
        else
        {
            printf("%lld\n", query(b, c));
        }
    }
    return 0;
}

  希望可以帮助到大家,如果还有些不懂的,可以留言,我会尽量帮大家解决。

posted @ 2019-07-15 20:05  king_kb  阅读(173)  评论(0编辑  收藏  举报