洛谷题单指南-线段树的进阶用法-P3380 【模板】树套树

原题链接:https://www.luogu.com.cn/problem/P3380

题意解读:对于序列,实现5个操作:1.查询区间[l,r]范围数值k的排名 2.查询区间[l,r]范围第k小值 3.单调修改某一个位置的值 4.查询区间[l,r]范围数值k的前驱 5.查询区间[l,r]范围数值k的后继。

解题思路:

要实现区间第k小的查询,借助于可持久化线段树的思想,首先考虑建立n+1棵权值线段树,树的根节点为root[N],树root[i]代表序列基于a[1]~a[i]建立的权值线段树,接下来看如何利用这n+1棵权值线段树实现5种操作:

1、查k在区间[l,r]的排名

区间[l,r]范围内k的排名 = [l,r]范围1~k-1的个数 + 1

借助于前缀和思想,[l,r]范围1~k-1的个数 = root[r]树中1~k-1的个数 - root[l-1]树中1~k-1的个数

时间复杂度:O(logn)

2、查区间[l,r]第k小值

先看[l,r]范围内所有左子树元素个数leftcnt,leftcnt = root[r]左子树元素个数 - root[l-1]左子树元素个数

如果leftcnt >= k,则继续递归在所有左子树查询,否则递归在右子树查询,直到叶子节点即找到第k小值。

时间复杂度:O(logn)

3、单点更新

要修改a[pos]的值,就要将root[pos]~root[n]的所有权值线段树都更新,将原来的值个数-1,将新值的个数加1。

时间复杂度:O(nlogn),这里显然不符合要求。

4、查k在区间[l,r]的前驱

借助于1/2操作,先查k在[l,r]的排名rank,在查[l,r]范围第rank-1小的值,即为前驱。

5、查k在区间[l,r]的后继

借助于1/2操作,先查k+1在[l,r]的排名rank,再查[l,r]范围低rank小值,即为后继。

注意:要对所有涉及到的元素值进行离散化处理。

基于以上分析,总体时间复杂度在单点更新时不满足要求,基于前缀和的思考,单点更新a[pos]会影响root[pos]~root[n]所有线段树。

如何进行优化呢?要优化前缀和以及单点修改,可以想到树状数组!

我们用树状数组来维护所有权值线段树的根节点root[N],

当修改一个值a[pos]的时候,只需要修改logn个线段树即可:

for(int i = pos; i <= n; i += lowbit(i))
  修改root[i]中的值

当要查询区间[l,r]范围的元素时,

可以先查询[1~r]范围所有的值累加起来:

for(int i = r; i; i -= lowbit(i))
  查询root[i]中的值并累加

再查询[l~l-1]范围所有的值累加起来:

for(int i = l-1; i; i -= lowbit(i))
  查询root[i]中的值并累加

再将[1~r]的值减[1~l-1]的值,即得[l,r]的值

因此,修改和查询的时间复杂度都在O(n*logn*logn)。

到这里,可以揭示一下树状数组存的信息到底是什么:

树状数组维护的是线段树的根节点数组root[N],具体到某一棵权值线段树root[i],存的是[i-lowbit(i)+1, i]区间范围内所有的元素值的个数。

下面介绍5种操作的具体实现:

1、查k在区间[l,r]的排名

先将for(int i = r; i; i -= lowbit(i))的每棵线段树中1~k-1的元素个数查出来累加

再将for(int i = l - 1; i; i -= lowbit(i))的每棵线段树中1~k-1的元素个数查出来累加

以上两者相减,得到[l,r]范围1~k-1的个数,再加1即得k的排名

//在根为u的线段树中查询符合x~y之间元素的个数
int query_cnt(int u, int l, int r, int x, int y)
{
    if(l >= x && r <= y) 
    {
        return tr[u].cnt;
    }
    else if(l > y || r < x) return 0;
    else 
    {
        int mid = l + r >> 1;
        return query_cnt(tr[u].L, l, mid, x, y) + query_cnt(tr[u].R, mid + 1, r, x, y);
    }
}
//查询[l,r]范围内x的排名
//查询x的排名,就是查询root[l]~root[r]的线段树中1~x-1有多少个,然后加1
int find_rank(int l, int r, int x)
{
    int sum = 0;
    //利用数状数组查询root[1]~root[r]的线段树范围内元素1~x-1的数量
    //实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
    for(int i = r; i; i -= lowbit(i)) sum += query_cnt(root[i], 1, b.size(), 1, x - 1);
    //利用数状数组查询root[1]~root[l-1]的线段树范围内元素1~x-1的数量
    //实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
    for(int i = l - 1; i; i -= lowbit(i)) sum -= query_cnt(root[i], 1, b.size(), 1, x - 1);
    return sum + 1;
}

2、查区间[l,r]第k小值

这里的查询和1有所不同,由于需要整体判断所有左子树节点数,因此要先将树状数组操作中所涉及到的线段树根节点缓存下来,然后再批量查询左子树节点数并累加,具体划分两个函数:

//在tempr、templ保存的线段树中查找符合find_kth里指定的范围的第k小
int query_kth(int l, int r, int k)
{
    if(l == r) return l;
    int leftcnt = 0;
    for(int i = 1; i <= cntr; i++) leftcnt += tr[tr[tempr[i]].L].cnt;
    for(int i = 1; i <= cntl; i++) leftcnt -= tr[tr[templ[i]].L].cnt;
    int mid = l + r >> 1;
    if(k <= leftcnt)
    {
        //所有涉及线段树往左子树递归查找第k小,暂存所有左子结点
        for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].L;
        for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].L;
        return query_kth(l, mid, k);
    }
    else
    {
        //所有涉及线段树往右子树递归查找第leftcnt-k小,暂存所有右子结点
        for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].R;
        for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].R;
        return query_kth(mid + 1, r, k - leftcnt);
    }
}

//查询[l,r]范围第k小元素值
int find_kth(int l, int r, int k)
{
    //利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
    //实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
    //先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
    cntr = 0;
    for(int i = r; i; i -= lowbit(i))  tempr[++cntr] = root[i];
    //利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
    //实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
    //先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
    cntl = 0;
    for(int i = l - 1; i; i -= lowbit(i)) templ[++cntl] = root[i];
    return query_kth(1, b.size(), k);
}

3、单点更新

通过树状数组定位到要影响的线段树根节点,然后对线段树进行更新:

//将根节点是pre的权值线段树,通过节点复制,将值为x的节点个数cnt增加v
int update(int pre, int l, int r, int x, int v)
{
    int u = ++idx;
    tr[u].L = tr[pre].L;
    tr[u].R = tr[pre].R;
    tr[u].cnt = tr[pre].cnt + v;
    if(l == r) return u;
    int mid = l + r >> 1;
    if(x <= mid) tr[u].L = update(tr[u].L, l, mid, x, v);
    else tr[u].R = update(tr[u].R, mid + 1, r, x, v);
    return u;
}

//利用树状数组将第x个元素的值v的个数cnt加add,影响到线段树root[x]、root[x+lowbit(x)] ...
void add(int x, int v, int add)
{
    for(int i = x; i <= n; i += lowbit(i))
    {
        root[i] = update(root[i], 1, b.size(), v, add);
        //printf("root[%d]:%d 线段树\n", i, root[i]);
    }
}

4、查k在区间[l,r]的前驱

//查询[l,r]范围内x的前驱
int find_pre(int l, int r, int x)
{
    //先查x的排名
    int rank = find_rank(l, r, x);
    if(rank == 1) return -INF; //没有前驱
    //第rank-1小的值就是x的前驱
    return b[find_kth(l, r, rank - 1) - 1];
}

5、查k在区间[l,r]的后继

//查询[l,r]范围内x的后继
int find_next(int l, int r, int x)
{
    //先查x+1的排名
    int rank = find_rank(l, r, x + 1);
    if(rank > r - l + 1) return INF; //后继不存在
    //第rank小的值就是x的后继
    return b[find_kth(l, r, rank) - 1];
}

分析一下空间复杂度,初始每个元素都会添加到logn棵线段树,每次复制涉及logn个节点,一共n次;更新操作一共可能m次,每次涉及logn棵线段树,每次复制涉及logn个节点;因此总空间再n*logn*logn+m*log*logn = 2n*logn*logn,n最大5000,线段树节点空间可以设为n*600。

100分代码:

#include <bits/stdc++.h>
using namespace std;

const int N = 50005, INF = 2147483647;

struct Node
{
    int L, R; //左、右子节点编号
    int cnt; //节点所表示值域区间[l,r]的元素个数
} tr[N * 600];

struct Op
{
    int opt;
    int l, r, pos, k;
} ops[N]; //所有操作

int root[N], idx; //root:所有根节点,idx:节点编号
int a[N]; //原序列
vector<int> b; //用于离散化
int tempr[N], cntr, templ[N], cntl; 
int n, m;

int lowbit(int x)
{
    return x & -x;
}

//查询x离散化之后的值
int lsh(int x)
{
    return lower_bound(b.begin(), b.end(), x) - b.begin() + 1;
}

//将根节点是pre的权值线段树,通过节点复制,将值为x的节点个数cnt增加v
int update(int pre, int l, int r, int x, int v)
{
    int u = ++idx;
    tr[u].L = tr[pre].L;
    tr[u].R = tr[pre].R;
    tr[u].cnt = tr[pre].cnt + v;
    if(l == r) return u;
    int mid = l + r >> 1;
    if(x <= mid) tr[u].L = update(tr[u].L, l, mid, x, v);
    else tr[u].R = update(tr[u].R, mid + 1, r, x, v);
    return u;
}

//在tempr、templ保存的线段树中查找符合find_kth里指定的范围的第k小
int query_kth(int l, int r, int k)
{
    if(l == r) return l;
    int leftcnt = 0;
    for(int i = 1; i <= cntr; i++) leftcnt += tr[tr[tempr[i]].L].cnt;
    for(int i = 1; i <= cntl; i++) leftcnt -= tr[tr[templ[i]].L].cnt;
    int mid = l + r >> 1;
    if(k <= leftcnt)
    {
        //所有涉及线段树往左子树递归查找第k小,暂存所有左子结点
        for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].L;
        for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].L;
        return query_kth(l, mid, k);
    }
    else
    {
        //所有涉及线段树往右子树递归查找第leftcnt-k小,暂存所有右子结点
        for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].R;
        for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].R;
        return query_kth(mid + 1, r, k - leftcnt);
    }
}

//在根为u的线段树中查询符合x~y之间元素的个数
int query_cnt(int u, int l, int r, int x, int y)
{
    if(l >= x && r <= y) 
    {
        return tr[u].cnt;
    }
    else if(l > y || r < x) return 0;
    else 
    {
        int mid = l + r >> 1;
        return query_cnt(tr[u].L, l, mid, x, y) + query_cnt(tr[u].R, mid + 1, r, x, y);
    }
}

//利用树状数组将第x个元素的值v的个数cnt加add,影响到线段树root[x]、root[x+lowbit(x)] ...
void add(int x, int v, int add)
{
    for(int i = x; i <= n; i += lowbit(i))
    {
        root[i] = update(root[i], 1, b.size(), v, add);
        //printf("root[%d]:%d 线段树\n", i, root[i]);
    }
        
}

//查询[l,r]范围第k小元素值
int find_kth(int l, int r, int k)
{
    //利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
    //实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
    //先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
    cntr = 0;
    for(int i = r; i; i -= lowbit(i))  tempr[++cntr] = root[i];
    //利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
    //实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
    //先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
    cntl = 0;
    for(int i = l - 1; i; i -= lowbit(i)) templ[++cntl] = root[i];
    return query_kth(1, b.size(), k);
}

//查询[l,r]范围内x的排名
//查询x的排名,就是查询root[l]~root[r]的线段树中1~x-1有多少个,然后加1
int find_rank(int l, int r, int x)
{
    int sum = 0;
    //利用数状数组查询root[1]~root[r]的线段树范围内元素1~x-1的数量
    //实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
    for(int i = r; i; i -= lowbit(i)) sum += query_cnt(root[i], 1, b.size(), 1, x - 1);
    //利用数状数组查询root[1]~root[l-1]的线段树范围内元素1~x-1的数量
    //实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
    for(int i = l - 1; i; i -= lowbit(i)) sum -= query_cnt(root[i], 1, b.size(), 1, x - 1);
    return sum + 1;
}

//查询[l,r]范围内x的前驱
int find_pre(int l, int r, int x)
{
    //先查x的排名
    int rank = find_rank(l, r, x);
    if(rank == 1) return -INF; //没有前驱
    //第rank-1小的值就是x的前驱
    return b[find_kth(l, r, rank - 1) - 1];
}

//查询[l,r]范围内x的后继
int find_next(int l, int r, int x)
{
    //先查x+1的排名
    int rank = find_rank(l, r, x + 1);
    if(rank > r - l + 1) return INF; //后继不存在
    //第rank小的值就是x的后继
    return b[find_kth(l, r, rank) - 1];
}

int main()
{
    cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
    cin >> n >> m;
    for(int i = 1; i <= n; i++) 
    {
        cin >> a[i];
        b.push_back(a[i]);
    }
    for(int i = 1; i <= m; i++)
    {
        cin >> ops[i].opt;
        if(ops[i].opt == 3) cin >> ops[i].pos >> ops[i].k;
        else cin >> ops[i].l >> ops[i].r >> ops[i].k;

        if(ops[i].opt != 2) b.push_back(ops[i].k); //将值加入b进行离散化
    }

    //排序去重离散化
    sort(b.begin(), b.end());
    b.erase(unique(b.begin(), b.end()), b.end());
    
    //将序列离散化后的值利用树状数组构建权值线段树
    for(int i = 1; i <= n; i++) add(i, lsh(a[i]), 1);
    
    for(auto o : ops)
    {
        if(o.opt == 1) cout << find_rank(o.l, o.r, lsh(o.k)) << endl;
        else if(o.opt == 2) cout << b[find_kth(o.l, o.r, o.k) - 1] << endl;
        else if(o.opt == 3)
        {
            add(o.pos, lsh(a[o.pos]), -1);
            a[o.pos] = o.k;
            add(o.pos, lsh(a[o.pos]), 1);
        }
        else if(o.opt == 4) cout << find_pre(o.l, o.r, lsh(o.k)) << endl;
        else if(o.opt == 5) cout << find_next(o.l, o.r, lsh(o.k)) << endl;
    }

    return 0;
}

 

posted @ 2025-01-03 16:18  五月江城  阅读(9)  评论(0编辑  收藏  举报