【洛谷P1527】【JZOJ2908】矩阵乘法【二维树状数组】【整体二分】

题目大意:

题目链接:https://www.luogu.org/problem/P1527
给你一个n×nn\times n的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第kk小数。


思路:

整体二分的好题。
这道题的空间吃的很紧,所以在线做法几乎是不可能的。所以可以考虑离线预处理。
这道题没有修改操作,所以可以把所有询问整体二分,然后再分别选择进入区间[l,mid][l,mid]还是[mid+1,r][mid+1,r]
我们每次二分midmid,可以把问题模型转换为小于等于midmid的数字全部染成黑色,然后求矩阵里黑色点的数量。如果小于等于kk就进入左区间,否则进入右区间。
所以我们需要维护支持二维加法以及前缀和的数据结构,用二维树状数组即可。常数还小。


思路:

#include <cstdio>
#include <algorithm>
using namespace std;

const int N=1010,M=150010;
int n,m,ans[M];

struct Martix
{
	int x,y,k;
}a[N*N];

struct Ask
{
	int x1,x2,y1,y2,k,id;
}ask[M],p1[M],p2[M];

struct Bit
{
	int c[N][N];
	
	int lowbit(int x)
	{
		return x&-x;
	}
	
	void add(int x,int y,int val)
	{
		if (!x || !y) return;
		for (int i=x;i<=n;i+=lowbit(i))
			for (int j=y;j<=n;j+=lowbit(j))
				c[i][j]+=val;
	}
	
	int ask(int x,int y)
	{
		if (!x || !y) return 0;
		int ans=0;
		for (int i=x;i;i-=lowbit(i))
			for (int j=y;j;j-=lowbit(j))
				ans+=c[i][j];
		return ans;
	}
}bit;

bool cmp(Martix x,Martix y)
{
	return x.k<y.k;
}

void solve(int l,int r,int ql,int qr)
{
	if (ql>qr) return;
	if (l==r)
	{
		for (int i=ql;i<=qr;i++)
			ans[ask[i].id]=a[l].k;
		return;
	}
	int mid=(l+r)>>1,tot1=0,tot2=0;
	for (int i=l;i<=mid;i++)
		bit.add(a[i].x,a[i].y,1);
	for (int i=ql;i<=qr;i++)
	{
		int sum=bit.ask(ask[i].x2,ask[i].y2)-bit.ask(ask[i].x1-1,ask[i].y2)-bit.ask(ask[i].x2,ask[i].y1-1)+bit.ask(ask[i].x1-1,ask[i].y1-1);
		if (sum>=ask[i].k) p1[++tot1]=ask[i];
			else ask[i].k-=sum,p2[++tot2]=ask[i];
	}
	for (int i=l;i<=mid;i++)
		bit.add(a[i].x,a[i].y,-1);
	int i=1; tot2+=tot1;
	for (;i<=tot1;i++) ask[ql+i-1]=p1[i];
	for (;i<=tot2;i++) ask[ql+i-1]=p2[i-tot1];
	solve(l,mid,ql,ql+tot1-1); solve(mid+1,r,ql+tot1,qr);
}

int main()
{
	scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++)
		for (int j=1;j<=n;j++)
		{
			scanf("%d",&a[i*n-n+j].k);
			a[i*n-n+j].x=i;
			a[i*n-n+j].y=j;
		}
	sort(a+1,a+1+n*n,cmp);
	for (int i=1;i<=m;i++)
	{
		scanf("%d%d%d%d%d",&ask[i].x1,&ask[i].y1,&ask[i].x2,&ask[i].y2,&ask[i].k);
		ask[i].id=i;
	}
	solve(1,n*n,1,m);
	for (int i=1;i<=m;i++)
		printf("%d\n",ans[i]);
	return 0;
}
posted @ 2019-08-23 20:55  全OI最菜  阅读(103)  评论(0编辑  收藏  举报