BZOJ2738矩阵乘法——整体二分+二维树状数组
题目描述
给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。
输入
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
输出
对于每组询问输出第K小的数。
样例输入
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
2 1
3 4
1 2 1 2 1
1 1 2 2 3
样例输出
1
3
3
提示
矩阵中数字是109以内的非负整数;
20%的数据:N<=100,Q<=1000;
40%的数据:N<=300,Q<=10000;
60%的数据:N<=400,Q<=30000;
100%的数据:N<=500,Q<=60000。
考虑暴力,对于每次询问二分答案,求权值$\le mid$的点的个数是否有$k$个,显然时间复杂度爆炸。
我们将所有询问一起二分,也就是整体二分。
先将每个点的权值排序,每次将权值$\le mid$的点加入到矩阵中,然后对当前要处理的所有询问进行查询。
如果查询矩形内点个数大于等于$k$那么说明这个询问的答案要在$[l,mid]$中,就将这个询问归为左区间,否则将询问的$k$减掉这次查询的结果,归为右区间。这样递归下去即可得到每个询问的答案。
至于查询矩形内点数用二维树状数组差分一下即可。
每层处理完不要忘记把树状数组清空。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<bitset> #include<vector> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int n,m; int v[510][510]; struct lty { int x,y,val; }a[250010]; int cnt; int num; struct miku { int a,b,c,d,k; }q[60010]; int ans[60010]; int s[60010]; int ql[60010]; int qr[60010]; bool cmp(lty a,lty b) { return a.val<b.val; } void add(int x,int y,int val) { for(int i=x;i<=n;i+=i&-i) { for(int j=y;j<=n;j+=j&-j) { v[i][j]+=val; } } } int ask(int x,int y) { int res=0; for(int i=x;i;i-=i&-i) { for(int j=y;j;j-=j&-j) { res+=v[i][j]; } } return res; } int query(int a,int b,int c,int d) { return ask(c,d)-ask(c,b-1)-ask(a-1,d)+ask(a-1,b-1); } void solve(int l,int r,int L,int R) { if(L>R) { return ; } if(l==r) { for(int i=L;i<=R;i++) { ans[s[i]]=a[l].val; } return ; } int mid=(l+r)>>1; for(int i=l;i<=mid;i++) { add(a[i].x,a[i].y,1); } int sl=L,sr=R; for(int i=L;i<=R;i++) { int res=query(q[s[i]].a,q[s[i]].b,q[s[i]].c,q[s[i]].d); if(res>=q[s[i]].k) { ql[sl++]=s[i]; } else { qr[sr--]=s[i]; q[s[i]].k-=res; } } for(int i=L;i<sl;i++) { s[i]=ql[i]; } for(int i=sr+1;i<=R;i++) { s[i]=qr[i]; } for(int i=l;i<=mid;i++) { add(a[i].x,a[i].y,-1); } solve(l,mid,L,sl-1),solve(mid+1,r,sr+1,R); } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { for(int j=1;j<=n;j++) { num++; scanf("%d",&a[num].val); a[num].x=i,a[num].y=j; } } sort(a+1,a+1+num,cmp); for(int i=1;i<=m;i++) { cnt++; scanf("%d%d%d%d%d",&q[cnt].a,&q[cnt].b,&q[cnt].c,&q[cnt].d,&q[cnt].k); s[i]=i; } solve(1,num,1,cnt); for(int i=1;i<=m;i++) { printf("%d\n",ans[i]); } }