【bzoj2738】矩阵乘法 整体二分+二维树状数组
题目描述
给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。
输入
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
输出
对于每组询问输出第K小的数。
样例输入
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
样例输出
1
3
题解
整体二分+二维树状数组
题目答案具有很明显的二分性质,所以可以考虑把询问离线,整体二分。(其实是看了题解才知道是整体二分的,一开始想了权值线段树套线段树套线段树)
整体二分是什么?和标准的二分几乎是一样的,只不过在判定的时候是把所有答案属于这个区间的询问都一一判定,根据是否满足条件,把询问分成答案属于[l,mid]和[mid+1,r]中的,然后对这两个子区间分别处理。最后当l=r时说明所有属于这个区间内的询问的答案都是l。
那么对于本题,我们可以按照矩阵内的数从大到小排序。然后把所有询问读进来,整体二分。
用solve(b,e,l,r)表示要解决询问区间在[b,e]内,点区间在[l,r]内的答案,那么当l=r时,ans[b~e]=v[l]。
当l≠r时,令mid=(l+r)/2。此时需要判定的就是是否有大于等于k个数在[l,mid]之间。
所以把[l,mid]内对应位置的权值+1,要求的就是矩形内的权值之和,可以使用二维树状数组来解决。
如果矩形内的数的个数大于等于k,则答案在[l,mid]内,把它放到对应的区间中;否则答案在[mid+1,r]中,并且把k值减去数的个数(因为这些数一定是前k小的)。
最后把二维树状数组清空(不能使用memset,需要动态清零),然后递归处理左右区间即可。
时间复杂度$O(q\log^3n)$
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; struct POINT { int x , y , v; bool operator<(const POINT a)const {return v < a.v;} }a[250010]; struct QUERY { int x1 , y1 , x2 , y2 , k , id; }q[60010] , t[60010]; int f[510][510] , n , ans[60010]; void add(int x , int y , int a) { int i , j; for(i = x ; i <= n ; i += i & -i) for(j = y ; j <= n ; j += j & -j) f[i][j] += a; } int query(int x , int y) { int i , j , ans = 0; for(i = x ; i ; i -= i & -i) for(j = y ; j ; j -= j & -j) ans += f[i][j]; return ans; } void solve(int b , int e , int l , int r) { if(b > e) return; int mid = (l + r) >> 1 , i , tl = b , tr = e , c; if(l == r) { for(i = b ; i <= e ; i ++ ) ans[q[i].id] = a[l].v; return; } for(i = l ; i <= mid ; i ++ ) add(a[i].x , a[i].y , 1); for(i = b ; i <= e ; i ++ ) { c = query(q[i].x2 , q[i].y2) - query(q[i].x1 - 1 , q[i].y2) - query(q[i].x2 , q[i].y1 - 1) + query(q[i].x1 - 1 , q[i].y1 - 1); if(c >= q[i].k) t[tl ++ ] = q[i]; else q[i].k -= c , t[tr -- ] = q[i]; } for(i = b ; i <= e ; i ++ ) q[i] = t[i]; for(i = l ; i <= mid ; i ++ ) add(a[i].x , a[i].y , -1); solve(b , tr , l , mid) , solve(tl , e , mid + 1 , r); } int main() { int m , i , j; scanf("%d%d" , &n , &m); for(i = 1 ; i <= n ; i ++ ) for(j = 1 ; j <= n ; j ++ ) scanf("%d" , &a[(i - 1) * n + j].v) , a[(i - 1) * n + j].x = i , a[(i - 1) * n + j].y = j; sort(a + 1 , a + n * n + 1); for(i = 1 ; i <= m ; i ++ ) scanf("%d%d%d%d%d" , &q[i].x1 , &q[i].y1 , &q[i].x2 , &q[i].y2 , &q[i].k) , q[i].id = i; solve(1 , m , 1 , n * n); for(i = 1 ; i <= m ; i ++ ) printf("%d\n" , ans[i]); return 0; }