[bzoj1057][ZJOI2007]棋盘制作
本题有多种解法,我选择了单调栈进行解答。
首先进行一个巧妙的问题转换:
将所有i+j为奇数的格子反转,
这样,问题就转换成了求一个最大的0/1子矩阵。
先考虑一维的情况,h[i]表示以i为终点的最长连续0的长度,有h[i]=a[i]==0? h[i-1]+1:0,这样可以O(n)轻松求出。
拓展到高维,首先同样按照一维的方法,h[i][j]表示第i行以j为终点的最长连续0的长度,预处理出h[]。。接下来考虑一列一列来更新答案,对于单独的一列i,若以h[i][j]为子矩阵的一个边长,则它能往上下扩展的最大长度len就是另一个边长,所谓扩展就是向一个方向扫描知道碰到h值比自己小位置。
如果暴力扩展,复杂度就是O(n3)不能满足要求。
实际上,我们想要扩展,只是想要得到这样一个信息:
对于行i,其分别能够向上/向下扩展多少行?
只要我们知道这个信息,结合h[i]就可以求出一个矩形用于更新答案了。
我们考察单调栈:
对于一个新元素,如果比栈顶元素h小,说明这个元素就是扩展的下界,马上弹出栈顶元素进行计算即可。
当一个新元素插入栈时,容易知道,新元素一定比栈顶元素大,那么栈顶元素就是他的上界,这样我们就得到了上界。
具体实现上,我直接使用了stl。
下面上代码。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2005;
int n, m;
struct node {
int x, y, left;
const bool operator > (const node b) {
return this->left > b.left;
}
};
stack<node> s;
int c[maxn][maxn], x[maxn][maxn], y[maxn][maxn], up[maxn];
int ans1 = 0, ans2 = 0;
void solve1() {
for(int j = 1; j <= m; j++) {
memset(up, -1, sizeof(up));
for(int i = 1; i <= n; i++) {
node a = (node){i, j, x[i][j]};
node b =(node){0, j, 0};
if(!s.empty()){
b = s.top();
while(b > a) {
s.pop();
ans2 = max(ans2, (a.x - up[b.x]) * b.left);
ans1 = max(ans1, min(b.left, a.x - up[b.x]) * min(b.left, a.x-up[b.x]));
if(s.empty()){
b = (node){0, j, 0};
break;
}
b = s.top();
}
}
up[a.x] = b.x+1;
s.push(a);
}
while(!s.empty()) {
node a = s.top();
s.pop();
ans2 = max(ans2, (n+1 - up[a.x]) * a.left);
ans1 = max(ans1, min(a.left, n - up[a.x] + 1) * min(a.left, n-up[a.x]+1));
}
}
}
void solve2() {
for(int j = 1; j <= m; j++) {
memset(up, -1, sizeof(up));
for(int i = 1; i <= n; i++) {
node a = (node){i, j, y[i][j]};
node b =(node){0, j, 0};
if(!s.empty()){
b = s.top();
while(b > a) {
s.pop();
ans2 = max(ans2, (a.x - up[b.x]) * b.left);
ans1 = max(ans1, min(b.left, a.x - up[b.x]) * min(b.left, a.x-up[b.x]));
if(s.empty()){
b = (node){0, j, 0};
break;
}
b = s.top();
}
}
up[a.x] = b.x+1;
s.push(a);
}
while(!s.empty()) {
node a = s.top();
s.pop();
ans2 = max(ans2, (n - up[a.x] + 1) * a.left);
ans1 = max(ans1, (min(a.left, (n-up[a.x]+1)) * min(a.left, (n-up[a.x]+1))));
}
}
}
int main() {
//freopen("input", "r", stdin);
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= m; j++) {
scanf("%d", &c[i][j]);
if((i+j)%2==1) {
c[i][j] ^= 1;
}
x[i][j] = c[i][j] == 1?x[i][j-1]+1:0;
y[i][j] = c[i][j] == 0?y[i][j-1]+1:0;
}
}
solve1();
solve2();
printf("%d\n%d", ans1, ans2);
}