寻找新序列的中位数 - 树状数组+二分
Problem Statement
We will define the median of a sequence b of length M, as follows:
Let b' be the sequence obtained by sorting b in non-decreasing order. Then, the value of the (M?2+1)-th element of b' is the median of b. Here, ? is integer division, rounding down.
For example, the median of (10,30,20) is 20; the median of (10,30,20,40) is 30; the median of (10,10,10,20,30) is 10.
Snuke comes up with the following problem.
You are given a sequence a of length N. For each pair (l,r) (1≤l≤r≤N), let ml,r be the median of the contiguous subsequence (al,al+1,…,ar) of a. We will list ml,r for all pairs (l,r) to create a new sequence m. Find the median of m.
Constraints
1≤N≤105
ai is an integer.
1≤ai≤109
Input
Input is given from Standard Input in the following format:
N
a1 a2 … aN
Output
Print the median of m.
Sample Input 1
Copy
3
10 30 20
Sample Output 1
Copy
30
The median of each contiguous subsequence of a is as follows:
The median of (10) is 10.
The median of (30) is 30.
The median of (20) is 20.
The median of (10,30) is 30.
The median of (30,20) is 30.
The median of (10,30,20) is 20.
Thus, m=(10,30,20,30,30,20) and the median of m is 30.
Sample Input 2
Copy
1
10
Sample Output 2
Copy
10
Sample Input 3
Copy
10
5 9 5 9 8 9 3 5 4 3
Sample Output 3
Copy
8
题意 : 不断的求一个子序列的中位数,构成一个新的序列,最后问你新的序列的中位数是多少?
思路分析 :
暴力的想法是找到所有的子区间求其中位数,最后再求整个序列的中位数,显然是超时
想一下可以发现,最后求的答案一定是符合单调性的,我们就直接去二分答案即可,然后对于每一次二分的答案,我们去对原数组进行一次操作,将大于二分出的数 记做是 +1, 小于的数记做是 -1 ,做一次前缀和,然后树状数组去维护即可,每次可以 n*logn 的时间找到所有序列里中位数大于等于我们二分的答案的有多少个, 当这个个数如果大于等于 全部序列中位数个数的一半 , 那么即代表是一个合法的
代码示例 :
#include <bits/stdc++.h> using namespace std; #define ll long long const ll maxn = 1e5+5; ll n; ll pre[maxn]; ll c[maxn<<1], s[maxn]; ll lowbit(ll x) {return x&(-x);} ll query(ll x){ ll res = 0; for(ll i = x; i ; i -= lowbit(i)){ res += c[i]; } return res; } void add(ll x){ for(ll i = x; i <= 2*n; i += lowbit(i)){ c[i] += 1; } } bool check(ll x){ memset(c, 0, sizeof(c)); s[0] = 0; for(ll i = 1; i <= n; i++) { s[i] = pre[i]>=x?1:-1; s[i] += s[i-1]; } ll sum = 0; for(ll i = 0; i <= n; i++){ sum += query(s[i]+n); add(s[i]+n); } ll num = n*(n+1)/2, ss; if (num%2) ss = (num+1)/2; else ss = num/2; if (sum >= ss) return true; return false; } int main () { cin >> n; ll l = 1, r = 1; for(ll i = 1; i <= n; i++) { scanf("%lld", &pre[i]); if (pre[i] > r) r = pre[i]; } ll ans; while(l <= r){ ll mid = (l+r)>>1; if (check(mid)) ans = mid, l = mid+1; else r = mid-1; } printf("%lld\n", ans); return 0; }