洛谷 P3380 【模板】二逼平衡树(树套树)
线段树套treap:
就是线段树每个节点放一个treap。建树复杂度应该是$n log n$,操作1,3,4,5的复杂度是$(log n)^2$,操作2的复杂度是$(log n)^3$。
操作3:找到线段树的对应叶子节点后找到要删除的值,在回溯的时候更新线段树相关的每一个节点(在treap中去掉要删除的值,再加入要加入的值)
操作1:将操作转化为统计(这个区间[l,r]内小于x的数的个数)+1。那么通过线段树将区间分解,然后对分解出的每一个区间对应的treap求小于x的数的个数,最后将这些答案加起来再加一得到最终答案。
操作4:通过线段树将区间分解,然后对分解出的每一个区间对应的treap求x的前驱,取这些前驱的最大值。
操作5:通过线段树将区间分解,然后对分解出的每一个区间对应的treap求x的后继,取这些后继的最小值。
操作2:咋一看似乎没什么好方法。看题解,可以将操作转化为找出这个区间[l,r]内最小的x,使得区间内小于x的数不少于k-1个(即x的排名不低于k)。二分答案即可。
1 #include<cstdio> 2 #include<algorithm> 3 using namespace std; 4 #define MAXI 2147483647 5 #define lc (num<<1) 6 #define rc (num<<1|1) 7 #define mid ((l+r)>>1) 8 int rand1() 9 { 10 static int x=471; 11 return x=(48271LL*x+1)%2147483647; 12 } 13 struct Node 14 { 15 Node* ch[2]; 16 int r;//优先级 17 int v;//value 18 int size;//维护子树的节点个数 19 int num;//当前数字出现次数 20 int cmp(int x) const//要在当前节点的哪个子树去查找,0左1右 21 { 22 if(x==v) return -1; 23 return v<x; 24 } 25 void upd() 26 { 27 size=num; 28 if(ch[0]!=NULL) size+=ch[0]->size; 29 if(ch[1]!=NULL) size+=ch[1]->size; 30 } 31 }nodes[3001000]; 32 int mem; 33 void rotate(Node* &o,int d) 34 { 35 Node* t=o->ch[d^1];o->ch[d^1]=t->ch[d];t->ch[d]=o; 36 o->upd();t->upd(); 37 o=t; 38 } 39 Node* getnode(){ return &nodes[mem++];} 40 void insert(Node* &o,int x) 41 { 42 if(o==NULL) 43 { 44 o=getnode();o->ch[0]=o->ch[1]=NULL; 45 o->v=x;o->r=rand1();o->num=1; 46 } 47 else 48 { 49 if(o->v==x) ++(o->num); 50 else 51 { 52 int d=o->v < x; 53 insert(o->ch[d],x); 54 if(o->r < o->ch[d]->r) rotate(o,d^1); 55 } 56 } 57 o->upd(); 58 } 59 void remove(Node* &o,int x) 60 { 61 int d=o->cmp(x); 62 if(d==-1) 63 { 64 if(o->num > 0) 65 { 66 --(o->num); 67 } 68 if(o->num == 0) 69 { 70 if(o->ch[0]==NULL) o=o->ch[1]; 71 else if(o->ch[1]==NULL) o=o->ch[0]; 72 else 73 { 74 int d2=o->ch[1]->r < o->ch[0]->r; 75 rotate(o,d2); 76 remove(o->ch[d2],x); 77 } 78 } 79 } 80 else remove(o->ch[d],x); 81 if(o!=NULL) o->upd(); 82 } 83 bool find(Node* o,int x) 84 { 85 int d; 86 while(o!=NULL) 87 { 88 d=o->cmp(x); 89 if(d==-1) return 1; 90 else o=o->ch[d]; 91 } 92 return 0; 93 } 94 int kth(Node* o,int k) 95 { 96 if(o==NULL||k<=0||k > o->size) return 0; 97 int s= o->ch[0]==NULL ? 0 : o->ch[0]->size; 98 if(k>s&&k<=s+ o->num) return o->v; 99 else if(k<=s) return kth(o->ch[0],k); 100 else return kth(o->ch[1],k-s- o->num); 101 } 102 int rk(Node* o,int x) 103 { 104 if(o==NULL) return 0; 105 int r=o->ch[0]==NULL ? 0 : o->ch[0]->size; 106 if(x==o->v) return r; 107 else if(x<o->v) return rk(o->ch[0],x); 108 else return r+ o->num +rk(o->ch[1],x); 109 } 110 int pre(Node* o,int x) 111 { 112 if(o==NULL) return -MAXI; 113 int d=o->cmp(x); 114 if(d<=0) return pre(o->ch[0],x); 115 else return max(o->v,pre(o->ch[1],x)); 116 } 117 int nxt(Node* o,int x) 118 { 119 if(o==NULL) return MAXI; 120 int d=o->cmp(x); 121 if(d!=0) return nxt(o->ch[1],x); 122 else return min(o->v,nxt(o->ch[0],x)); 123 } 124 Node* root[200100]; 125 int x,L,R,k,d,n,m;//所有操作的操作区间用[L,R]表示而不是[l,r] 126 int a[50010]; 127 128 129 int rk1(int l,int r,int num)//返回区间内小于x的数的个数 130 { 131 if(L<=l&&r<=R) 132 { 133 return rk(root[num],x); 134 } 135 int ans=0; 136 if(L<=mid) ans+=rk1(l,mid,lc); 137 if(mid<R) ans+=rk1(mid+1,r,rc); 138 return ans; 139 } 140 int pre1(int l,int r,int num) 141 { 142 if(L<=l&&r<=R) 143 { 144 return pre(root[num],x); 145 } 146 int ans=-2147483647; 147 if(L<=mid) ans=max(ans,pre1(l,mid,lc)); 148 if(mid<R) ans=max(ans,pre1(mid+1,r,rc)); 149 return ans; 150 } 151 int nxt1(int l,int r,int num) 152 { 153 if(L<=l&&r<=R) 154 { 155 return nxt(root[num],x); 156 } 157 int ans=2147483647; 158 if(L<=mid) ans=min(ans,nxt1(l,mid,lc)); 159 if(mid<R) ans=min(ans,nxt1(mid+1,r,rc)); 160 return ans; 161 } 162 void change(int l,int r,int num) 163 { 164 if(l<r) 165 { 166 if(k<=mid) change(l,mid,lc); 167 else change(mid+1,r,rc); 168 } 169 else 170 { 171 d=kth(root[num],1); 172 } 173 remove(root[num],d); 174 insert(root[num],x); 175 } 176 void build(int l,int r,int num) 177 { 178 for(int i=l;i<=r;i++) insert(root[num],a[i]); 179 if(l<r) 180 { 181 build(l,mid,lc); 182 build(mid+1,r,rc); 183 } 184 } 185 int main() 186 { 187 int i,idx,l,r; 188 scanf("%d%d",&n,&m); 189 for(i=1;i<=n;i++) scanf("%d",&a[i]); 190 build(1,n,1); 191 for(i=1;i<=m;i++) 192 { 193 scanf("%d",&idx); 194 if(idx==1) 195 { 196 scanf("%d%d%d",&L,&R,&x); 197 printf("%d\n",rk1(1,n,1)+1); 198 } 199 else if(idx==2) 200 { 201 scanf("%d%d%d",&L,&R,&k); 202 l=-1;r=100000000; 203 while(l<r-1) 204 { 205 x=(l+r)/2; 206 if(rk1(1,n,1)+1<=k) l=x; 207 else r=x; 208 } 209 x=r;printf("%d\n",pre1(1,n,1));; 210 } 211 else if(idx==3) 212 { 213 scanf("%d%d",&k,&x); 214 change(1,n,1); 215 } 216 else if(idx==4) 217 { 218 scanf("%d%d%d",&L,&R,&x); 219 printf("%d\n",pre1(1,n,1)); 220 } 221 else if(idx==5) 222 { 223 scanf("%d%d%d",&L,&R,&x); 224 printf("%d\n",nxt1(1,n,1)); 225 } 226 } 227 return 0; 228 }