可持久化线段树

  1.1可持久化

  顾名思义,数据的可持久化就是不仅能访问该文件的当前版本,也能访问该文件的历史版本,比较常见的应用就是撤销了,而本篇博文要写的就是线段树可持久化的实现。如果没学过线段树的话请先学习线段树以及线段树的动态开点

  1.2实现原理

  考虑最简单的情况:

                                       

  这是一棵非常简单的线段树,设为0号,现在,我们要改变点1的值,(即节点2),生成1号的历史版本,该怎么做呢?

  考虑到改变一个点,只要将该点和它所有的祖宗改一下就可以了,其他的节点不用变,如下所示:

                                    

  该图中,我们新增了5号节点代表点1的值,即虽然2号,5号的区间一样,但值不一样,新增了4号点作为新版本中的根节点,由于4号点的右孩子没改,所以没有新增节点,在新版本中代表区间2的节点依然是3号点。

  读者们也注意到了,图中新增了两个带箭头的数:0、1,它们代表了历史版本所指向的根节点,这样就能锁定该历史版本的整棵树了。

  那么,当树更复杂呢?我们可以考虑一下,将1号,4号节点视为一棵线段树的子树的根,而指向它们的是它们父亲节点,再将2、3节点视为1节点的子树,同样可以求解。

  2、代码实现

  这次总算有【模板】可持久化数组(可持久化线段树/平衡树)了,不用博主自己编题了……

  这可能要动态开点线段树的知识。

  该题中,我们只要实现单点修改,单点查询,我们需要两个数组记录该节点的左右儿子,若要修改的值在左儿子,那么该节点的右儿子就是模式版本的右儿子,再做左儿子就行了。

void jia(long long mo,long long u,long long x,long long k)//mo:模式版本的代表该区间的节点,u:我们要构造的节点,x:区间位置,k:要修改成的值
{
    l[u]=l[mo];
    r[u]=r[mo];
    if (l[u]==r[u])
    {
        z[u]=k;
        return;
    }//这个和build很像
    if (x<=(l[u]+r[u])/2)//x在左儿子
    {
        rr[u]=rr[mo];//右边和模式版本一样
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],x,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else
    {
        ll[u]=ll[mo];//左边就和模式版本一样
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],x,k);
        z[u]=z[ll[u]]+z[rr[u]];//x在右儿子
} }

  单点修改算是可持久化的精髓,以下就是完整代码:

#include<bits/stdc++.h>
using namespace std;
long long l[40000001],r[40000001],z[40000001],ll[40000001],rr[40000001],cnt,n,m,i,loc,v,val,q,a[1000001],he[1000001];//空间可不止4倍,博主比较懒,开了40倍。
void build(long long u,long long l1,long long r1)
{
    l[u]=l1;
    r[u]=r1;
    if (l1==r1)
    {
        z[u]=a[l1];
        return;
    }
    cnt++;
    ll[u]=cnt;//动态开点左区间
    build(ll[u],l1,(l1+r1)/2);
    cnt++;
    rr[u]=cnt;//动态开点右区间
    build(rr[u],(l1+r1)/2+1,r1);
    z[u]=z[ll[u]]+z[rr[u]];//其实这个用不到
}
void jia(long long mo,long long u,long long x,long long k)//单点修改
{
    l[u]=l[mo];
    r[u]=r[mo];
    if (l[u]==r[u])
    {
        z[u]=k;
        return;
    }
    if (x<=(l[u]+r[u])/2)
    {
        rr[u]=rr[mo];
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],x,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else
    {
        ll[u]=ll[mo];
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],x,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
}
long long qui(long long u,long long x)//这和普通线段树的区别就是儿子一个是ll[u],rr[u],一个是u*2,u*2+1。
{
    if (l[u]>x||r[u]<x) return 0;
    if (l[u]==r[u]) return z[u];
    else return qui(ll[u],x)+qui(rr[u],x);
}
int main()
{
    he[0]=1;
    cnt=1;//其实这个也是要不要无所谓……
    scanf("%lld%lld",&n,&m);
    for (i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    build(1,1,n);
    for (i=1;i<=m;i++)
    {
        scanf("%lld%lld",&v,&q);;
        if (q==1)
        {
            cnt++;
            he[i]=cnt;//新增一个版本
            scanf("%lld%lld",&loc,&val);
            jia(he[v],he[i],loc,val);
        }
        else
        {
            scanf("%lld",&loc);
            he[i]=he[v];。。直接就将新版本指向模式版本的根节点
            printf("%lld\n",qui(he[v],loc));
        }
    }
    return 0;
}

  这道题还是相对简单的。接下来难的可就来了。

  3、进阶运用

  博主总算把这道题编好了:【模板】可持久化线段树(标准版),太难编了。

  在这道题中,我们发现从单点修改到了区间修改!我们对区间每一点赋值肯定是不现实的,所以我们就需要一些类似线段树的做法:

 

void xiafang(long long u)//类似线段树的标记下传
{
    cnt++;
    ll[u]=cnt;
    l[cnt]=l[u];
    r[cnt]=(l[u]+r[u])/2;
    z[cnt]=c[u]*(r[cnt]-l[cnt]+1);//实际上是新造了它的左、右儿子
    c[cnt]=c[u];
    cnt++;
    rr[u]=cnt;
    l[cnt]=(l[u]+r[u])/2+1;
    r[cnt]=r[u];
    z[cnt]=c[u]*(r[cnt]-l[cnt]+1);//标记也要下传
    c[cnt]=c[u];
    c[u]=0;
}

 

  然后,就是区间加了:

void jia(long long mo,long long u,long long l1,long long r1,long long k)
{
    l[u]=l[mo];
    r[u]=r[mo];
    if (l[u]>=l1&&r[u]<=r1)//如果完全在赋值区间就不继续了
    {
        z[u]=k*(r[u]-l[u]+1);
        c[u]=k;//打上标记
        return;
    }
    if (c[mo]) xiafang(mo);
    if (r1<=(l[u]+r[u])/2)//在右边
    {
        rr[u]=rr[mo];
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else if (l1>(l[u]+r[u])/2)//在左边
    {
        ll[u]=ll[mo];
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else
    {
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],l1,r1,k);//两边都有
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
}

  因为所有数字小于1e8,所以最大时不会爆long long的。

  以下是完整代码:

#include<bits/stdc++.h>
using namespace std;
long long c[40000001],l[40000001],r[40000001],z[40000001],ll[40000001],rr[40000001],cnt,n,m,i,l1,r1,v,val,q,a[1000001],he[1000001];
void xiafang(long long u)
{
    cnt++;
    ll[u]=cnt;
    l[cnt]=l[u];
    r[cnt]=(l[u]+r[u])/2;
    z[cnt]=c[u]*(r[cnt]-l[cnt]+1);
    c[cnt]=c[u];
    cnt++;
    rr[u]=cnt;
    l[cnt]=(l[u]+r[u])/2+1;
    r[cnt]=r[u];
    z[cnt]=c[u]*(r[cnt]-l[cnt]+1);
    c[cnt]=c[u];
    c[u]=0;
}
void build(long long u,long long l1,long long r1)
{
    l[u]=l1;
    r[u]=r1;
    if (l1==r1)
    {
        z[u]=a[l1];
        return;
    }
    cnt++;
    ll[u]=cnt;
    build(ll[u],l1,(l1+r1)/2);
    cnt++;
    rr[u]=cnt;
    build(rr[u],(l1+r1)/2+1,r1);
    z[u]=z[ll[u]]+z[rr[u]];
}
void jia(long long mo,long long u,long long l1,long long r1,long long k)
{
    l[u]=l[mo];
    r[u]=r[mo];
    if (l[u]>=l1&&r[u]<=r1)
    {
        z[u]=k*(r[u]-l[u]+1);
        c[u]=k;
        return;
    }
    if (c[mo]) xiafang(mo);
    if (r1<=(l[u]+r[u])/2)
    {
        rr[u]=rr[mo];
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else if (l1>(l[u]+r[u])/2)
    {
        ll[u]=ll[mo];
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
    else
    {
        cnt++;
        ll[u]=cnt;
        jia(ll[mo],ll[u],l1,r1,k);
        cnt++;
        rr[u]=cnt;
        jia(rr[mo],rr[u],l1,r1,k);
        z[u]=z[ll[u]]+z[rr[u]];
    }
}
long long qui(long long u,long long l1,long long r1)
{
    if (l[u]>r1||r[u]<l1) return 0;
    if (l[u]>=l1&&r[u]<=r1) return z[u];
    if (c[u]) xiafang(u);
    return (qui(ll[u],l1,r1)+qui(rr[u],l1,r1))%998244353;
}
int main()
{
    he[0]=1;
    cnt=1;
    scanf("%lld%lld",&n,&m);
    for (i=1;i<=n;i++)
        scanf("%lld",&a[i]);
    build(1,1,n);
    for (i=1;i<=m;i++)
    {
        scanf("%lld%lld",&v,&q);;
        if (q==1)
        {
            cnt++;
            he[i]=cnt;
            scanf("%lld%lld%lld",&l1,&r1,&val);
            jia(he[v],he[i],l1,r1,val);
        }
        else
        {
            scanf("%lld%lld",&l1,&r1);
            he[i]=he[v];
            printf("%lld\n",qui(he[v],l1,r1)%998244353);
        }
    }
    return 0;
}

 

posted @ 2019-08-05 10:25  冰逝  阅读(634)  评论(0编辑  收藏  举报