洛谷题单指南-线段树的进阶用法-P3380 【模板】树套树
原题链接:https://www.luogu.com.cn/problem/P3380
题意解读:对于序列,实现5个操作:1.查询区间[l,r]范围数值k的排名 2.查询区间[l,r]范围第k小值 3.单调修改某一个位置的值 4.查询区间[l,r]范围数值k的前驱 5.查询区间[l,r]范围数值k的后继。
解题思路:
要实现区间第k小的查询,借助于可持久化线段树的思想,首先考虑建立n+1棵权值线段树,树的根节点为root[N],树root[i]代表序列基于a[1]~a[i]建立的权值线段树,接下来看如何利用这n+1棵权值线段树实现5种操作:
1、查k在区间[l,r]的排名
区间[l,r]范围内k的排名 = [l,r]范围1~k-1的个数 + 1
借助于前缀和思想,[l,r]范围1~k-1的个数 = root[r]树中1~k-1的个数 - root[l-1]树中1~k-1的个数
时间复杂度:O(logn)
2、查区间[l,r]第k小值
先看[l,r]范围内所有左子树元素个数leftcnt,leftcnt = root[r]左子树元素个数 - root[l-1]左子树元素个数
如果leftcnt >= k,则继续递归在所有左子树查询,否则递归在右子树查询,直到叶子节点即找到第k小值。
时间复杂度:O(logn)
3、单点更新
要修改a[pos]的值,就要将root[pos]~root[n]的所有权值线段树都更新,将原来的值个数-1,将新值的个数加1。
时间复杂度:O(nlogn),这里显然不符合要求。
4、查k在区间[l,r]的前驱
借助于1/2操作,先查k在[l,r]的排名rank,在查[l,r]范围第rank-1小的值,即为前驱。
5、查k在区间[l,r]的后继
借助于1/2操作,先查k+1在[l,r]的排名rank,再查[l,r]范围低rank小值,即为后继。
注意:要对所有涉及到的元素值进行离散化处理。
基于以上分析,总体时间复杂度在单点更新时不满足要求,基于前缀和的思考,单点更新a[pos]会影响root[pos]~root[n]所有线段树。
如何进行优化呢?要优化前缀和以及单点修改,可以想到树状数组!
我们用树状数组来维护所有权值线段树的根节点root[N],
当修改一个值a[pos]的时候,只需要修改logn个线段树即可:
for(int i = pos; i <= n; i += lowbit(i))
修改root[i]中的值
当要查询区间[l,r]范围的元素时,
可以先查询[1~r]范围所有的值累加起来:
for(int i = r; i; i -= lowbit(i))
查询root[i]中的值并累加
再查询[l~l-1]范围所有的值累加起来:
for(int i = l-1; i; i -= lowbit(i))
查询root[i]中的值并累加
再将[1~r]的值减[1~l-1]的值,即得[l,r]的值
因此,修改和查询的时间复杂度都在O(n*logn*logn)。
到这里,可以揭示一下树状数组存的信息到底是什么:
树状数组维护的是线段树的根节点数组root[N],具体到某一棵权值线段树root[i],存的是[i-lowbit(i)+1, i]区间范围内所有的元素值的个数。
下面介绍5种操作的具体实现:
1、查k在区间[l,r]的排名
先将for(int i = r; i; i -= lowbit(i))的每棵线段树中1~k-1的元素个数查出来累加
再将for(int i = l - 1; i; i -= lowbit(i))的每棵线段树中1~k-1的元素个数查出来累加
以上两者相减,得到[l,r]范围1~k-1的个数,再加1即得k的排名
//在根为u的线段树中查询符合x~y之间元素的个数
int query_cnt(int u, int l, int r, int x, int y)
{
if(l >= x && r <= y)
{
return tr[u].cnt;
}
else if(l > y || r < x) return 0;
else
{
int mid = l + r >> 1;
return query_cnt(tr[u].L, l, mid, x, y) + query_cnt(tr[u].R, mid + 1, r, x, y);
}
}
//查询[l,r]范围内x的排名
//查询x的排名,就是查询root[l]~root[r]的线段树中1~x-1有多少个,然后加1
int find_rank(int l, int r, int x)
{
int sum = 0;
//利用数状数组查询root[1]~root[r]的线段树范围内元素1~x-1的数量
//实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
for(int i = r; i; i -= lowbit(i)) sum += query_cnt(root[i], 1, b.size(), 1, x - 1);
//利用数状数组查询root[1]~root[l-1]的线段树范围内元素1~x-1的数量
//实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
for(int i = l - 1; i; i -= lowbit(i)) sum -= query_cnt(root[i], 1, b.size(), 1, x - 1);
return sum + 1;
}
2、查区间[l,r]第k小值
这里的查询和1有所不同,由于需要整体判断所有左子树节点数,因此要先将树状数组操作中所涉及到的线段树根节点缓存下来,然后再批量查询左子树节点数并累加,具体划分两个函数:
//在tempr、templ保存的线段树中查找符合find_kth里指定的范围的第k小
int query_kth(int l, int r, int k)
{
if(l == r) return l;
int leftcnt = 0;
for(int i = 1; i <= cntr; i++) leftcnt += tr[tr[tempr[i]].L].cnt;
for(int i = 1; i <= cntl; i++) leftcnt -= tr[tr[templ[i]].L].cnt;
int mid = l + r >> 1;
if(k <= leftcnt)
{
//所有涉及线段树往左子树递归查找第k小,暂存所有左子结点
for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].L;
for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].L;
return query_kth(l, mid, k);
}
else
{
//所有涉及线段树往右子树递归查找第leftcnt-k小,暂存所有右子结点
for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].R;
for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].R;
return query_kth(mid + 1, r, k - leftcnt);
}
}
//查询[l,r]范围第k小元素值
int find_kth(int l, int r, int k)
{
//利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
//实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
//先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
cntr = 0;
for(int i = r; i; i -= lowbit(i)) tempr[++cntr] = root[i];
//利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
//实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
//先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
cntl = 0;
for(int i = l - 1; i; i -= lowbit(i)) templ[++cntl] = root[i];
return query_kth(1, b.size(), k);
}
3、单点更新
通过树状数组定位到要影响的线段树根节点,然后对线段树进行更新:
//将根节点是pre的权值线段树,通过节点复制,将值为x的节点个数cnt增加v
int update(int pre, int l, int r, int x, int v)
{
int u = ++idx;
tr[u].L = tr[pre].L;
tr[u].R = tr[pre].R;
tr[u].cnt = tr[pre].cnt + v;
if(l == r) return u;
int mid = l + r >> 1;
if(x <= mid) tr[u].L = update(tr[u].L, l, mid, x, v);
else tr[u].R = update(tr[u].R, mid + 1, r, x, v);
return u;
}
//利用树状数组将第x个元素的值v的个数cnt加add,影响到线段树root[x]、root[x+lowbit(x)] ...
void add(int x, int v, int add)
{
for(int i = x; i <= n; i += lowbit(i))
{
root[i] = update(root[i], 1, b.size(), v, add);
//printf("root[%d]:%d 线段树\n", i, root[i]);
}
}
4、查k在区间[l,r]的前驱
//查询[l,r]范围内x的前驱
int find_pre(int l, int r, int x)
{
//先查x的排名
int rank = find_rank(l, r, x);
if(rank == 1) return -INF; //没有前驱
//第rank-1小的值就是x的前驱
return b[find_kth(l, r, rank - 1) - 1];
}
5、查k在区间[l,r]的后继
//查询[l,r]范围内x的后继
int find_next(int l, int r, int x)
{
//先查x+1的排名
int rank = find_rank(l, r, x + 1);
if(rank > r - l + 1) return INF; //后继不存在
//第rank小的值就是x的后继
return b[find_kth(l, r, rank) - 1];
}
分析一下空间复杂度,初始每个元素都会添加到logn棵线段树,每次复制涉及logn个节点,一共n次;更新操作一共可能m次,每次涉及logn棵线段树,每次复制涉及logn个节点;因此总空间再n*logn*logn+m*log*logn = 2n*logn*logn,n最大5000,线段树节点空间可以设为n*600。
100分代码:
#include <bits/stdc++.h>
using namespace std;
const int N = 50005, INF = 2147483647;
struct Node
{
int L, R; //左、右子节点编号
int cnt; //节点所表示值域区间[l,r]的元素个数
} tr[N * 600];
struct Op
{
int opt;
int l, r, pos, k;
} ops[N]; //所有操作
int root[N], idx; //root:所有根节点,idx:节点编号
int a[N]; //原序列
vector<int> b; //用于离散化
int tempr[N], cntr, templ[N], cntl;
int n, m;
int lowbit(int x)
{
return x & -x;
}
//查询x离散化之后的值
int lsh(int x)
{
return lower_bound(b.begin(), b.end(), x) - b.begin() + 1;
}
//将根节点是pre的权值线段树,通过节点复制,将值为x的节点个数cnt增加v
int update(int pre, int l, int r, int x, int v)
{
int u = ++idx;
tr[u].L = tr[pre].L;
tr[u].R = tr[pre].R;
tr[u].cnt = tr[pre].cnt + v;
if(l == r) return u;
int mid = l + r >> 1;
if(x <= mid) tr[u].L = update(tr[u].L, l, mid, x, v);
else tr[u].R = update(tr[u].R, mid + 1, r, x, v);
return u;
}
//在tempr、templ保存的线段树中查找符合find_kth里指定的范围的第k小
int query_kth(int l, int r, int k)
{
if(l == r) return l;
int leftcnt = 0;
for(int i = 1; i <= cntr; i++) leftcnt += tr[tr[tempr[i]].L].cnt;
for(int i = 1; i <= cntl; i++) leftcnt -= tr[tr[templ[i]].L].cnt;
int mid = l + r >> 1;
if(k <= leftcnt)
{
//所有涉及线段树往左子树递归查找第k小,暂存所有左子结点
for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].L;
for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].L;
return query_kth(l, mid, k);
}
else
{
//所有涉及线段树往右子树递归查找第leftcnt-k小,暂存所有右子结点
for(int i = 1; i <= cntr; i++) tempr[i] = tr[tempr[i]].R;
for(int i = 1; i <= cntl; i++) templ[i] = tr[templ[i]].R;
return query_kth(mid + 1, r, k - leftcnt);
}
}
//在根为u的线段树中查询符合x~y之间元素的个数
int query_cnt(int u, int l, int r, int x, int y)
{
if(l >= x && r <= y)
{
return tr[u].cnt;
}
else if(l > y || r < x) return 0;
else
{
int mid = l + r >> 1;
return query_cnt(tr[u].L, l, mid, x, y) + query_cnt(tr[u].R, mid + 1, r, x, y);
}
}
//利用树状数组将第x个元素的值v的个数cnt加add,影响到线段树root[x]、root[x+lowbit(x)] ...
void add(int x, int v, int add)
{
for(int i = x; i <= n; i += lowbit(i))
{
root[i] = update(root[i], 1, b.size(), v, add);
//printf("root[%d]:%d 线段树\n", i, root[i]);
}
}
//查询[l,r]范围第k小元素值
int find_kth(int l, int r, int k)
{
//利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
//实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
//先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
cntr = 0;
for(int i = r; i; i -= lowbit(i)) tempr[++cntr] = root[i];
//利用数状数组查询root[1]~root[r]的线段树范围内元素val的数量
//实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
//先不进行真实查询,而是把涉及到的线段树根节点都保存到tempr
cntl = 0;
for(int i = l - 1; i; i -= lowbit(i)) templ[++cntl] = root[i];
return query_kth(1, b.size(), k);
}
//查询[l,r]范围内x的排名
//查询x的排名,就是查询root[l]~root[r]的线段树中1~x-1有多少个,然后加1
int find_rank(int l, int r, int x)
{
int sum = 0;
//利用数状数组查询root[1]~root[r]的线段树范围内元素1~x-1的数量
//实际的数据是在根节点为root[r]、root[r-lowbit(r)]...的线段树中查询
for(int i = r; i; i -= lowbit(i)) sum += query_cnt(root[i], 1, b.size(), 1, x - 1);
//利用数状数组查询root[1]~root[l-1]的线段树范围内元素1~x-1的数量
//实际的数据是在根节点为root[l-1]、root[l-1-lowbit(l-1)]...的线段树中查询
for(int i = l - 1; i; i -= lowbit(i)) sum -= query_cnt(root[i], 1, b.size(), 1, x - 1);
return sum + 1;
}
//查询[l,r]范围内x的前驱
int find_pre(int l, int r, int x)
{
//先查x的排名
int rank = find_rank(l, r, x);
if(rank == 1) return -INF; //没有前驱
//第rank-1小的值就是x的前驱
return b[find_kth(l, r, rank - 1) - 1];
}
//查询[l,r]范围内x的后继
int find_next(int l, int r, int x)
{
//先查x+1的排名
int rank = find_rank(l, r, x + 1);
if(rank > r - l + 1) return INF; //后继不存在
//第rank小的值就是x的后继
return b[find_kth(l, r, rank) - 1];
}
int main()
{
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
cin >> n >> m;
for(int i = 1; i <= n; i++)
{
cin >> a[i];
b.push_back(a[i]);
}
for(int i = 1; i <= m; i++)
{
cin >> ops[i].opt;
if(ops[i].opt == 3) cin >> ops[i].pos >> ops[i].k;
else cin >> ops[i].l >> ops[i].r >> ops[i].k;
if(ops[i].opt != 2) b.push_back(ops[i].k); //将值加入b进行离散化
}
//排序去重离散化
sort(b.begin(), b.end());
b.erase(unique(b.begin(), b.end()), b.end());
//将序列离散化后的值利用树状数组构建权值线段树
for(int i = 1; i <= n; i++) add(i, lsh(a[i]), 1);
for(auto o : ops)
{
if(o.opt == 1) cout << find_rank(o.l, o.r, lsh(o.k)) << endl;
else if(o.opt == 2) cout << b[find_kth(o.l, o.r, o.k) - 1] << endl;
else if(o.opt == 3)
{
add(o.pos, lsh(a[o.pos]), -1);
a[o.pos] = o.k;
add(o.pos, lsh(a[o.pos]), 1);
}
else if(o.opt == 4) cout << find_pre(o.l, o.r, lsh(o.k)) << endl;
else if(o.opt == 5) cout << find_next(o.l, o.r, lsh(o.k)) << endl;
}
return 0;
}