[IOI2019]矩形区域
题意:
给一个\(n×m\)矩阵,问有多少子矩阵,对于其中任意一个数,都满足它小于它的上下左右四个方向中第一个在矩阵外面的数。
\(1\leq n,m\leq 2500\)。
写一个\(O(nmlognm)\)的做法。
首先,对于每一行,每一列,分别求出其中的合法区间。
显然,只要区间中的最大值满足条件(即它小于区间外侧的两个数)即可。
因此,我们可以枚举最大值,用单调栈求出两侧不大于它的区间\((l,r)\),那么,只有\((l-1,r+1)\)可能满足条件。
这样,每个最大值只会贡献一个区间,因此总的区间数是\(O(nm)\)的。
求出这部分后,只要子矩阵的每行每列都能对应一个合法区间,它就是合法的了。
再求出每个子区间向下(向右)能延伸的长度,这个可以边扫描,边更新。
求答案时,我们可以枚举这个子矩阵最上边的那行,并枚举那行中的合法区间,设它的向下延伸长度为\(x\),区间长度为\(y\)。
之后,再枚举它的左端点的列中,上端点位于这行的一个区间。
那么,只要这个区间长度不大于\(x\),向右延伸长度不小于\(y\),就会产生一个答案。
可以发现,这是一个类似二维数点的问题,直接扫描线+树状数组维护即可。
注意细节
代码:
#include <stdio.h>
#include <vector>
using namespace std;
struct SJd
{
int l,r;
SJd(){}
SJd(int L,int R)
{
l=L;r=R;
}
};
int getlr(int sz[2502],int n,SJd jg[2502])
{
int l[2502],r[2502],m=0;
for(int i=0;i<n;i++)
l[i]=r[i]=i;
for(int i=0;i<n;i++)
{
while(l[i]>0&&sz[l[i]-1]<=sz[i])
l[i]=l[l[i]-1];
}
for(int i=n-1;i>=0;i--)
{
while(r[i]+1<n&&sz[r[i]+1]<sz[i])
r[i]=r[r[i]+1];
}
for(int i=0;i<n;i++)
{
if(l[i]>0&&r[i]+1<n&&sz[r[i]+1]>sz[i])
jg[m++]=SJd(l[i],r[i]);
}
return m;
}
SJd ha[2502][2502],li[2502][2502];
int sh[2502],sl[2502],sz[2502][2502];
int dn[2502][2502],ri[2502][2502],dp[2502][2502];
bool bk[2502][2502];
struct SPx
{
int j,ri;
SPx(){}
SPx(int J,int RI)
{
j=J;ri=RI;
}
};
vector<SPx> ve[2502][2502];
int C[2502][2502],N;
void add(int a,int i)
{
i+=1;
for(int j=i;j>0;j-=j&(-j))
C[a][j]+=1;
}
void clean(int a,int i)
{
i+=1;
for(int j=i;j>0;j-=j&(-j))
C[a][j]=0;
}
int sum(int a,int i)
{
int jg=0;i+=1;
for(int j=i;j<=N;j+=j&(-j))
jg+=C[a][j];
return jg;
}
vector<int> vv[2510];
int main()
{
int n,m;
scanf("%d%d",&n,&m);N=m;
for(int i=0;i<n;i++)
{
for(int j=0;j<m;j++)
scanf("%d",&sz[i][j]);
}
for(int i=0;i<n;i++)
sh[i]=getlr(sz[i],m,ha[i]);
for(int i=0;i<m;i++)
{
int jh[2502];
for(int j=0;j<n;j++)
jh[j]=sz[j][i];
sl[i]=getlr(jh,n,li[i]);
}
for(int i=0;i<m;i++)
{
for(int j=0;j<m;j++)
dp[i][j]=-1;
}
for(int i=n-1;i>=0;i--)
{
for(int j=0;j<sh[i];j++)
{
int l=ha[i][j].l,r=ha[i][j].r;
dp[l][r]+=1;bk[l][r]=true;
dn[i][j]=dp[l][r];
}
if(i<n-1)
{
for(int j=0;j<sh[i+1];j++)
{
int l=ha[i+1][j].l,r=ha[i+1][j].r;
if(!bk[l][r])
dp[l][r]=-1;
}
}
for(int j=0;j<sh[i];j++)
{
int l=ha[i][j].l,r=ha[i][j].r;
bk[l][r]=false;
}
}
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
dp[i][j]=-1;
}
for(int i=m-1;i>=0;i--)
{
for(int j=0;j<sl[i];j++)
{
int l=li[i][j].l,r=li[i][j].r;
dp[l][r]+=1;bk[l][r]=true;
ri[i][j]=dp[l][r];
}
if(i<m-1)
{
for(int j=0;j<sl[i+1];j++)
{
int l=li[i+1][j].l,r=li[i+1][j].r;
if(!bk[l][r])
dp[l][r]=-1;
}
}
for(int j=0;j<sl[i];j++)
{
int l=li[i][j].l,r=li[i][j].r;
bk[l][r]=false;
}
}
for(int i=0;i<m-1;i++)
{
for(int j=0;j<sl[i];j++)
ve[li[i][j].l][li[i][j].r-li[i][j].l].push_back(SPx(i,ri[i][j]));
}
int ans=0;
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
vv[j].clear();
for(int j=0;j<sh[i];j++)
vv[dn[i][j]].push_back(j);
for(int k=0;k<n;k++)
{
for(int s=0;s<ve[i][k].size();s++)
add(ve[i][k][s].j,ve[i][k][s].ri);
for(int s=0;s<vv[k].size();s++)
{
int j=vv[k][s],l=ha[i][j].l,r=ha[i][j].r,z=r-l;
ans+=sum(l,z);
}
}
for(int k=0;k<n;k++)
{
for(int s=0;s<ve[i][k].size();s++)
clean(ve[i][k][s].j,ve[i][k][s].ri);
}
}
printf("%d",ans);
return 0;
}