P3625 [APIO2009]采油区域 题解
这道题是一道很好的二位前缀和问题。
然而码量有点大。
下面规定 \(n\) 表示行,\(m\) 表示列,\(n,m\) 同阶。
即计算复杂度的时候视 \(O(nm)\) 为 \(O(n^2)\)。
首先预处理 \(sum_{i,j}\) 表示从 \((1,1)\) 到 \((i,j)\) 的和,也就是二维前缀和,这里略过。
考虑将一个矩阵分成 3 块无非 6 种情况,如下:
上面 6 种情况又分为 2 类:
第一类是前面两种全横排与全竖排。
对于这一类情况,我们需要 \(O(n^2)\) 枚举两个分割行,然后 \(O(1)\) 求出答案。
因此这里引入两个辅助数组 \(Line_{i,j},Col_{i,j}\)。
- \(Line_{i,j}\) 表示第 \(i\) 行到第 \(j\) 行之间的最大 \(k \times k\) 子矩阵元素和。
- \(Col_{i,j}\) 表示第 \(i\) 列到第 \(j\) 列之间的最大 \(k \times k\) 子矩阵元素和。
然后考虑如何预处理出这两个数组。
下面以 \(Line_{i,j}\) 为例。
首先,对于形如 \(Line_{i,i+k-1}\) 的结果,由于总共只有 \(k\) 行,此时我们可以枚举这 \(k\) 行里面的子矩阵,至多只有 \(m-k+1\) 个。
这部分复杂度是 \(O(n^2)\)。
然后对于任意的 \(Line_{i,j}\),如果 \(j-i+1<k\),说明此时行数非法,空着就好(无贡献),否则采用区间 DP 方式转移:\(Line_{i,j}=\max(Line_{i+1,j},Line_{i,j-1})\)。
处理完 \(Line,Col\) 之后,我们就可以 \(O(n^2)\) 枚举,\(O(1)\) 算贡献了。
第二类是后面的四种情况,图我搬过来了:
细心观察上图可以发现,这一类情况的图中都有一个交点(即图中的红点)。
因此我们考虑 \(O(n^2)\) 枚举这个交点。
但是这样我们依然要 \(O(1)\) 算贡献。
因此我们除了前面的 \(Line,Col\),还要引入一个辅助数组:
\(f_{i,j,0/1/2/3}\) 表示以 \((i,j)\) 为分界点,左上/右上/左下/右下 的最大 \(k \times k\) 子矩阵元素和。包括 \((i,j)\) 这个点。
也就是向下面这张图:
接下来以预处理 \(f_{i,j,0}\) 为例:
还是看图。
假设右下角的点为 \((i,j)\),我们要算 \(f_{i,j,0}\)。
这个矩形被我分成了 3 类:
- 绿色的是 \(f_{i,j-1,0}\)。
- 蓝色的是 \(f_{i-1,j,0}\)。
- 黄色的是 \((i-k+1,j-k+1)\) 到 \((i,j)\) 这一块矩形的和。
显然答案只能是上述 3 者的最大值,于是我们可以 \(O(n^2)\) 处理完 \(f_{i,j,0}\)。
别的同理。
综上,对于第二类情况,我们就可以 \(O(n^2)\) 枚举交点,\(O(1)\) 计算答案了。
最后的时间复杂度是 \(O(n^2)\)。
几个小细节:
- 注意在计算 \(f_{i,j,0/1/2/3},Line_{i,j},Col_{i,j}\) 的时候是否合法的判断。
- 在转移的时候要注意分割线不能算两次。
比如说下面这样:
此时的正确答案计算式的其中一种为 \(f_{i,j,0}+f_{i,j+1,1}+Line_{i+1,n}\)。
千万注意分界线不能算两次。
- 注意 \(n,m\) 不能弄错。
- 在计算二维前缀和的时候注意不能出现左上角+右下角与左下角+右上角计算两种情况混着用。
换句话说,假设当前矩阵左上,右上,左下,右下分别为 \((2,2),(2,4),(4,2),(4,4)\),不能出现传进去 \((4,2),(2,4)\) 并且将其当成矩形左上角与右下角 的错误。 - 不能开
long long
,否则会 MLE。
至于答案手算一下就会发现实际上不会炸int
。
代码:
/*
========= Plozia =========
Author:Plozia
Problem:P3625 [APIO2009]采油区域
Date:2021/5/14
========= Plozia =========
*/
#include <bits/stdc++.h>
typedef long long LL;
const int MAXN = 1500 + 10;
int n, m, k, a[MAXN][MAXN], sum[MAXN][MAXN], f[MAXN][MAXN][4], Line[MAXN][MAXN], Col[MAXN][MAXN];
int read()
{
int sum = 0, fh = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
return sum * fh;
}
int Max(int fir, int sec) { return (fir > sec) ? fir : sec; }
int Min(int fir, int sec) { return (fir < sec) ? fir : sec; }
int Get(int r1, int c1, int r2, int c2)
{
if (r1 > r2) std::swap(r1, r2); if (c1 > c2) std::swap(c1, c2);
return sum[r2][c2] - sum[r2][c1 - 1] - sum[r1 - 1][c2] + sum[r1 - 1][c1 - 1];
}
void init()
{
f[k][k][0] = Get(1, 1, k, k);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
{
if (i > k) { f[i][j][0] = Max(f[i][j][0], f[i - 1][j][0]); }
if (j > k) { f[i][j][0] = Max(f[i][j][0], f[i][j - 1][0]); }
if (i - k + 1 > 0 && j - k + 1 > 0) { f[i][j][0] = Max(f[i][j][0], Get(i - k + 1, j - k + 1, i, j)); }
}
f[k][m - k + 1][1] = Get(1, m - k + 1, k, m);
for (int i = 1; i <= n; ++i)
for (int j = m; j >= 1; --j)
{
if (i > k) { f[i][j][1] = Max(f[i][j][1], f[i - 1][j][1]); }
if (j < m - k + 1) { f[i][j][1] = Max(f[i][j][1], f[i][j + 1][1]); }
if (i - k + 1 > 0 && j + k - 1 <= m) { f[i][j][1] = Max(f[i][j][1], Get(i - k + 1, j, i, j + k - 1)); }
}
f[n - k + 1][k][2] = Get(n - k + 1, 1, n, k);
for (int i = n; i >= 1; --i)
for (int j = 1; j <= m; ++j)
{
if (i < n - k + 1) { f[i][j][2] = Max(f[i][j][2], f[i + 1][j][2]); }
if (j > k) { f[i][j][2] = Max(f[i][j][2], f[i][j - 1][2]); }
if (i + k - 1 <= n && j - k + 1 > 0) { f[i][j][2] = Max(f[i][j][2], Get(i, j - k + 1, i + k - 1, j)); }
}
f[n - k + 1][m - k + 1][3] = Get(n - k + 1, m - k + 1, n, m);
for (int i = n; i >= 1; --i)
for (int j = m; j >= 1; --j)
{
if (i < n - k + 1) { f[i][j][3] = Max(f[i][j][3], f[i + 1][j][3]); }
if (j < m - k + 1) { f[i][j][3] = Max(f[i][j][3], f[i][j + 1][3]); }
if (i + k - 1 <= n && j + k - 1 <= m) { f[i][j][3] = Max(f[i][j][3], Get(i, j, i + k - 1, j + k - 1)); }
}
}
int main()
{
n = read(), m = read(), k = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
a[i][j] = read();
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
sum[i][j] = sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1] + a[i][j];
init();
for (int i = 1; i + k - 1 <= n; ++i)
{
for (int j = 1; j <= m; ++j)
{
if (j >= k) { Line[i][i + k - 1] = Max(Line[i][i + k - 1], Get(i, j - k + 1, i + k - 1, j)); }
if (j + k - 1 <= m) { Line[i][i + k - 1] = Max(Line[i][i + k - 1], Get(i, j, i + k - 1, j + k - 1)); }
}
}
for (int len = k + 1; len <= n; ++len)
{
for (int i = 1; i <= n; ++i)
{
int j = i + len - 1; if (j > n) break ;
Line[i][j] = Max(Line[i + 1][j], Line[i][j - 1]);
}
}
for (int j = 1; j + k - 1 <= m; ++j)
{
for (int i = 1; i <= n; ++i)
{
if (i >= k) { Col[j][j + k - 1] = Max(Col[j][j + k - 1], Get(i - k + 1, j, i, j + k - 1)); }
if (i + k - 1 <= n) { Col[j][j + k - 1] = Max(Col[j][j + k - 1], Get(i, j, i + k - 1, j + k - 1)); }
}
}
for (int len = k + 1; len <= m; ++len)
{
for (int i = 1; i <= m; ++i)
{
int j = i + len - 1; if (j > m) break ;
Col[i][j] = Max(Col[i + 1][j], Col[i][j - 1]);
}
}
int ans = 0;
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= m; ++j)
{
ans = Max(ans, Col[1][j - 1] + f[i][j][1] + f[i + 1][j][3]);
ans = Max(ans, Col[j + 1][m] + f[i][j][0] + f[i + 1][j][2]);
ans = Max(ans, f[i][j][0] + f[i][j + 1][1] + Line[i + 1][n]);
ans = Max(ans, f[i][j][2] + f[i][j + 1][3] + Line[1][i - 1]);
}
for (int i = k; i <= n - k + 1; ++i)
for (int j = i + k - 1; j <= n - k + 1; ++j)
ans = Max(ans, Line[1][i] + Line[i + 1][j - 1] + Line[j][n]);
for (int i = k; i <= m - k + 1; ++i)
for (int j = i + k - 1; j <= m - k + 1; ++j)
ans = Max(ans, Col[1][i] + Col[i + 1][j - 1] + Col[j][m]);
printf("%d\n", ans); return 0;
}