【BZOJ2738】矩阵乘法 整体二分

【BZOJ2738】矩阵乘法

Description

  给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

Input

  第一行两个数N,Q,表示矩阵大小和询问组数;
  接下来N行N列一共N*N个数,表示这个矩阵;
  再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

Output

  对于每组询问输出第K小的数。

Sample Input

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

Sample Output

1
3

HINT

  矩阵中数字是109以内的非负整数;
  20%的数据:N<=100,Q<=1000;
  40%的数据:N<=300,Q<=10000;
  60%的数据:N<=400,Q<=30000;
  100%的数据:N<=500,Q<=60000。

题解:根据整体二分的思想,我们将所有数排序,然后二分。我们将[1,mid]中的所有数扔到二维树状数组中去,然后看一看那些矩阵中的元素个数≥K。我们将满足条件的放在左边,不满足的放在右边,然后继续递归下去,直至出解。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
int n,m,n2,tot,now;
struct node
{
	int x,y,val;
}v[500*510];
int q1[60010],q2[60010],q3[60010],q4[60010],qk[60010],ans[60010];
int s[510][510],p[60010],q[60010],sum[60010];
int rd()
{
	int ret=0,f=1;	char gc=getchar();
	while(gc<'0'||gc>'9')	{if(gc=='-')f=-f;	gc=getchar();}
	while(gc>='0'&&gc<='9')	ret=ret*10+gc-'0',gc=getchar();
	return ret*f;
}
bool cmp(node a,node b)
{
	return a.val<b.val;
}
void updata(int x,int y,int val)
{
	int i,j;
	for(i=x;i<=n;i+=i&-i)
		for(j=y;j<=n;j+=j&-j)
			s[i][j]+=val;
}
int query(int x,int y)
{
	int ret=0,i,j;
	for(i=x;i;i-=i&-i)
		for(j=y;j;j-=j&-j)
			ret+=s[i][j];
	return ret;
}
void solve(int l,int r,int L,int R)
{
	if(l>r)	return ;
	if(L==R)
	{
		for(int i=l;i<=r;i++)	ans[p[i]]=v[L].val;
		return ;
	}
	int MID=L+R>>1,i,mid=l-1;
	while(now<MID)	now++,updata(v[now].x,v[now].y,1);
	while(now>MID)	updata(v[now].x,v[now].y,-1),now--;
	for(i=l;i<=r;i++)
	{
		sum[p[i]]=query(q1[p[i]]-1,q2[p[i]]-1)+query(q3[p[i]],q4[p[i]])-query(q1[p[i]]-1,q4[p[i]])-query(q3[p[i]],q2[p[i]]-1);
		if(sum[p[i]]>=qk[p[i]])	mid++;
	}
	int l1=l,l2=mid+1;
	for(i=l;i<=r;i++)
	{
		if(sum[p[i]]>=qk[p[i]])	q[l1++]=p[i];
		else	q[l2++]=p[i];
	}
	for(i=l;i<=r;i++)	p[i]=q[i];
	solve(l,mid,L,MID),solve(mid+1,r,MID+1,R);
}
int main()
{
	n=rd(),m=rd();
	int i,j;
	for(i=1;i<=n;i++)
		for(j=1;j<=n;j++)
			v[++n2].val=rd(),v[n2].x=i,v[n2].y=j;
	sort(v+1,v+n2+1,cmp);
	for(i=1;i<=m;i++)	q1[i]=rd(),q2[i]=rd(),q3[i]=rd(),q4[i]=rd(),qk[i]=rd(),p[i]=i;
	solve(1,m,1,n2);
	for(i=1;i<=m;i++)	printf("%d\n",ans[i]);
	return 0;
}
posted @ 2017-05-22 16:40  CQzhangyu  阅读(592)  评论(0编辑  收藏  举报