[TJOI2010] 中位数
给一个序列,要求在支持插入元素的情况下,维护序列中位数。
这道题解法就比较多了,显然可以用各种平衡树搞。
这里给出用堆维护的算法。
先维护一个大根堆和一个小根堆。
可以脑补一下,如果这个序列从左到右是单调不减的,小根堆就代表序列左边的一半,大根堆则代表右边的一半。
两个堆的堆顶则代表这两半的边界元素。
这里的“一半”不是严格的一半,而可能是一边是一小半,另一边是一大半。
比如1 2 3||4 5 6。
一开始先把所有元素全部加到一边。
插入元素的时候,根据新元素与两个堆顶的大小关系决定该加到哪个堆里。
如果向1 2 3||4 5 6中插入5,显然应该插到右边。
也就是要使小根堆的堆顶永远不小于大根堆的堆顶。
即:右边的最左元素不小于左边的最右元素。
这样两边的分界处才是一个正确的分界处。
询问的时候通过插入和删除操作,平衡两个堆的大小(两堆的大小之差不大于1)。
这个操作对应到序列上就是将分界处移到序列中间。
即:1||2 3 4 5 6 --> 1 2 3||4 5 6
此时答案一定是两个堆的堆顶之一。
然后分类讨论一下该输出哪个堆的堆顶。
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 using std::swap; 5 using std::min; 6 7 int n,m; 8 9 struct min_heap 10 { 11 int sz; 12 int v[110005]; 13 void insert(int val) 14 { 15 int p=++sz; 16 v[p]=val; 17 while(p!=1&&v[p]<v[p>>1]) 18 swap(v[p],v[p>>1]),p>>=1; 19 } 20 int top(){return v[1];} 21 int check(int p) 22 { 23 if((p<<1)>sz)return -1; 24 if((p<<1|1)>sz) 25 { 26 if(v[p<<1]<v[p])return 0; 27 else return -1; 28 } 29 if(v[p]<v[p<<1]&&v[p]<v[p<<1|1])return -1; 30 if(v[p<<1]<v[p<<1|1])return 0; 31 return 1; 32 } 33 void pop() 34 { 35 swap(v[1],v[sz]); 36 v[sz--]=0; 37 int p=1,k=check(p); 38 while(k!=-1) 39 { 40 swap(v[p],v[(p<<1)+k]); 41 p=(p<<1)+k; 42 k=check(p); 43 } 44 } 45 }h2; 46 47 struct max_heap 48 { 49 int sz; 50 int v[110005]; 51 void insert(int val) 52 { 53 int p=++sz; 54 v[p]=val; 55 while(p!=1&&v[p]>v[p>>1]) 56 swap(v[p],v[p>>1]),p>>=1; 57 } 58 int top(){return v[1];} 59 int check(int p) 60 { 61 if((p<<1)>sz)return -1; 62 if((p<<1|1)>sz) 63 { 64 if(v[p<<1]>v[p])return 0; 65 else return -1; 66 } 67 if(v[p]>v[p<<1]&&v[p]>v[p<<1|1])return -1; 68 if(v[p<<1]>v[p<<1|1])return 0; 69 return 1; 70 } 71 void pop() 72 { 73 swap(v[1],v[sz]); 74 v[sz--]=0; 75 int p=1,k=check(p); 76 while(k!=-1) 77 { 78 swap(v[p],v[(p<<1)+k]); 79 p=(p<<1)+k; 80 k=check(p); 81 } 82 } 83 }h1; 84 85 int abv(int r){return r>0?r:(-r);} 86 87 int read() 88 { 89 int ret=0,fl=1; 90 char c=getchar(); 91 while(c<'0'||c>'9'){if(c=='-')fl=-1;c=getchar();} 92 while(c>='0'&&c<='9')ret=ret*10+c-'0',c=getchar(); 93 return ret*fl; 94 } 95 96 int main() 97 { 98 scanf("%d",&n); 99 for(int i=1;i<=n;i++) 100 { 101 int t=read(); 102 h1.insert(t); 103 } 104 scanf("%d",&m); 105 for(int i=1;i<=m;i++) 106 { 107 char op[10]; 108 scanf("%s",op+1); 109 if(op[1]=='a') 110 { 111 int t=read(); 112 if(t>h1.top())h2.insert(t); 113 else h1.insert(t); 114 } 115 if(op[1]=='m') 116 { 117 while(abv(h1.sz-h2.sz)>1) 118 { 119 if(h1.sz>h2.sz) 120 { 121 int t=h1.top(); 122 h1.pop(); 123 h2.insert(t); 124 }else 125 { 126 int t=h2.top(); 127 h2.pop(); 128 h1.insert(t); 129 } 130 } 131 if(h1.sz>h2.sz)printf("%d\n",h1.top()); 132 if(h1.sz<h2.sz)printf("%d\n",h2.top()); 133 if(h1.sz==h2.sz)printf("%d\n",min(h1.top(),h2.top())); 134 } 135 } 136 return 0; 137 }