BZOJ2738矩阵乘法——整体二分+二维树状数组

题目描述

给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。

输入

 
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。

输出

对于每组询问输出第K小的数。

样例输入

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3

样例输出

1
3

提示

矩阵中数字是109以内的非负整数;
20%的数据:N<=100,Q<=1000;
40%的数据:N<=300,Q<=10000;
60%的数据:N<=400,Q<=30000;
100%的数据:N<=500,Q<=60000。

 

考虑暴力,对于每次询问二分答案,求权值$\le mid$的点的个数是否有$k$个,显然时间复杂度爆炸。

我们将所有询问一起二分,也就是整体二分。

先将每个点的权值排序,每次将权值$\le mid$的点加入到矩阵中,然后对当前要处理的所有询问进行查询。

如果查询矩形内点个数大于等于$k$那么说明这个询问的答案要在$[l,mid]$中,就将这个询问归为左区间,否则将询问的$k$减掉这次查询的结果,归为右区间。这样递归下去即可得到每个询问的答案。

至于查询矩形内点数用二维树状数组差分一下即可。

每层处理完不要忘记把树状数组清空。

#include<set>
#include<map>
#include<queue>
#include<stack>
#include<cmath>
#include<cstdio>
#include<bitset>
#include<vector>
#include<cstring>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
int n,m;
int v[510][510];
struct lty
{
	int x,y,val;
}a[250010];
int cnt;
int num;
struct miku
{
	int a,b,c,d,k;
}q[60010];
int ans[60010];
int s[60010];
int ql[60010];
int qr[60010];
bool cmp(lty a,lty b)
{
	return a.val<b.val;
}
void add(int x,int y,int val)
{
	for(int i=x;i<=n;i+=i&-i)
	{
		for(int j=y;j<=n;j+=j&-j)
		{
			v[i][j]+=val;
		}
	}
}
int ask(int x,int y)
{
	int res=0;
	for(int i=x;i;i-=i&-i)
	{
		for(int j=y;j;j-=j&-j)
		{
			res+=v[i][j];
		}
	}
	return res;
}
int query(int a,int b,int c,int d)
{
	return ask(c,d)-ask(c,b-1)-ask(a-1,d)+ask(a-1,b-1);
}
void solve(int l,int r,int L,int R)
{
	if(L>R)
	{
		return ;
	}
	if(l==r)
	{
		for(int i=L;i<=R;i++)
		{
			ans[s[i]]=a[l].val;
		}
		return ;
	}
	int mid=(l+r)>>1;
	for(int i=l;i<=mid;i++)
	{
		add(a[i].x,a[i].y,1);
	}
	int sl=L,sr=R;
	for(int i=L;i<=R;i++)
	{
		int res=query(q[s[i]].a,q[s[i]].b,q[s[i]].c,q[s[i]].d);
		if(res>=q[s[i]].k)
		{
			ql[sl++]=s[i];
		}
		else
		{
			qr[sr--]=s[i];
			q[s[i]].k-=res;
		}
	}
	for(int i=L;i<sl;i++)
	{
		s[i]=ql[i];
	}
	for(int i=sr+1;i<=R;i++)
	{
		s[i]=qr[i];
	}
	for(int i=l;i<=mid;i++)
	{
		add(a[i].x,a[i].y,-1);
	}
	solve(l,mid,L,sl-1),solve(mid+1,r,sr+1,R);
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		for(int j=1;j<=n;j++)
		{
			num++;
			scanf("%d",&a[num].val);
			a[num].x=i,a[num].y=j;
		}
	}
	sort(a+1,a+1+num,cmp);
	for(int i=1;i<=m;i++)
	{
		cnt++;
		scanf("%d%d%d%d%d",&q[cnt].a,&q[cnt].b,&q[cnt].c,&q[cnt].d,&q[cnt].k);
		s[i]=i;
	}
	solve(1,num,1,cnt);
	for(int i=1;i<=m;i++)
	{
		printf("%d\n",ans[i]);
	}
}
posted @ 2019-04-03 22:14  The_Virtuoso  阅读(229)  评论(0编辑  收藏  举报