bzoj2658[ZJOI2012]小蓝的好友(mrx)
题目链接
校内模拟赛出了这题,然后考试时写出了一个做法,评测机貌似有点慢被卡成90了,原题上可以A掉。首先可以知道\(O(n^2)\)的做法,就是只算全0的矩形个数,枚举下边界是哪一行,定义\(h\)值表示这个点向上碰到的第一个1的距离,然后可以从左到右枚举右边界,用单调栈维护以每个列作为左边界的矩形最大高度,那么以某一列作为右边界的对答案的贡献就是单调栈中的元素和。
然后数据随机,可以进一步优化这个做法,用线段树存每个点的高度值,当下边界向下移动时,相当于每个点的高度+1,且新出现的1位置的要设置成0。设新出现1的位置为x,由于每个跨过x的矩形的贡献都是0,要把它减掉,我们按照\(O(n^2)\)的想法,把x左边的单调栈通过线段树给做出来,然后显然后面的点r如果影响到了x以左的单调栈(就是存在一个\(l\le x\)且\(l\)$r-1$的矩形最大高度与$l$\(r\)的最大矩形高度不同),才会使贡献变化(若没有影响贡献不变,直接用原来的贡献乘区间长度即可),然后就维护一下x以左的单调栈。为了找到这样的r,可以从x开始每次往右找第一个高度小于x的点,x跳过去,这样的每个点都是会影响贡献值的点。
复杂度\(O(qklogn)\),k是"从一个点开始每次向右跳到第一个比它小的点的期望跳的次数",由于数据随机k的平均大小是小于\(logn\)级别的。
刚刚去看了看别人的做法,怎么都是用treap啊,只能说想不到想不到。实际运行速度我的做法顶多比treap慢一倍的样子,可能是因为线段树比treap快了蛮多吧。
\(O(n^2)\)的:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<vector>
#include<algorithm>
#include<cmath>
#define P puts("lala")
#define cp cerr<<"lala"<<endl
#define ln putchar('\n')
#define pb push_back
#define fi first
#define se second
#define shmem(x) cerr<<sizeof(x)/(1024*1024.0)<<"MB"
using namespace std;
inline int read()
{
char ch=getchar();int g=1,re=0;
while(ch<'0'||ch>'9') {if(ch=='-')g=-1;ch=getchar();}
while(ch<='9'&&ch>='0') re=(re<<1)+(re<<3)+(ch^48),ch=getchar();
return re*g;
}
typedef long long ll;
typedef pair<int,int> pii;
const int N=2050;
int dot[N][N],up[N][N];
int n,m,T;
pii stk[N]; int top=0;
void wj()
{
freopen("alice.in","r",stdin);
freopen("alice.out","w",stdout);
}
int main()
{
wj();
n=read(); m=read(); T=read();
for(int i=1;i<=T;++i) dot[read()][read()]=1;
for(int i=1;i<=n;++i) for(int j=1;j<=m;++j)
up[i][j]=(dot[i][j]?0:up[i-1][j]+1);
ll ans=0;
for(int i=1;i<=n;++i)
{
ll sum=0;
for(int j=1;j<=m;++j)
{
pii v=pii(up[i][j],1);
while(top&&stk[top].fi>=v.fi)
{
sum-=1ll*stk[top].fi*stk[top].se;
v.se+=stk[top].se;
top--;
}
stk[++top]=v;
sum+=1ll*v.fi*v.se;
ans+=sum;
}
}
printf("%lld\n",(1ll*n*(n+1)/2)*(1ll*m*(m+1)/2)-ans);
return 0;
}
\(O(qklogn)\)的:
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<vector>
#include<algorithm>
#include<cmath>
#include<ctime>
#define P puts("lala")
#define cp cerr<<"lala"<<endl
#define ln putchar('\n')
#define pb push_back
#define fi first
#define se second
#define shmem(x) cerr<<sizeof(x)/(1024*1024.0)<<"MB"
using namespace std;
inline int read()
{
char ch=getchar();int g=1,re=0;
while(ch<'0'||ch>'9') {if(ch=='-')g=-1;ch=getchar();}
while(ch<='9'&&ch>='0') re=(re<<1)+(re<<3)+(ch^48),ch=getchar();
return re*g;
}
typedef long long ll;
typedef pair<int,int> pii;
const int N=70050;
int n,m,T;
int minv[N<<2],addv[N<<2],add=0;
inline void modify(int o,int l,int r,int x,int k)
{
if(l==r) {minv[o]=k; return ;}
int mid=l+r>>1;
if(x<=mid) modify(o<<1,l,mid,x,k);
else modify(o<<1|1,mid+1,r,x,k);
minv[o]=min(minv[o<<1],minv[o<<1|1]);
}
inline int query(int o,int l,int r,int x)
{
if(l==r) return minv[o];
int mid=l+r>>1;
if(x<=mid) return query(o<<1,l,mid,x);
else return query(o<<1|1,mid+1,r,x);
}
int ret=0,found=0,val;
inline void gorig(int o,int l,int r,int x)
{
if(l==r) {if(minv[o]<x) ret=l,found=1,val=minv[o];return ;}
int mid=l+r>>1;
if(minv[o<<1|1]<x) gorig(o<<1|1,mid+1,r,x);
else gorig(o<<1,l,mid,x);
}
inline void findpre(int o,int l,int r,int x,int y,int k)
{
if(found) return ;
if(x<=l&&r<=y)
{
if(minv[o]<k) gorig(o,l,r,k);
return ;
}
int mid=l+r>>1;
if(y>mid) findpre(o<<1|1,mid+1,r,x,y,k);
if(x<=mid) findpre(o<<1,l,mid,x,y,k);
}
inline int getpre(int x,int y,int k)
{
if(x>y) return 0;
ret=0; found=0; val=0; findpre(1,1,m,x,y,k);
return ret;
}
inline void golef(int o,int l,int r,int x)
{
if(l==r) {if(minv[o]<x) ret=l,found=1,val=minv[o];return ;}
int mid=l+r>>1;
if(minv[o<<1]<x) golef(o<<1,l,mid,x);
else golef(o<<1|1,mid+1,r,x);
}
inline void findnex(int o,int l,int r,int x,int y,int k)
{
if(found) return ;
if(x<=l&&r<=y)
{
if(minv[o]<k) golef(o,l,r,k);
return ;
}
int mid=l+r>>1;
if(x<=mid) findnex(o<<1,l,mid,x,y,k);
if(y>mid) findnex(o<<1|1,mid+1,r,x,y,k);
}
inline int getnex(int x,int y,int k)
{
if(x>y) return m+1;
ret=m+1; found=0; val=0; findnex(1,1,m,x,y,k);
return ret;
}
pii stk[N],stk2[N]; int top=0,top2=0;
vector<int>ask[N];
void wj()
{
freopen("alice.in","r",stdin);
freopen("alice.out","w",stdout);
}
int main()
{
wj();
clock_t sta=clock();
n=read(); m=read(); T=read();
for(int i=1;i<=T;++i)
{
int x=read(),y=read();
ask[x].pb(y);
}
ll ans=0,sum=0;
for(int row=1;row<=n;++row)
{
add++;
sum+=1ll*(m+1)*m/2;
for(int i=0,siz=ask[row].size();i<siz;++i)
{
top=0; top2=0;
int in=ask[row][i];
ll tot=0;
int hi=query(1,1,m,in);
while(in)
{
int las=getpre(1,in-1,hi);
stk2[++top2]=pii(hi,in-las);
sum-=1ll*(hi+add)*(in-las);
tot+=1ll*(hi+add)*(in-las);
in=las; hi=val;
}
top=top2;
for(int j=1;j<=top;++j) stk[top-j+1]=stk2[j];
in=ask[row][i];
hi=query(1,1,m,in);
while(in<=m)
{
int las=getnex(in+1,m,hi);
sum-=1ll*(las-in)*tot;
if(in==ask[row][i]) sum+=tot;
if(las>m) break;
pii v=pii((hi=val),0);
while(top&&stk[top].fi>=v.fi)
{
tot-=1ll*(stk[top].fi+add)*stk[top].se;
v.se+=stk[top].se;
top--;
}
stk[++top]=v;
tot+=1ll*(v.fi+add)*v.se;
in=las;
}
modify(1,1,m,ask[row][i],-add);
}
ans+=sum;
}
printf("%lld\n",(1ll*n*(n+1)/2)*(1ll*m*(m+1)/2)-ans);
//cerr<<(1ll*n*(n+1)/2)*(1ll*m*(m+1)/2)-ans<<endl;
clock_t fin=clock();
//cerr<<(double)(fin-sta)/CLOCKS_PER_SEC<<endl;
return 0;
}