[HNOI2017]单旋
Description:
用一种数据结构,模拟单旋splay的以下操作:
1.插入一个数
2.查询最小值,单旋到根
3.查询最大值,单旋到根
4.查询最小值,单旋到根删除
5.查询最大值,单旋到根删除
Hint:
\(n,m<=10^5\)
Solution:
考虑我们需要维护什么东西,不需要维护什么东西
维护树的最大最小值,这个比较容易,插入也是,直接set搞就行了
现在的问题是如何维护树的形态和深度
手模会发现,单旋之后的splay节点深度会区间整体+1或-1
考虑用线段树维护,分类讨论一下
然后再次手模,每次splay只有当前节点的父子关系会发生改变,其他节点不变
于是我们需要维护每个节点的父子关系
考虑操作的实现,
最大值最小值可以直接找
每次单旋到根相当于直接更新当前节点的父子信息
以及其他节点的深度信息
删除操作,由于根一定只有一棵子树,直接断边就行了
题目不难,主要是要敢于找规律,勤手模,做法还是很好想的
代码细节稍多,但只要思路清晰,还是很好码的
#include<bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
using namespace std;
typedef set<int >::iterator sit;
const int mxn=1e6+5;
struct Q {
int opt,val;
}q[mxn];
int n,m,rt,a[mxn],b[mxn],t[mxn],fa[mxn],tag[mxn],ch[mxn][2];
set<int > s;
void push_down(int p) {
if(tag[p]) {
t[ls]+=tag[p];
t[rs]+=tag[p];
tag[ls]+=tag[p];
tag[rs]+=tag[p];
tag[p]=0;
}
}
void add(int l,int r,int ql,int qr,int val,int p)
{
if(ql<=l&&r<=qr) {
tag[p]+=val;
t[p]+=val;
return ;
}
int mid=(l+r)>>1; push_down(p);
if(ql<=mid) add(l,mid,ql,qr,val,ls);
if(qr>mid) add(mid+1,r,ql,qr,val,rs);
}
void update(int l,int r,int pos,int val,int p)
{
if(l==r) {
t[p]=val; return ;
}
int mid=(l+r)>>1; push_down(p);
if(pos<=mid) update(l,mid,pos,val,ls);
else update(mid+1,r,pos,val,rs);
}
int query(int l,int r,int pos,int p)
{
if(l==r) return t[p];
int mid=(l+r)>>1; push_down(p);
if(pos<=mid) return query(l,mid,pos,ls);
else return query(mid+1,r,pos,rs);
}
int ins(int x)
{
sit it=s.insert(x).first; //返回对应插入后该节点的迭代器
if(!rt) {
rt=x; update(1,m,x,1,1);
fa[rt]=0; ch[rt][1]=ch[rt][0]=0; //将根的深度赋为1
return 1;
}
if(it!=s.begin()) {
--it;
if(!ch[*it][1]) ch[fa[x]=*it][1]=x;
++it;
} //找父亲
if(!fa[x]) ++it,ch[fa[x]=*it][0]=x; //找父亲
int dep=query(1,m,fa[x],1)+1;
update(1,m,x,dep,1); //更新深度
return dep;
}
int findmin()
{
int x=*s.begin(),dep=query(1,m,x,1);
if(x==rt) return 1;
if(x+1<=fa[x]-1) add(1,m,x+1,fa[x]-1,-1,1);
add(1,m,1,m,1,1); //分类讨论树中节点的深度变化
ch[fa[x]][0]=ch[x][1]; fa[ch[x][1]]=fa[x];
fa[rt]=x; ch[x][1]=rt; rt=x; //splay的基本操作
update(1,m,x,1,1); return dep;
}
int findmax()
{
int x=*s.rbegin(),dep=query(1,m,x,1);
if(x==rt) return 1;
if(fa[x]+1<=x-1) add(1,m,fa[x]+1,x-1,-1,1);
add(1,m,1,m,1,1);
ch[fa[x]][1]=ch[x][0]; fa[ch[x][0]]=fa[x];
fa[rt]=x; ch[x][0]=rt; rt=x;
update(1,m,x,1,1); return dep;
}
void delmin()
{
printf("%d\n",findmin());
add(1,m,1,m,-1,1); s.erase(rt);
rt=ch[rt][1]; fa[rt]=0; //删除
}
void delmax()
{
printf("%d\n",findmax());
add(1,m,1,m,-1,1); s.erase(rt);
rt=ch[rt][0]; fa[rt]=0;
}
int main()
{
scanf("%d",&n); int opt,x;
for(int i=1;i<=n;++i) {
scanf("%d",&q[i].opt);
if(q[i].opt==1)
scanf("%d",&x),a[i]=b[++m]=x;
}
sort(b+1,b+m+1);
for(int i=1;i<=n;++i)
a[i]=lower_bound(b+1,b+m+1,a[i])-b; //离散化方便处理
for(int i=1;i<=n;++i) {
if(q[i].opt==1) printf("%d\n",ins(a[i]));
else if(q[i].opt==2) printf("%d\n",findmin());
else if(q[i].opt==3) printf("%d\n",findmax());
else if(q[i].opt==4) delmin();
else delmax();
}
return 0;
}