可持久化线段树

少写了一点,可持久化的好处就是可以用较低的代价去得到可以变换版本这一功能。

可持久化线段树(主席树) 带注释的代码

/*
    注意, 可持久化线段树很难支持区间修改, 一般涉及区间修改的时候不用
    单点修改是可以的
    
    一样, 直接选这题不大好, 看下面的通用模版, 具有通用性, 不要被这题禁锢了思想 
    
    思路还算简单
    首先你需要一个线段树的框架, 即root[0], 因为可持久化线段树基本不改变其框架
    只改变其中的信息, 比如最大值最小值, 这里线段树可以存下标范围, 也可以存值域, 只要线段树的框架不变就行了
    但是这时候线段树里面的l, r 就存的是它的左右儿子, 而非, 左右边界, 相当于指针(一般用数组进行模拟), 这个点的左右边界可以通过递归传下来
    
    可持久化的思路, 每次只修改新加入的点(对与上一个版本), 因此, 我们可以先复制一份
    一模一样的, 即tr[q] = tr[p] (p是上一版本, q是这一版本)
    这时候只修改要修改的(比如这个点左儿子的信息, 右儿子不变)就行了, 这样依旧可以搜到之前版本中没修改的点, 也不会更改之前版本
    对于修改的点, 每次新创建一个新点(每个版本最多会创建logn个点, 加上之前的骨架, 一共就是(n * 4 + nlogn)个点)
    
    对于这题
    首先得保证线段树的结构不变, 因为没有修改, 所以结构肯定的不变,
    但是为了好算第k小数, 我们可持久化线段树存的就不是数组的下标范围,
    而是一个值域, 离散化后的值域, 因为点的数量n是固定的, 那么这个值域(离散化后的)就是固定的
    就不会改变这个线段树的框架, 同时保序离散化后, 是从小到大排的, 这就方便我们进行求第k小数
    
    这题线段树里面多存一个cnt, 表示这个区间里面的数量, 因为这里保序离散化了, 
    如果一个点左儿子内的数量大于等于k, 那么总体第k小数, 就在左儿子里面, 否则就在右儿子里面
    
    这两句引用我代码里的注释
    如果左儿子cnts >= k, 那么总体第k小的, 在左儿子里还是第k小的
    如果左儿子cnts不满足k个, 说明第k小肯定不在左儿子里面, 就去右儿子里面找, 并且, 因为左儿子里面有cnts个数, 那么在右儿子里面, 总体第k小数, 应该是第k - cnts小的数
    
    而对于限制[l, r], 对于r可以直接在第r个版本去搜数量
    对于[l, r], 可以使用前缀和的思想, 在第l - 1个版本里面的这个区间的数量cnt1, 第r个版本的数量cnt2, cnt2 - cnt1就是[l, r]版本内这个区间新加的数量
    也就是第[l, r]个数直接的数量, 这样区间限制就解决了(注意可持久化线段树, 总体框架不变, 每个点在不同版本都有对应的)
    
    至此此题结束
    
    具体见代码
*/
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010, M = 10010;

int n, m;
int id[N], w[N], idx, cnt; 
int root[N]; // 每个版本的入口

struct Node // 实际上线段树每个节点存的是一个"值"域
{
    int l, r; // 这里的l和r, 不是左右边界, 是左右儿子的下标(idx)
    int cnt; // 每个值域里面数的数量
}tr[N * 4 + 17 * N]; // logn大约是17,

int find(int x) // 保序离散化
{
    int l = 1, r = cnt;
    while (l < r)
    {
        int mid = l + r >> 1;
        if (x <= id[mid]) r = mid;
        else l = mid + 1;
    }
    
    return l;
}

int build(int l, int r) // 建立基本骨架
{
    int p = ++ idx;
    if (l == r) return p;
    
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    
    return p;
}

int insert(int p, int l, int r, int x)
{
    int q = ++ idx;
    tr[q] = tr[p]; // 复制这一点, 注意这里tr[p]内已经有值了
    if (l == r) // 说明搜到这一点了, 把这一点的数量++
    {
        tr[q].cnt ++ ;
        return q;
    }
    
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x); // 更改左儿子
    else tr[q].r = insert(tr[p].r, mid + 1, r, x); // 更改右儿子
    
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; // 计算cnt, 这里可以写一个pushup
    
    return q;
}


int query(int p, int q, int l, int r, int k)
{
    int cnts = tr[tr[p].l].cnt - tr[tr[q].l].cnt; // 这里我命名的是cnts, 别写错了
    int mid = l + r >> 1;
    if (l == r) return l; // 说明找到这个点了, 返回的是l, 是第几个点
    
    if (cnts >= k) return query(tr[p].l, tr[q].l, l, mid, k); // 如果cnts >= k, 那么总体第k小的, 在左儿子里还是第k小的
    else query(tr[p].r, tr[q].r, mid + 1, r, k - cnts); // 如果cnts不满足k个, 说明第k小肯定不在左儿子里面, 就去右儿子里面找, 并且, 因为左儿子里面有cnts个数, 那么在右儿子里面, 总体第k小数, 应该是第k - cnts小的数
}

int main()
{
    cin >> n >> m; 
    
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]), id[ ++ cnt] = w[i];
    
    sort(id + 1, id + 1 + cnt);
    cnt = unique(id + 1, id + 1 + cnt) - id - 1; // 判重
    
    root[0] = build(1, cnt); // 初始框架

    for (int i = 1; i <= n; i ++ ) root[i] = insert(root[i - 1], 1, cnt, find(w[i]));  // 每个版本
    
    while (m -- )
    {
        int l, r, k;
        scanf("%d%d%d", &l, &r, &k);
        printf("%d\n", id[query(root[r], root[l - 1], 1, cnt, k)]); // 注意每次输出的是原数而不是离散化之后的
    }
    
    return 0;
}


通用模版

洛谷P3919

/*
    不容易, 算是打出来模版了
    还是模版好理解一些, 上来就第k小数, 其实并不能真正了解这个
    这个模版可以让你更加细致了理解可持久化线段树
    洛谷P3919 【模板】可持久化线段树 1(可持久化数组)
    
    根据题意线段树里面随便存个值就行, 最大值也行, 最小也行, 只存叶节点的也行
    刚开始的时候给我看傻了, 后边一想, 好像存什么都行, 这里存最大值
    
    为什么可持久化线段树里面必须存左右儿子下标?
    可持久化线段树里面的l和r都是代表左右儿子, 为什么, 因为可持久化线段树里面有很多版本, 每个版本有新点, 
    新点会打乱下标顺序, 就不能通过堆的方式来找到左右儿子, 所以要存左右儿子的下标来找到它们, 这是必须存的
    而左右边界, 你可以在结构体另开新的变量存, 也可以通过递归传下去, 这里使用递归传下去, 思想要打开
    
    注意题意, 剩下看模版把, 其实和普通线段树差不太多
*/
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstdio>

using namespace std;

const int N = 1000010; 

int n, m, cnt;
int root[N], w[N], idx;

struct Node
{
    int l, r; // 左右儿子
    int v; // 这个区间的最大数
}tr[N * 4 + N * (int)ceil(log(N) / log(2))];

void pushup(int p)
{
    tr[p].v = max(tr[tr[p].l].v, tr[tr[p].r].v);
}

int build(int l, int r) // 这里题目中给出了初始化版本了, 所以初始版本要更新最大值
{
    int p = ++ idx;
    if (l == r) 
    {
        tr[p].v = w[l]; // 叶节点存当前值
        return p;
    }
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    pushup(p); // 就是不pushup也可以过掉, 注意上面说的, 不pushup还快, 被卡了可以试试a
    return p;
}

int insert(int p, int l, int r, int x, int k)
{
    int q = ++ idx;  // 开新点
    tr[q] = tr[p]; // 复制
    
    if (l == r) 
    {
        tr[q].v = k;
        return q;
    }
    
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x, k);
    else tr[q].r = insert(tr[p].r, mid + 1, r, x, k);
    pushup(q); // 就是不pushup也可以过掉
    return q;
}

int query(int p, int l, int r, int x)
{
    if (l == r) return tr[p].v;
    int mid = l + r >> 1;
    
    if (x <= mid) return query(tr[p].l, l, mid, x);
    else return query(tr[p].r, mid + 1, r, x);
}

int main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    
    root[0] = build(1, n); // 初始版本
    
    while (m -- )
    {
        int v, op, d, x;
        scanf("%d%d", &v, &op);
        if (op == 1) 
        {
            scanf("%d%d", &x, &d);
            root[ ++ cnt] = insert(root[v], 1, n, x, d); // 新开版本
        }
        else
        {
            scanf("%d", &x);
            root[ ++ cnt] = root[v]; // 根据题意新开版本
            printf("%d\n", query(root[v], 1, n, x));
        }
    }
    return 0;
}

截图

屏幕截图 2023-10-30 215337.png
屏幕截图 2023-10-31 075836.png
屏幕截图 2023-10-31 075855.png

对于这题的另一种做法

自己想出来的,感觉要容易想到,时间上要比y的慢一倍。大体思想就是,我们从小到大依次加入一个数,每加入一个就记录一个版本,线段树里记录区间里数的数量,在查询时,只要二分出区间数的数量大于等于k的最小版本即可,这个版本对应插入的点就是要求的第 k 小点,时间复杂度是 \(O(n\log^2n)\) 的和 y 是一个量级的,可能是由于常数问题,所以运行上要慢。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 100010;

int n, m;
int idx, root[N], cnt;
int g[N];

struct node
{
    int v, id;
    bool operator<(const node &W)const
    {
        return v < W.v;
    }
}a[N];

struct Node
{
    int l, r;
    int v, sum = 0;
}tr[N * 4 + N * (int)ceil(log2(N))];

void pushup(int u)
{
    int &l = tr[u].l, &r = tr[u].r;
    tr[u].sum = tr[l].sum + tr[r].sum;
}

int build(int l, int r)
{
    int p = ++ idx;
    if (l == r)
    {
        tr[p].v = -0x3f3f3f3f;
        tr[p].sum = 0;
        return p;
    }
    int mid = l + r >> 1;
    tr[p].l = build(l, mid);
    tr[p].r = build(mid + 1, r);
    pushup(p);
    return p;
}

int insert(int p, int l, int r, int x, int k)
{
    int q = ++ idx;
    tr[q] = tr[p];
    if (l == r)
    {
        tr[q].v = k;
        if (k > -0x3f3f3f3f) tr[q].sum = 1;
        return q;
    }
    int mid = l + r >> 1;
    if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x, k);
    else tr[q].r = insert(tr[p].r, mid + 1, r, x, k);
    pushup(q);
    return q;
}

int query(int p, int l, int r, int x, int y)
{
    if (x <= l && r <= y) return tr[p].sum;
    
    int mid = l + r >> 1;
    int sum = 0;
    if (x <= mid) sum += query(tr[p].l, l, mid, x, y);
    if (y > mid) sum += query(tr[p].r, mid + 1, r, x, y);
    
    return sum;
}

bool check(int x, int l, int r, int k)
{
    return query(root[x], 1, n, l, r) >= k;
}

int main()
{
    cin >> n >> m;
    
    root[0] = build(1, n);
    for (int i = 1; i <= n; i ++ ) 
    {
        int x;
        scanf("%d", &x);
        a[i] = {x, i};
        g[i] = x;
    }
    
    sort(a + 1, a + n + 1);
    
    for (int i = 1; i <= n; i ++ ) 
    {
        root[i] = insert(root[i - 1], 1, n, a[i].id, a[i].v);
        // cout << i << endl;
    }
    
    while (m -- )
    {
        int ls, rs, k;
        scanf("%d%d%d", &ls, &rs, &k);
        
        int l = 0, r = n, mid;
        while (l < r)
        {
            mid = l + r >> 1;
            if (check(mid, ls, rs, k)) r = mid;
            else l = mid + 1;
        }
        
        printf("%d\n", a[l].v);
    }
    
    // cout << query(root[5], 1, n, 2, 5);
    
    
    return 0;
    
}
posted @ 2024-11-09 21:18  blind5883  阅读(2)  评论(0编辑  收藏  举报