2476. 树套树

题目链接

2476. 树套树

请你写出一种数据结构,来维护一个长度为 \(n\) 的数列,其中需要提供以下操作:

  1. 1 l r x,查询整数 \(x\) 在区间 \([l,r]\) 内的排名。
  2. 2 l r k,查询区间 \([l,r]\) 内排名为 \(k\) 的值。
  3. 3 pos x,将 \(pos\) 位置的数修改为 \(x\)
  4. 4 l r x,查询整数 \(x\) 在区间 \([l,r]\) 内的前驱(前驱定义为小于 \(x\),且最大的数)。
  5. 5 l r x,查询整数 \(x\) 在区间 \([l,r]\) 内的后继(后继定义为大于 \(x\),且最小的数)。

数列中的位置从左到右依次标号为 \(1 \sim n\)

区间 \([l,r]\) 表示从位置 \(l\) 到位置 \(r\) 之间(包括两端点)的所有数字。

区间内排名为 \(k\) 的值指区间内从小到大排在第 \(k\) 位的数值。(位次从 \(1\) 开始)

输入格式

第一行包含两个整数 \(n,m\),表示数列长度以及操作次数。

第二行包含 \(n\) 个整数,表示数列。

接下来 \(m\) 行,每行包含一个操作指令,格式如题目所述。

输出格式

对于所有操作 \(1,2,4,5\),每个操作输出一个查询结果,每个结果占一行。

数据范围

\(1 \le n,m \le 5 \times 10^4\),
\(1 \le l \le r \le n\),
\(1 \le pos \le n\),
\(1 \le k \le r-l+1\),
\(0 \le x \le 10^8\),
有序数列中的数字始终满足在 \([0,10^8]\) 范围内,
数据保证所有操作一定合法,所有查询一定有解。

输入样例:

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

输出样例:

2
4
3
4
9

解题思路

树套树,线段树套平衡树

本题外层线段树,内层平衡树,要求查找排名,则不能用 STL 代替平衡树,线段树可以恰好将一个区间分成 \(O(logn)\) 段小区间,要求找某个数 \(x\) 的排名,即找有多少个数小于 \(x\),在每个小区间中计算再累加即可;还要求查找排名为 \(k\) 的值,直接找不好做,可以通过某个数的排名二分答案进而求解;修改某个数会涉及到线段树中的 \(O(logn)\) 个状态节点,而每个状态节点对应一棵 \(splay\) 树,即转化为在 \(O(logn)\)\(splay\) 中修改某个数,即删除该数和插入修改的数,这里简单说下平衡树中删除操作:找到该数的一个节点,将该节点转到根节点,进而找到其前驱和后继,旋转前驱为根节点,后继为前驱的右子树,此时要删除的数为后继的左子树;查询区间某数 \(x\) 的前驱找对应小区间内小于 \(x\) 的最大值,这些小区间最大值取最大即为答案,查询区间某数 \(x\) 的后继同理

  • 时间复杂度:\((nlog^3n)\)

代码

// Problem: 树套树
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/2478/
// Memory Limit: 128 MB
// Time Limit: 4000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

// %%%Skyqwq
#include <bits/stdc++.h>
 
//#define int long long
#define help {cin.tie(NULL); cout.tie(NULL);}
#define pb push_back
#define fi first
#define se second
#define mkp make_pair
using namespace std;
 
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
 
template <typename T> bool chkMax(T &x, T y) { return (y > x) ? x = y, 1 : 0; }
template <typename T> bool chkMin(T &x, T y) { return (y < x) ? x = y, 1 : 0; }
 
template <typename T> void inline read(T &x) {
    int f = 1; x = 0; char s = getchar();
    while (s < '0' || s > '9') { if (s == '-') f = -1; s = getchar(); }
    while (s <= '9' && s >= '0') x = x * 10 + (s ^ 48), s = getchar();
    x *= f;
}

const int N=5e4+5,M=1500000,inf=1e9;
int n,m,w[N],L[N*4],R[N*4],cnt,root[N*4];
struct Tr
{
	int s[2],p,v,sz;
	void init(int _p,int _v)
	{
		p=_p,v=_v;
		sz=1;
	}
}tr[M];
void pushup(int u)
{
	tr[u].sz=tr[tr[u].s[0]].sz+tr[tr[u].s[1]].sz+1;
}
void rotate(int x)
{
	int y=tr[x].p,z=tr[y].p;
	int k=tr[y].s[1]==x;
	tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
	tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
	tr[x].s[k^1]=y,tr[y].p=x;
	pushup(y),pushup(x);
}
void splay(int &root,int x,int k)
{
	while(tr[x].p!=k)
	{
		int y=tr[x].p,z=tr[y].p;
		if(z!=k)
		{
			if((tr[z].s[1]==y)^(tr[y].s[1]==x))rotate(x);
			else
				rotate(y);
		}
		rotate(x);
	}
	if(!k)root=x;
}
void insert(int &root,int v)
{
	int u=root,p=0;
	while(u)p=u,u=tr[u].s[v>tr[u].v];
	u=++cnt;
	if(p)tr[p].s[v>tr[p].v]=u;
	tr[u].init(p,v);
	splay(root,u,0);
}
int kth(int root,int v)
{
	int u=root,res=0;
	while(u)
	{
		if(tr[u].v<v)res+=tr[tr[u].s[0]].sz+1,u=tr[u].s[1];
		else
			u=tr[u].s[0];
	}
	return res;
}
void update(int &root,int x,int y)
{
	int u=root;
	while(u)
	{
		if(tr[u].v==x)break;
		if(tr[u].v<x)u=tr[u].s[1];
		else
			u=tr[u].s[0];
	}
	splay(root,u,0);
	int l=tr[u].s[0],r=tr[u].s[1];
	while(tr[l].s[1])l=tr[l].s[1];
	while(tr[r].s[0])r=tr[r].s[0];
	splay(root,l,0),splay(root,r,l);
	tr[r].s[0]=0;
	pushup(r),pushup(l);
	insert(root,y);
}
int pre(int root,int x)
{
	int u=root,res=-inf;
	while(u)
	{
		if(tr[u].v<x)res=max(res,tr[u].v),u=tr[u].s[1];
		else
			u=tr[u].s[0];
	}
	return res;
}
int suc(int root,int x)
{
	int u=root,res=inf;
	while(u)
	{
		if(tr[u].v>x)res=min(res,tr[u].v),u=tr[u].s[0];
		else
			u=tr[u].s[1];
	}
	return res;
}

void build(int p,int l,int r)
{
	L[p]=l,R[p]=r;
	insert(root[p],-inf),insert(root[p],inf);
	for(int i=l;i<=r;i++)insert(root[p],w[i]);
	if(l==r)return ;
	int mid=l+r>>1;
	build(p<<1,l,mid),build(p<<1|1,mid+1,r);
}
int ask(int p,int l,int r,int x)
{
	if(l<=L[p]&&R[p]<=r)return kth(root[p],x)-1;
	int mid=L[p]+R[p]>>1,res=0;
	if(l<=mid)res+=ask(p<<1,l,r,x);
	if(r>mid)res+=ask(p<<1|1,l,r,x);
	return res;
}
void change(int p,int x,int y)
{
	update(root[p],w[x],y);
	if(L[p]==R[p])return ;
	int mid=L[p]+R[p]>>1;
	if(x<=mid)change(p<<1,x,y);
	else
		change(p<<1|1,x,y);
}
int get_pre(int p,int l,int r,int x)
{
	if(l<=L[p]&&R[p]<=r)return pre(root[p],x);
	int mid=L[p]+R[p]>>1,res=-inf;
	if(l<=mid)res=max(res,get_pre(p<<1,l,r,x));
	if(r>mid)res=max(res,get_pre(p<<1|1,l,r,x));
	return res;
}
int get_suc(int p,int l,int r,int x)
{
	if(l<=L[p]&&R[p]<=r)return suc(root[p],x);	
	int mid=L[p]+R[p]>>1,res=inf;
	if(l<=mid)res=min(res,get_suc(p<<1,l,r,x));
	if(r>mid)res=min(res,get_suc(p<<1|1,l,r,x));
	return res;
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",&w[i]);
	build(1,1,n); 
	while(m--)
	{
		int op,l,r,x,pos,k;
		scanf("%d",&op);
		if(op==1)
		{
			
			scanf("%d%d%d",&l,&r,&x);
			printf("%d\n",ask(1,l,r,x)+1);
		}
		else if(op==2)
		{
			
			scanf("%d%d%d",&l,&r,&k);
			int L=0,R=1e8;
			while(L<R)
			{
				int mid=L+R+1>>1;
				if(ask(1,l,r,mid)+1<=k)L=mid;
				else
					R=mid-1;
			}
			printf("%d\n",L);
		}
		else if(op==3)
		{
			
			scanf("%d%d",&pos,&x);
			change(1,pos,x);
			w[pos]=x;
		}
		else if(op==4)
		{
			
			scanf("%d%d%d",&l,&r,&x);
			printf("%d\n",get_pre(1,l,r,x));
		}
		else
		{
			
			scanf("%d%d%d",&l,&r,&x);
			printf("%d\n",get_suc(1,l,r,x));
		}
	}
    return 0;
}
posted @ 2022-07-10 22:17  zyy2001  阅读(23)  评论(0编辑  收藏  举报