【BZOJ】2653: middle
2653: middle
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 2381 Solved: 1340
[Submit][Status][Discuss]
Description
一个长度为n的序列a,设其排过序之后为b,其中位数定义为b[n/2],其中a,b从0开始标号,除法取下整。给你一个
长度为n的序列s。回答Q个这样的询问:s的左端点在[a,b]之间,右端点在[c,d]之间的子序列中,最大的中位数。
其中a<b<c<d。位置也从0开始标号。我会使用一些方式强制你在线。
Input
第一行序列长度n。接下来n行按顺序给出a中的数。
接下来一行Q。然后Q行每行a,b,c,d,我们令上个询问的答案是
x(如果这是第一个询问则x=0)。
令数组q={(a+x)%n,(b+x)%n,(c+x)%n,(d+x)%n}。
将q从小到大排序之后,令真正的
要询问的a=q[0],b=q[1],c=q[2],d=q[3]。
输入保证满足条件。
第一行所谓“排过序”指的是从小到大排序!
n<=20000,Q<=25000
Output
Q行依次给出询问的答案。
Sample Input
5
170337785
271451044
22430280
969056313
206452321
3
3 1 0 2
2 3 1 4
3 1 4 0
170337785
271451044
22430280
969056313
206452321
3
3 1 0 2
2 3 1 4
3 1 4 0
Sample Output
271451044
271451044
969056313
271451044
969056313
HINT
Source
clj dalao出的神题%%%
一道很好的思维题吧。首先可以发现,满足条件的中位数是具有二分性(即单调性)的。比如数列中一个数$x$,我们把所有比$x$小的数的位置标记成-1,比$x$大的数的位置标记成1,如果和大于0,表示比它大的数枚举多了,所以把答案往右移,反之往左移。满足二分性,所以二分check时满足区间最大值>=0就表示这个数可以作为答案,更新答案。(<=0也可以用来判断,不过维护的是最小值即可)
可以发现,问题中的$bc$区间一定被包括,所以查询$ab$的右缀连续区间最大值和$cd$的左缀连续区间最大值以及$bc$的整个区间和即可。【注意】所有区间都是左闭右开。
考虑如何维护。给每个值建一棵主席数,预处理出每棵主席数上-1和1的情况,每次就在对应二分的pos的主席树上查询即可。
最后是合并节点的问题,$lmax$是从左儿子的$lmax$或者左儿子的$sum$加右儿子的$lmax$更新过来,$rmax$同理。
#include<iostream> #include<cstdio> #include<algorithm> using namespace std; int n, ans, a, b, c, d, q[5]; struct QwQ { int v, id; } A[20005]; bool cmp ( QwQ a, QwQ b ) { return a.v < b.v; } struct node { node *ls, *rs; int sum, lmax, rmax; void update ( ) { sum = ls -> sum + rs -> sum; lmax = max ( ls -> lmax, ls -> sum + rs -> lmax ); rmax = max ( rs -> rmax, rs -> sum + ls -> rmax ); } } *zero, *root[20005], pool[20005*32], *tail = pool; node *newnode ( ) { node *nd = ++ tail; nd -> ls = zero; nd -> rs = zero; nd -> sum = 0; nd -> lmax = 0; nd -> rmax = 0; return nd; } node *build ( int l, int r ) { node *nd = newnode ( ); if ( l == r ) { nd -> sum = nd -> lmax = nd -> rmax = 1; return nd; } int mid = ( l + r ) >> 1; nd -> ls = build ( l, mid ); nd -> rs = build ( mid + 1, r ); nd -> update ( ); return nd; } void init ( ) { zero = ++ tail; zero -> ls = zero; zero -> rs = zero; zero -> sum = 0; zero -> lmax = 0; zero -> rmax = 0; root[1] = build ( 1, n ); } node *insert ( node *nd, int l, int r, int pos ) { node *nnd = newnode ( ); if ( l == r ) { nnd -> sum = -1; nnd -> lmax = nnd -> rmax = -1; return nnd; } int mid = ( l + r ) >> 1; if ( mid >= pos ) { nnd -> rs = nd -> rs; nnd -> ls = insert ( nd -> ls, l, mid, pos ); } else { nnd -> ls = nd -> ls; nnd -> rs = insert ( nd -> rs, mid + 1, r, pos ); } nnd -> update ( ); return nnd; } int query_s ( node *nd, int l, int r, int L, int R ) { if ( L > R ) return 0; if ( l >= L && r <= R ) return nd -> sum; int mid = ( l + r ) >> 1, ans = 0; if ( L <= mid ) ans += query_s ( nd -> ls, l, mid, L, R ); if ( R > mid ) ans += query_s ( nd -> rs, mid + 1, r, L, R ); return ans; } int query_l ( node *nd, int l, int r, int L, int R ) { if ( L > R ) return 0; if ( l >= L && r <= R ) return nd -> lmax; int mid = ( l + r ) >> 1, ans = 0; if ( L <= mid ) ans = max ( ans, query_l ( nd -> ls, l, mid, L, R ) ); if ( R > mid ) { int lsum = query_s ( nd -> ls, l, mid, L, mid ); int rmaxl = query_l ( nd -> rs, mid + 1, r, L, R ); ans = max ( ans, lsum + rmaxl ); } return ans; } int query_r ( node *nd, int l, int r, int L, int R ) { if ( L > R ) return 0; if ( l >= L && r <= R ) return nd -> rmax; int mid = ( l + r ) >> 1, ans = 0; if ( R > mid ) ans = max ( ans, query_r ( nd -> rs, mid + 1, r, L, R ) ); if ( L <= mid ) { int rsum = query_s ( nd -> rs, mid + 1, r, mid + 1, R ); int lmaxr = query_r ( nd -> ls, l, mid, L, R ); ans = max ( ans, rsum + lmaxr ); } return ans; } bool check ( int pos ) { int ab = query_r ( root[pos], 1, n, a, b - 1 ); int cd = query_l ( root[pos], 1, n, c + 1, d ); int bc = query_s ( root[pos], 1, n, b, c ); if ( ab + cd + bc >= 0 ) return 1; return 0; } int find ( ) { int l = 1, r = n, ans = 0; while ( l <= r ) { int mid = ( l + r ) >> 1; if ( check ( mid ) ) { ans = mid; l = mid + 1; } else r = mid - 1; } return ans; } int main ( ) { scanf ( "%d", &n ); init ( ); for ( int i = 1; i <= n; i ++ ) { scanf ( "%d", &A[i].v ); A[i].id = i; } sort ( A + 1, A + 1 + n, cmp ); for ( int i = 2; i <= n; i ++ ) root[i] = insert ( root[i-1], 1, n, A[i-1].id ); int Q; scanf ( "%d", &Q ); while ( Q -- ) { for ( int i = 1; i <= 4; i ++ ) { int x; scanf ( "%d", &x ); q[i] = ( x + ans ) % n + 1; } sort ( q + 1, q + 5 ); a = q[1], b = q[2], c = q[3], d = q[4]; ans = A[find ( )].v; printf ( "%d\n", ans ); } return 0; }