2476. 树套树
题目链接
2476. 树套树
请你写出一种数据结构,来维护一个长度为 \(n\) 的数列,其中需要提供以下操作:
1 l r x
,查询整数 \(x\) 在区间 \([l,r]\) 内的排名。2 l r k
,查询区间 \([l,r]\) 内排名为 \(k\) 的值。3 pos x
,将 \(pos\) 位置的数修改为 \(x\)。4 l r x
,查询整数 \(x\) 在区间 \([l,r]\) 内的前驱(前驱定义为小于 \(x\),且最大的数)。5 l r x
,查询整数 \(x\) 在区间 \([l,r]\) 内的后继(后继定义为大于 \(x\),且最小的数)。
数列中的位置从左到右依次标号为 \(1 \sim n\)。
区间 \([l,r]\) 表示从位置 \(l\) 到位置 \(r\) 之间(包括两端点)的所有数字。
区间内排名为 \(k\) 的值指区间内从小到大排在第 \(k\) 位的数值。(位次从 \(1\) 开始)
输入格式
第一行包含两个整数 \(n,m\),表示数列长度以及操作次数。
第二行包含 \(n\) 个整数,表示数列。
接下来 \(m\) 行,每行包含一个操作指令,格式如题目所述。
输出格式
对于所有操作 \(1,2,4,5\),每个操作输出一个查询结果,每个结果占一行。
数据范围
\(1 \le n,m \le 5 \times 10^4\),
\(1 \le l \le r \le n\),
\(1 \le pos \le n\),
\(1 \le k \le r-l+1\),
\(0 \le x \le 10^8\),
有序数列中的数字始终满足在 \([0,10^8]\) 范围内,
数据保证所有操作一定合法,所有查询一定有解。
输入样例:
9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5
输出样例:
2
4
3
4
9
解题思路
树套树,线段树套平衡树
本题外层线段树,内层平衡树,要求查找排名,则不能用 STL 代替平衡树,线段树可以恰好将一个区间分成 \(O(logn)\) 段小区间,要求找某个数 \(x\) 的排名,即找有多少个数小于 \(x\),在每个小区间中计算再累加即可;还要求查找排名为 \(k\) 的值,直接找不好做,可以通过某个数的排名二分答案进而求解;修改某个数会涉及到线段树中的 \(O(logn)\) 个状态节点,而每个状态节点对应一棵 \(splay\) 树,即转化为在 \(O(logn)\) 棵 \(splay\) 中修改某个数,即删除该数和插入修改的数,这里简单说下平衡树中删除操作:找到该数的一个节点,将该节点转到根节点,进而找到其前驱和后继,旋转前驱为根节点,后继为前驱的右子树,此时要删除的数为后继的左子树;查询区间某数 \(x\) 的前驱找对应小区间内小于 \(x\) 的最大值,这些小区间最大值取最大即为答案,查询区间某数 \(x\) 的后继同理
- 时间复杂度:\((nlog^3n)\)
代码
// Problem: 树套树
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/2478/
// Memory Limit: 128 MB
// Time Limit: 4000 ms
//
// Powered by CP Editor (https://cpeditor.org)
// %%%Skyqwq
#include <bits/stdc++.h>
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
template <typename T> void inline read(T &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
x *= f;
}
const int N=5e4+5,M=1500000,inf=1e9;
int n,m,w[N],L[N*4],R[N*4],cnt,root[N*4];
struct Tr
{
int s[2],p,v,sz;
void init(int _p,int _v)
{
p=_p,v=_v;
sz=1;
}
}tr[M];
void pushup(int u)
{
tr[u].sz=tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int &root,int x,int k)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
{
if((tr[z].s[1]==y)^(tr[y].s[1]==x))rotate(x);
else
rotate(y);
}
rotate(x);
}
if(!k)root=x;
}
void insert(int &root,int v)
{
int u=root,p=0;
while(u)p=u,u=tr[u].s[v>tr[u].v];
u=++cnt;
if(p)tr[p].s[v>tr[p].v]=u;
tr[u].init(p,v);
splay(root,u,0);
}
int kth(int root,int v)
{
int u=root,res=0;
while(u)
{
if(tr[u].v<v)res+=tr[tr[u].s[0]].sz+1,u=tr[u].s[1];
else
u=tr[u].s[0];
}
return res;
}
void update(int &root,int x,int y)
{
int u=root;
while(u)
{
if(tr[u].v==x)break;
if(tr[u].v<x)u=tr[u].s[1];
else
u=tr[u].s[0];
}
splay(root,u,0);
int l=tr[u].s[0],r=tr[u].s[1];
while(tr[l].s[1])l=tr[l].s[1];
while(tr[r].s[0])r=tr[r].s[0];
splay(root,l,0),splay(root,r,l);
tr[r].s[0]=0;
pushup(r),pushup(l);
insert(root,y);
}
int pre(int root,int x)
{
int u=root,res=-inf;
while(u)
{
if(tr[u].v<x)res=max(res,tr[u].v),u=tr[u].s[1];
else
u=tr[u].s[0];
}
return res;
}
int suc(int root,int x)
{
int u=root,res=inf;
while(u)
{
if(tr[u].v>x)res=min(res,tr[u].v),u=tr[u].s[0];
else
u=tr[u].s[1];
}
return res;
}
void build(int p,int l,int r)
{
L[p]=l,R[p]=r;
insert(root[p],-inf),insert(root[p],inf);
for(int i=l;i<=r;i++)insert(root[p],w[i]);
if(l==r)return ;
int mid=l+r>>1;
build(p<<1,l,mid),build(p<<1|1,mid+1,r);
}
int ask(int p,int l,int r,int x)
{
if(l<=L[p]&&R[p]<=r)return kth(root[p],x)-1;
int mid=L[p]+R[p]>>1,res=0;
if(l<=mid)res+=ask(p<<1,l,r,x);
if(r>mid)res+=ask(p<<1|1,l,r,x);
return res;
}
void change(int p,int x,int y)
{
update(root[p],w[x],y);
if(L[p]==R[p])return ;
int mid=L[p]+R[p]>>1;
if(x<=mid)change(p<<1,x,y);
else
change(p<<1|1,x,y);
}
int get_pre(int p,int l,int r,int x)
{
if(l<=L[p]&&R[p]<=r)return pre(root[p],x);
int mid=L[p]+R[p]>>1,res=-inf;
if(l<=mid)res=max(res,get_pre(p<<1,l,r,x));
if(r>mid)res=max(res,get_pre(p<<1|1,l,r,x));
return res;
}
int get_suc(int p,int l,int r,int x)
{
if(l<=L[p]&&R[p]<=r)return suc(root[p],x);
int mid=L[p]+R[p]>>1,res=inf;
if(l<=mid)res=min(res,get_suc(p<<1,l,r,x));
if(r>mid)res=min(res,get_suc(p<<1|1,l,r,x));
return res;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",&w[i]);
build(1,1,n);
while(m--)
{
int op,l,r,x,pos,k;
scanf("%d",&op);
if(op==1)
{
scanf("%d%d%d",&l,&r,&x);
printf("%d\n",ask(1,l,r,x)+1);
}
else if(op==2)
{
scanf("%d%d%d",&l,&r,&k);
int L=0,R=1e8;
while(L<R)
{
int mid=L+R+1>>1;
if(ask(1,l,r,mid)+1<=k)L=mid;
else
R=mid-1;
}
printf("%d\n",L);
}
else if(op==3)
{
scanf("%d%d",&pos,&x);
change(1,pos,x);
w[pos]=x;
}
else if(op==4)
{
scanf("%d%d%d",&l,&r,&x);
printf("%d\n",get_pre(1,l,r,x));
}
else
{
scanf("%d%d%d",&l,&r,&x);
printf("%d\n",get_suc(1,l,r,x));
}
}
return 0;
}