vector与树状数组实现平衡树
vector操作非常奇妙,不过慎用,易越界
#include<cstdio> #include<algorithm> #include<vector> using namespace std; int n; vector<int>v; int main(){ scanf("%d",&n); for(int i=1;i<=n;i++){ int op=0,x=0; scanf("%d%d",&op,&x); if(op==1) v.insert(upper_bound(v.begin(),v.end(),x),x); else if(op==2) v.erase(lower_bound(v.begin(),v.end(),x)); else if(op==3) printf("%d\n",lower_bound(v.begin(),v.end(),x)-v.begin()+1); else if(op==4) printf("%d\n",v[x-1]); else if(op==5) printf("%d\n",*--lower_bound(v.begin(),v.end(),x)); else if(op==6) printf("%d\n",*upper_bound(v.begin(),v.end(),x)); } return 0; }
当然树状数组有时也能代替平衡树:
添加:
void add(int pos,int x){ while(pos<=maxn) c[pos]+=x,pos+=lowbit(pos); } void build(){ for(int i=1;i<=n;i++){ int t;cin>>t;add(t,1); } }
求排名:
int myrank(int x){//rank为c++11关键字 int res=1;x--;//等于1是因为还要算上自己,x--将自己排除在外,主要是防止有几个和自身相等的数 while(x) res+=c[x],x-=lowbit(x); return res; }
第k小:
int findKth(int k){ int ans=0,cnt=0; for(int i=30;i>=0;i--){//i实际上为log2(maxn),maxn是a数组中的最大值 ans+=(1<<i); if(ans>maxn || cnt+c[ans]>=k)ans-=(1<<i);//>=k是为了防止有重复元素 else cnt+=c[ans]; } return ++ans;//如果找不到则返回n+1 }
前趋后继:
int pre(int x){ return findKth(myrank(x)-1); } int nxt(int x){ return findKth(myrank(x)+1); }
总代码:
#include<bits/stdc++.h> using namespace std; const int maxn=100050; inline int read() { int x=0,t=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')t=-1;ch=getchar();} while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x*t; } int n,q[maxn],a[maxn],p[maxn],tot=0,c[maxn]; int hash(int x){return lower_bound(q+1,q+1+tot,x)-q;} int lowbit(int x){return x&-x;} void add(int x,int p) { while(x<=tot) { c[x]+=p; x+=lowbit(x); } } int sum(int x) { int res=0; while(x) { res+=c[x]; x-=lowbit(x); } return res; } int query(int x) { int t=0; for(int i=19;i>=0;i--) { t+=1<<i; if(t>tot||c[t]>=x) t-=1<<i; else x-=c[t]; } return q[t+1]; } int main() { n=read(); for(int i=1;i<=n;i++) { p[i]=read(),a[i]=read(); if(p[i]!=4) q[++tot]=a[i]; } sort(q+1,q+1+tot); tot=unique(q+1,q+1+tot)-(1+q); for(int i=1;i<=n;i++) { if(p[i]==1) add(hash(a[i]),1); if(p[i]==2) add(hash(a[i]),-1); if(p[i]==3) printf("%d\n",sum(hash(a[i])-1)+1); if(p[i]==4) printf("%d\n",query(a[i])); if(p[i]==5) printf("%d\n",query(sum(hash(a[i])-1))); if(p[i]==6) printf("%d\n",query(sum(hash(a[i]))+1)); } return 0; }