luogu6600题解

题意中的字母 T 可以看作一个回文的 \(1\) 串和回文串中心带一个向下的 \(1\) 串。

我们先来考虑朴素做法,可以枚举这个中心然后扩展来找有几个符合要求的串。
朴素做法显然复杂度不对。

沿着朴素的思路优化。
如果只考虑 \(w\ge a\)\(h\ge b\) 这两个要求很容易想到容斥。
此时有四个条件,我们直观上不好容斥,因为 \(w\)\(h\) 会互相影响,在固定一个量与后两个条件的限制下互相设一个不固定的下限。

题目中的 \(a,b,s,x\) 都是固定的,预处理也许会有点用。
预处理的一个直接的想法就是计算出该中心的回文 \(1\) 串最长回文长度和向下 \(1\) 串的最长长度(下面称为最大的 T)对应的 T 拆出来的 T 符合要求的 T 的个数。

可以发现最大的 T 拆出来的 T 的回文 \(1\) 串长度小于等于最大的 T 的长度,向下 \(1\) 串同理。
这个形式有点像二维前缀和。
如果我们沿着前缀和的思路走下去,该前缀和建立的一个矩形数阵上每个数代表的是该点对应的 T 是否符合要求,若符合要求则为 \(1\),否则为 \(0\)

于是我们把这个预处理出来,枚举每个中心即可。
回文串最长长度用 Manacher 算法计算,向下 \(1\) 串通过第一个 \(1\) 处暴力向下求解,其他的 \(1\) 由上一个 \(1\) 的长度减去一,这样均摊的方式计算,算法整体的时间复杂度为 \(O(nm)\)

记得开 long long,否则喜提 \(40\) 分。

代码如下,其中的 mh 是计算该中心的最长回文串长度的,ms 是计算该中心向下最长 \(1\) 串长度的。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
constexpr int MAXN=3e3+10,MAXM=3e3+10;
int n,m,a,b,s,x,sum[MAXN][MAXM],mh[MAXM],ms[MAXN],jz[MAXN][MAXM];
ll ans;
template<typename T>
T read(){
	T x=0;
	char ch=getchar();
	while(ch<'0'||ch>'9'){ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+(ch^48);ch=getchar();}
	return x;
}
int gc(){
	int x=getchar();
	while(x==' '||x=='\r'||x=='\n')x=getchar();
	return x;
}
namespace sol{
	void solve(){
		n=read<int>();m=read<int>();
		a=max(read<int>(),3);b=max(read<int>(),2);s=read<int>();x=read<int>();
		//w>=a,h>=b,w*h>=s,w+h>=s,w横,h竖
		for(int i=1;i<=n;++i){//竖
			for(int j=1;((j<<1)-1)<=m;++j){//横一边长度
				if(((j<<1)-1)>=a&&i>=b&&(((j<<1)-1)*i)>=s&&(((j<<1)-1)+i)>=x){
					sum[i][j]=sum[i-1][j]+sum[i][j-1]+1-sum[i-1][j-1];
				}else{
					sum[i][j]=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1];
				}
			}
		}
		for(int i=1;i<=n;++i){
			for(int j=1;j<=m;++j){
				jz[i][j]=gc()-'0';
			}
		}
		for(int i=1;i<=n;++i){
			memset(mh,0,sizeof(mh));
			int p=0;
			for(int j=1;j<=m;++j){
				if(j<=p+mh[p]-1)mh[j]=min(p+mh[p]-j,mh[p-(j-p)]);
				else mh[j]=jz[i][j];
				while(j+mh[j]<=m&&j-mh[j]>0&&jz[i][j+mh[j]]==1&&jz[i][j-mh[j]]==1)++mh[j];
				if(!jz[i][j])ms[j]=0;
				else if(!ms[j]){
					ms[j]=1;
					while(i+ms[j]<=n&&jz[i+ms[j]][j]==1)++ms[j];
				}else --ms[j];
				ans+=sum[ms[j]][mh[j]];
				if(mh[j]+j>mh[p]+p)p=j;
			}
		}
		printf("%lld\n",ans);
	}
}
int main(){
	sol::solve();
	return 0;
}
posted @ 2024-02-15 11:46  LiJoQiao  阅读(14)  评论(0编辑  收藏  举报