树状数组题【1】

题目描述

给你一个\(N*N\)的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第\(K\)小数。

输入描述:

第一行两个数\(N,Q\),表示矩阵大小和询问组数;
接下来\(N\)\(N\)列一共\(N*N\)个数,表示这个矩阵;
再接下来\(Q\)行每行5个数描述一个询问:\(x1,y1,x2,y2,k\)
表示找到以\((x1,y1)\)为左上角、以\((x2,y2)\)为右下角的子矩形中的第\(K\)小数。

输出描述:

对于每组询问输出第\(K\)小的数。

  • 输入

2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
  • 输出

1
3

数据范围:

\[N<=500,Q<=60000 \]

题解

这是一道整体二分的经典题目。
这道题显然可以给每个询问二分答案,统计该询问矩阵中小于等于mid的元素个数。如果大于等于k,说明猜大了,否则说明猜小了。
如果用这种方法的话,对于每个询问都至少要用O(询问矩阵大小*log值域)的时间复杂度解决,多组询问的话时间不能接受。
发现多个询问的二分答案是可以同时被检验的,我们可以为所有询问同时二分答案,把所有答案小于等于mid的询问放在询问序列的左侧,大于mid的放到询问序列的右侧然后递归处理。
这样为什么会快呢?我们每次可以用把矩阵中小于等于mid的元素染成黑色,剩下的元素保持白色。这样对于每一个询问的检验,就相当于是统计某个子矩阵中黑点的个数。为什么不用二维树状数组维护前缀和呢?
最开始的时候整个矩阵都设为0,然后让所有小于等于mid的位置加1,\(O(N^2\log^2N)\)
N)处理完整个矩阵之后再\(O(\log^2N)\)去应付每一个询问。被询问的子矩阵中可能有很多重叠的部分,这样保证了在每次询问过程中,矩阵中的每个元素只对运行时间做一次贡献,所以这个算法要比对于每个询问单独二分要快得多。
这样做的时间复杂度为什么是对的呢?我们是在二分值域,考虑二分过程中的每一层对答案的贡献。

  1. 对于每一层二分,矩阵中的每个元素最多被加入树状数组一次。
  2. 对于每一层二分,每个询问只会被处理一次。
  3. 二分值域的过程中最多只会出现\(O(\log N)\)层。
    时间复杂度\(O((N^2 + Q)\log^3)\)

code

#include <bits/stdc++.h>
#define int long long
#define re register int
using namespace std;
inline void read(int &x){
    x=0;char ch=getchar();
    for(;!isdigit(ch);ch=getchar());
    for(; isdigit(ch);ch=getchar())
		x=(x<<1)+(x<<3)+(ch^48);}
inline int print(int x){if(x>9) print(x/10);putchar(x%10+48);}
const int N=510,Q=6e4+10;;
int n,test,cnt,x,t[N][N];
struct rec{int x,y,k;}a[N*N];
struct node{int a,b,c,d,k,num,ans;}q[Q],L[Q],R[Q];
bool cmp(node x,node y){return x.num<y.num;}
bool cmp1(rec x,rec y){return x.k<y.k;}
void ins(int x,int y,int d){
    for(re i=x;i<=n;i+=i&(-i))
        for(re j=y;j<=n;j+=j&(-j))
            t[i][j]+=d;
}
int get(int x,int y){
    int sum=0;
    for(re i=x;i>0;i-=i&(-i))
        for(re j=y;j>0;j-=j&(-j))
            sum+=t[i][j];
    return sum;
}
int count(int a,int b,int c,int d){return get(c,d)-get(a-1,d)-get(c,b-1)+get(a-1,b-1);}
void solve(int l,int r,int ql,int qr){
    if(ql>qr) return;
    if(l==r){for(re i=ql;i<=qr;++i)q[i].ans=a[l].k;return;}
    int mid=(l+r)/2;    
    for(re i=l;i<=mid;++i) ins(a[i].x,a[i].y,1);
    int cnt1=0,cnt2=0;
    for(re d,i=ql;i<=qr;++i){
        d=count(q[i].a,q[i].b,q[i].c,q[i].d);
        if (q[i].k<=d) L[++cnt1]=q[i];
        else q[i].k-=d,R[++cnt2]=q[i];
    }
    for(re i=l;i<=mid ;++i) ins(a[i].x,a[i].y,-1);
    for(re i=1;i<=cnt1;++i) q[ql+i-1]=L[i];
    for(re i=1;i<=cnt2;++i) q[ql+cnt1-1+i]=R[i];
    solve(l,mid,ql,ql+cnt1-1);
    solve(mid+1,r,ql+cnt1,qr);
}
signed main(){
    read(n),read(test);
    for(re i=1;i<=n;++i)
        for(re j=1;j<=n;++j)
            read(x),a[++cnt]=(rec){i,j,x};
    for(re i=1;i<=test;++i)read(q[i].a),read(q[i].b),read(q[i].c),read(q[i].d),read(q[i].k),q[i].num=i;
    sort(a+1,a+cnt+1,cmp1),solve(1,cnt,1,test),sort(q+1,q+test+1,cmp);
    for (re i=1;i<=test;++i) print(q[i].ans),puts("");
}
posted @ 2018-09-23 21:23  Sparks_Pion  阅读(125)  评论(0编辑  收藏  举报