【模板】二逼平衡树(树套树)

Description

维护一个可以支持查询区间某值的排名,查询区间某排名的值,修改某值,查询某值在区间内的前驱和后继的数据结构。

区间长度为 \(n\)\(m\) 次询问。

\(1\leq n\ m\leq5\cdot 10^4\ \ \ \ 0\leq a[i]\leq 10^8\)

Solution

刚学了平衡树,发现硬上平衡树仿佛不行。(全是区间内呀)

正好老师也大概讲了一下树套树是什么玩意。

看到大部分操作其实都都是平衡树的基本操作,只不过都在区间内,然后就很自然地想到了线段树。

于是我们就可以考虑线段树套平衡树(虽说最优的还是树状数组套值域线段树)

直接在线段树每个节点上建一颗平衡树。

那么对于这五个操作,除了第三个,其他的直接在线段树上找到合法区间硬上平衡树就行。

第三个操作的话从根节点开始找,碰上一个节点更新哪里平衡树对应的值(删掉 \(a[pos]\) ,加入 \(k\) ),直到找到最后 \(l=r=pos\) 时,更新后,把 \(a[pos]\) 改成 \(k\) 就行了。

#include<bits/stdc++.h>
#define ls(i) spl[i].ch[0]
#define rs(i) spl[i].ch[1]
#define reg register
using namespace std;
typedef long long ll;
const int N=5e4+10;
const int INF=2147483647;
int n,m,a[N],tot;
struct Splay{int ch[2],fa,cnt,val,siz;}spl[N<<6];
struct Seg_tree{int lt,rt,root;}seg[N<<2];
inline int read(){
	int s=0,w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
	while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
	return s*w;
}
//splay start
inline void pushup(int now){
	spl[now].siz=spl[now].cnt+spl[ls(now)].siz+spl[rs(now)].siz;
}
inline void rotate(int now){
	int nxt=spl[now].fa;
	int nnt=spl[nxt].fa;
	int k1=rs(nxt)==now;
	int k2=rs(nnt)==nxt;
	int pre=spl[now].ch[k1^1];
	spl[nnt].ch[k2]=now;	spl[now].fa=nnt;
	spl[nxt].ch[k1]=pre;	spl[pre].fa=nxt;
	spl[now].ch[k1^1]=nxt;	spl[nxt].fa=now;
	pushup(nxt);pushup(now);
}//虽然肯定会跑得慢,但真tm好看
inline void splay(int now,int S,int it){
	while(spl[now].fa!=S){
		int nxt=spl[now].fa;
		int nnt=spl[nxt].fa;
		int k1=rs(nxt)==now;
		int k2=rs(nnt)==nxt;
		if(nnt!=S)(k1^k2)?rotate(now):rotate(nxt);
		rotate(now);
	}
	if(!S)seg[it].root=now;
}
inline void insert(int now,int it){
	int u=seg[it].root,fth=0;
	while(u&&spl[u].val!=now){
		fth=u;
		u=spl[u].ch[now>spl[u].val];
	}
	if(u)++spl[u].cnt;
	else {
		u=++tot;ls(u)=rs(u)=0;
		if(fth)spl[fth].ch[now>spl[fth].val]=u;
		spl[u].cnt=spl[u].siz=1;
		spl[u].fa=fth;spl[u].val=now;
	}
	splay(u,0,it);
}
inline void find(int now,int it){
	int u=seg[it].root;
	if(!u)return ;
	while(spl[u].ch[now>spl[u].val]&&now!=spl[u].val){
		u=spl[u].ch[now>spl[u].val];
	}
	splay(u,0,it);
}
inline int pre_nxt(int now,int k,int it){
	find(now,it);
	int u=seg[it].root;
	if(spl[u].val<now&&(!k))return u;
	if(spl[u].val>now&&k)return u;
	u=spl[u].ch[k];
	while(spl[u].ch[k^1])u=spl[u].ch[k^1];
	return u;
}
inline void delet(int now,int it){
	int pre=pre_nxt(now,0,it);
	int nxt=pre_nxt(now,1,it);
	splay(pre,0,it);splay(nxt,pre,it);
	int del=ls(nxt);
	if(spl[del].cnt>1){
		--spl[del].cnt;
		splay(del,0,it);
	}
	else {
		ls(nxt)=0;//del
	}
	pushup(pre);
}
//splay end

/*inline int kth_find(int now,int it){
	int u=seg[it].root;
	if(spl[u].siz<now)return 0;
	while(1){
		if(now>spl[ls(u)].siz+spl[u].cnt){
			now-=spl[ls(u)].siz+spl[u].cnt;
			u=rs(u);
		}
		else if(now<=spl[ls(u)].siz)u=ls(u);
		else return spl[u].val;
	}
}*/

//seg_tree start
inline void sbuild(int lt,int rt,int it){
	insert(-INF,it);insert(INF,it);
	if(lt==rt)return ;
	int mid=(lt+rt)>>1;
	sbuild(lt,mid,it<<1);
	sbuild(mid+1,rt,it<<1|1);
}
inline void sinsert(int lt,int rt,int pos,int val,int it){
	insert(val,it);
	if(lt==rt)return ;
	int mid=(lt+rt)>>1;
	if(mid>=pos)sinsert(lt,mid,pos,val,it<<1);
	else sinsert(mid+1,rt,pos,val,it<<1|1);
}
inline int sfind(int lt,int rt,int LT,int RT,int pos,int it){
	if(RT<lt||rt<LT)return 0;
	if(LT<=lt&&rt<=RT){
		find(pos,it);
		int u=seg[it].root;
		if(spl[u].val>=pos)return spl[ls(u)].siz-1;
		else return spl[ls(u)].siz-1+spl[u].cnt;
	}
	int mid=(lt+rt)>>1;
	int ans1=sfind(lt,mid,LT,RT,pos,it<<1);
	int ans2=sfind(mid+1,rt,LT,RT,pos,it<<1|1);
	return ans1+ans2;
}
inline int skth_find(int lt,int rt,int LT,int RT,int val){
	int mid,now,ans;
	while(lt<=rt){
		mid=(lt+rt)>>1;
		now=sfind(1,n,LT,RT,mid,1)+1;
		if(now>val)rt=mid-1;
		else lt=mid+1,ans=mid;
	}
	return ans;
}
inline void smodify(int lt,int rt,int pos,int val,int it){
	delet(a[pos],it);insert(val,it);
	if(lt==rt&&rt==pos){
		a[pos]=val;
		return ;
	}
	int mid=(lt+rt)>>1;
	if(mid>=pos)smodify(lt,mid,pos,val,it<<1);
	else smodify(mid+1,rt,pos,val,it<<1|1);
}
inline int spre_nxt(int sig,int lt,int rt,int LT,int RT,int pos,int it){
	if(RT<lt||rt<LT){
		if(!sig)return -INF;
		return INF;
	}
	if(LT<=lt&&rt<=RT){
		int u=pre_nxt(pos,sig,it);
		return spl[u].val;
	}
	int mid=(lt+rt)>>1;
	int ans1=spre_nxt(sig,lt,mid,LT,RT,pos,it<<1);
	int ans2=spre_nxt(sig,mid+1,rt,LT,RT,pos,it<<1|1);
	if(!sig)return max(ans1,ans2);
	return min(ans1,ans2);
}
//seg_tree end
int main(){
	n=read();m=read();
	sbuild(1,n,1);
	for(int i=1;i<=n;++i){
		a[i]=read();
		sinsert(1,n,i,a[i],1);
	}
	while(m--){
		int opt=read(),l=read(),r=read(),k;
		if(opt==1){
			k=read();
			printf("%d\n",sfind(1,n,l,r,k,1)+1);
		}
		else if(opt==2){
			k=read();
			printf("%d\n",skth_find(0,1e8,l,r,k));
		}
		else if(opt==3){
			smodify(1,n,l,r,1);
		}
		else if(opt==4){
			k=read();
			printf("%d\n",spre_nxt(0,1,n,l,r,k,1));
		}
		else if(opt==5){
			k=read();
			printf("%d\n",spre_nxt(1,1,n,l,r,k,1));
		}
	}
	return 0;
}

(这玩意考场上打得出来??

posted @ 2021-07-09 12:57  Illusory_dimes  阅读(86)  评论(0编辑  收藏  举报