题解 网格图

传送门

暴力是 \(n^4\) 的,扫描线优化到 \(n^3\) 就能过了

Code:
#include <bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define N 510
#define ll long long
//#define int long long

int n, k;
int id[N][N], bel[N][N], fa[300010], tot, siz[300010], mp[N][N];
char s[N];
const int dlt[][2]={{-1,0},{1,0},{0,-1},{0,1}};
inline int find(int p) {return fa[p]==p?p:fa[p]=find(fa[p]);}

namespace force{
	int sta[300010], top, ans, sum[N][N];
	bool vis[300010];
	inline int qsum(int i, int j) {return sum[i][j]-sum[i-k][j]-sum[i][j-k]+sum[i-k][j-k];}
	void solve() {
		int lim=n*n;
		for (int i=1; i<=lim; ++i) fa[i]=i, siz[i]=1;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) id[i][j]=++tot;
		for (int i=1; i<=n; ++i)
			for (int j=1; j<=n; ++j) {
				if (mp[i][j]==1) {
					for (int k=0; k<4; ++k) {
						int x=i+dlt[k][0], y=j+dlt[k][1];
						if (mp[x][y]==1) {
							int f1=find(id[x][y]), f2=find(id[i][j]);
							if (f1!=f2) fa[f1]=f2, siz[f2]+=siz[f1];
						}
					}
				}
				else sum[i][j]=1;
			}
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) bel[i][j]=find(id[i][j]);
		// cout<<"---bel---"<<endl;
		// for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<bel[i][j]<<' '; cout<<endl;}
		// cout<<"siz: "; for (int i=1; i<=tot; ++i) cout<<siz[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) sum[i][j]+=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1];
		for (int i=k; i<=n; ++i) {
			for (int j=k; j<=n; ++j) {
				// cout<<"ij: "<<i<<' '<<j<<endl;
				int sum2=0;
				for (int s=i-k+1; s<=i; ++s) {
					for (int t=j-k+1; t<=j; ++t) if (mp[s][t]==1 && !vis[bel[s][t]]) {
						vis[bel[s][t]]=1;
						sta[++top]=bel[s][t];
						sum2+=siz[bel[s][t]];
						// cout<<"add: "<<bel[s][t]<<' '<<siz[bel[s][t]]<<endl;
					}
				}
				for (int x=i-k,y=j-k+1; y<=j; ++y) if (mp[x][y]==1 && !vis[bel[x][y]]) {vis[bel[x][y]]=1; sta[++top]=bel[x][y]; sum2+=siz[bel[x][y]];}
				for (int x=i+1,y=j-k+1; y<=j; ++y) if (mp[x][y]==1 && !vis[bel[x][y]]) {vis[bel[x][y]]=1; sta[++top]=bel[x][y]; sum2+=siz[bel[x][y]];}
				for (int x=i-k+1,y=j-k; x<=i; ++x) if (mp[x][y]==1 && !vis[bel[x][y]]) {vis[bel[x][y]]=1; sta[++top]=bel[x][y]; sum2+=siz[bel[x][y]];}
				for (int x=i-k+1,y=j+1; x<=i; ++x) if (mp[x][y]==1 && !vis[bel[x][y]]) {vis[bel[x][y]]=1; sta[++top]=bel[x][y]; sum2+=siz[bel[x][y]];}
				// cout<<"sum: "<<sum2<<' '<<qsum(i, j)<<endl;
				sum2+=qsum(i, j);
				ans=max(ans, sum2);
				while (top) vis[sta[top--]]=0;
			}
		}
		printf("%d\n", ans);
		exit(0);
	}
}

namespace task1{
	int sta[300010], top, ans, sum[N][N];
	bool vis[300010], vis2[300010];
	inline int qsum(int i, int j) {return sum[i][j]-sum[i-k][j]-sum[i][j-k]+sum[i-k][j-k];}
	void solve() {
		int lim=n*n;
		for (int i=1; i<=lim; ++i) fa[i]=i, siz[i]=1;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) id[i][j]=++tot;
		for (int i=1; i<=n; ++i)
			for (int j=1; j<=n; ++j) {
				if (mp[i][j]==1) {
					for (int k=0; k<4; ++k) {
						int x=i+dlt[k][0], y=j+dlt[k][1];
						if (mp[x][y]==1) {
							int f1=find(id[x][y]), f2=find(id[i][j]);
							if (f1!=f2) fa[f1]=f2, siz[f2]+=siz[f1];
						}
					}
				}
				else sum[i][j]=1;
			}
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) bel[i][j]=find(id[i][j]);
		// cout<<"---bel---"<<endl;
		// for (int i=1; i<=n; ++i) {for (int j=1; j<=n; ++j) cout<<bel[i][j]<<' '; cout<<endl;}
		// cout<<"siz: "; for (int i=1; i<=tot; ++i) cout<<siz[i]<<' '; cout<<endl;
		for (int i=1; i<=n; ++i) for (int j=1; j<=n; ++j) sum[i][j]+=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1];
		for (int j=k; j<=n; ++j) {
			int insum=0;
			memset(vis, 0, sizeof(vis));
			for (int s=1; s<=k; ++s) {
				for (int t=j-k+1; t<=j; ++t) if (mp[s][t]==1 && !vis[bel[s][t]]) {
					vis[bel[s][t]]=1;
					insum+=siz[bel[s][t]];
					// cout<<"add: "<<bel[s][t]<<' '<<siz[bel[s][t]]<<endl;
				}
			}
			for (int i=k; i<=n; ++i) {
				int outsum=0;
				for (int x=i-k,y=j-k+1; y<=j; ++y) if (mp[x][y]==1 && !vis2[bel[x][y]]) {
					if (vis[bel[x][y]]) vis[bel[x][y]]=0, insum-=siz[bel[x][y]];
					vis2[bel[x][y]]=1; sta[++top]=bel[x][y]; outsum+=siz[bel[x][y]];
				}
				for (int x=i+1,y=j-k+1; y<=j; ++y) if (mp[x][y]==1 && !vis2[bel[x][y]]) {
					if (vis[bel[x][y]]) vis[bel[x][y]]=0, insum-=siz[bel[x][y]];
					vis2[bel[x][y]]=1; sta[++top]=bel[x][y]; outsum+=siz[bel[x][y]];
				}
				for (int x=i-k+1,y=j-k; x<=i; ++x) if (mp[x][y]==1 && !vis2[bel[x][y]]) {
					if (vis[bel[x][y]]) vis[bel[x][y]]=0, insum-=siz[bel[x][y]];
					vis2[bel[x][y]]=1; sta[++top]=bel[x][y]; outsum+=siz[bel[x][y]];
				}
				for (int x=i-k+1,y=j+1; x<=i; ++x) if (mp[x][y]==1 && !vis2[bel[x][y]]) {
					if (vis[bel[x][y]]) vis[bel[x][y]]=0, insum-=siz[bel[x][y]];
					vis2[bel[x][y]]=1; sta[++top]=bel[x][y]; outsum+=siz[bel[x][y]];
				}
				ans=max(ans, insum+outsum+qsum(i, j));
				while (top) vis2[sta[top--]]=0;
				for (int x=i+1,y=j-k+1; y<=j; ++y) if (mp[x][y]==1 && !vis[bel[x][y]]) vis[bel[x][y]]=1, insum+=siz[bel[x][y]];
			}
		}
		printf("%d\n", ans);
		exit(0);
	}
}

signed main()
{	freopen("grid.in", "r", stdin);
	freopen("grid.out", "w", stdout);

	scanf("%d%d", &n, &k);
	for (int i=1; i<=n; ++i) {
		scanf("%s", s+1);
		for (int j=1; j<=n; ++j) mp[i][j]=(s[j]=='.')?1:2;
	}
	// force::solve();
	task1::solve();
	
	return 0;
}
posted @ 2021-09-29 19:02  Administrator-09  阅读(5)  评论(0编辑  收藏  举报