最大矩阵区间 题解
题意简述
给定 \(n\) 行 \(m\) 列矩阵 \(A\)。对于每一行 \(i\),选择非空区间 \([l_i, r_i]\),满足 \(\forall i \in [1, n)\),\([l_i, r_i]\) 和 \([l_{i + 1}, r_{i + 1}]\) 相交,即 \(\max \{ l_i, l_{i + 1} \} \leq \min \{ r_i, r_{i+1} \}\)。求所有选出区间的 \(A_{i, j}\) 值之和的最大值,即 \(\max \sum \limits _ {i = 1} ^ n \sum \limits _ {j = l_i} ^ {r_i} A_{i, j}\)。
题目分析
一眼 DP。有一个 naive 的想法是记 \(f_{i, l, r}\) 表示前 \(i\) 行,第 \(i\) 行选出了 \([l, r]\) 的价值和,转移可以优化到 \(\Theta(nm^2)\),只列出转移方程,不展开。记 \(sum_{i, j} = \sum \limits _ {k = 1} ^ {j} A_{i, k}\)。
初始:\(f_{1, l, r} = sum_{1, r} - sum_{1, l - 1}\)。
\(g1_{j}\) 表示 \(i - 1\) 行,左端点是 \(j\) 的最大值,即 \(g1_{j} = \max \limits _ {o = j} ^ m f_{i - 1, j, o}\)。
类似 \(g2_{j}\) 表示右端点是 \(j\),即 \(g2_{j} = \max \limits _ {o = 1} ^ j f_{i - 1, o, j}\)。
\(g3_{l, r}\) 表示包含 \([l, r]\) 的最大值,即 \(g3_{l, r} = \max \limits _ {p = 1} ^ l \max \limits _ {q = r} ^ m f_{i - 1, p, q}\)。可以从大区间向小区间递推:\(g3_{l, r} = \max \{ f_{i - 1, l, r}, g3_{l - 1, r}, g3_{l, r + 1}, g3_{l - 1, r + 1} \}\)。
\(f_{i, l, r}\) 的转移:
上述 DP 的瓶颈显然在记录 \([l, r]\) 上了。有一个 naive 的想法,我们只记一个端点 \(j\),至于是左右端点我们不关心,因为我们发现似乎上一个区间的某一个端点在我们这个区间里,这两个区间一定是相交的。但是前者是后者充分不必要条件,考虑 \(i - 1\) 的 \(l', r'\),如果有 \(l' \lt l \leq r \lt r'\),这种完全被包含的情况,我们没考虑到。但是事实证明,应该能骗到好多分。转移方程为:
不妨对前一半讨论:\(\max \limits _ {o = 1} ^ {j} \Big \{ \max \limits _ {k = o} ^ {j} f_{i - 1, k} - sum_{i, o - 1} \Big \} + sum_{i, j}\)。\(sum_{i, j}\) 为定值,考虑 \(j\) 从小到大扫,对于一个新的 \(j\),不会影响 \(sum_{i, o - 1}\),但是会更新一些 \(\max \limits _ {k = o} ^ {j} f_{i - 1, k}\),单调栈即可。数据结构用线段树维护。
我们发现只记端点来刻画一个状态还不够,那么就把 \(f_{i, j}\) 的定义改为第 \(i\) 行选出的区间包含 \(j\) 的最大值,这样,两个区间相交变成了有一个公共点被选中。
类似地,对于 \(f_{i, j}\),我们不妨考虑其和 \(i - 1\) 行的区间交在 \(o\)。转移时,除了 \(o \sim j\) 必选,我们也会贪心向外扩展一部分,记 \(pre_i\),\(suf_i\) 分别表示以 \(i\) 为左 / 右端点的最长连续子段和,有转移:
扫一扫即可。
时间复杂度:\(\Theta(nm)\)。
代码
点击查看水到 $90$ 分的代码
#include <cstdio>
#include <iostream>
using namespace std;
using lint = long long;
const lint inf = 0x3fffffffffffffff;
struct Array {
lint buf[2000010];
lint *val[1000010];
void init(int n, int m) {
val[0] = buf;
for (int i = 1; i <= n; ++i)
val[i] = val[i - 1] + m + 1;
}
lint * operator [] (const int x) {
return val[x];
}
} sum, f;
int n, m;
namespace yzh520 {
lint f[110][110][110]; // i, l ~ r
lint g1[110], g2[110], g3[110][110];
void solve() {
for (int l = 1; l <= m; ++l)
for (int r = l; r <= m; ++r) {
f[1][l][r] = sum[1][r] - sum[1][l - 1];
}
for (int i = 2; i <= n; ++i) {
for (int l = 1; l <= m; ++l) {
g1[l] = -inf;
for (int r = l; r <= m; ++r)
g1[l] = max(g1[l], f[i - 1][l][r]);
}
for (int r = 1; r <= m; ++r) {
g2[r] = -inf;
for (int l = r; l >= 1; --l)
g2[r] = max(g2[r], f[i - 1][l][r]);
}
g3[0][m + 1] = -inf;
g3[0][m] = -inf;
g3[1][m + 1] = -inf;
for (int l = 1; l <= m; ++l)
for (int r = m; r >= l; --r) {
g3[l][r] = f[i - 1][l][r];
g3[l][r] = max(g3[l][r], g3[l - 1][r]);
g3[l][r] = max(g3[l][r], g3[l][r + 1]);
g3[l][r] = max(g3[l][r], g3[l - 1][r + 1]);
}
for (int l = 1; l <= m; ++l) {
lint mx = -inf;
for (int r = l; r <= m; ++r) {
mx = max(mx, g1[r]);
mx = max(mx, g2[r]);
f[i][l][r] = g3[l][r];
f[i][l][r] = max(f[i][l][r], mx);
f[i][l][r] += sum[i][r] - sum[i][l - 1];
}
}
}
lint ans = -inf;
for (int l = 1; l <= m; ++l)
for (int r = l; r <= m; ++r)
ans = max(ans, f[n][l][r]);
printf("%lld\n", ans);
}
}
struct Segment_Tree {
#define lson (idx << 1 )
#define rson (idx << 1 | 1)
struct node {
int l, r;
lint mx, lazy;
} tree[1000010 << 2];
inline void pushup(int idx) {
tree[idx].mx = max(tree[lson].mx, tree[rson].mx);
}
void build(int idx, int l, int r, int i, bool tag) {
tree[idx] = {l, r, 0, 0};
if (l == r) {
if (tag)
tree[idx].mx = sum[i][l];
else
tree[idx].mx = -sum[i][l - 1];
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid, i, tag);
build(rson, mid + 1, r, i, tag);
pushup(idx);
}
inline void pushtag(int idx, lint v) {
tree[idx].mx += v, tree[idx].lazy += v;
}
inline void pushdown(int idx) {
if (!tree[idx].lazy) return;
pushtag(lson, tree[idx].lazy);
pushtag(rson, tree[idx].lazy);
tree[idx].lazy = 0;
}
void modify(int idx, int l, int r, lint v) {
if (tree[idx].l > r || tree[idx].r < l) return;
if (l <= tree[idx].l && tree[idx].r <= r) return pushtag(idx, v);
pushdown(idx);
modify(lson, l, r, v);
modify(rson, l, r, v);
pushup(idx);
}
lint query(int idx, int l, int r) {
if (tree[idx].l > r || tree[idx].r < l) return -inf;
if (l <= tree[idx].l && tree[idx].r <= r) return tree[idx].mx;
pushdown(idx);
return max(query(lson, l, r), query(rson, l, r));
}
#undef lson
#undef rson
} yzh;
int stack[1000010], top;
signed main() {
scanf("%d%d", &n, &m);
sum.init(n, m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%lld", &sum[i][j]);
sum[i][j] += sum[i][j - 1];
}
}
#ifndef XuYueming
if (n * m <= 100) return yzh520::solve(), 0;
#endif
f.init(n, m);
for (int i = 1; i <= n; ++i) {
yzh.build(1, 1, m, i, false);
top = 0, stack[0] = 0;
for (int j = 1; j <= m; ++j)
f[i][j] = -inf;
for (int j = 1; j <= m; ++j) {
while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
yzh.modify(1, stack[top - 1] + 1, stack[top], -f[i - 1][stack[top]]);
--top;
}
yzh.modify(1, stack[top] + 1, j, f[i - 1][j]);
stack[++top] = j;
f[i][j] = max(f[i][j], yzh.query(1, 1, j) + sum[i][j]);
}
yzh.build(1, 1, m, i, true);
top = 0, stack[0] = m + 1;
for (int j = m; j >= 1; --j) {
while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
yzh.modify(1, stack[top], stack[top - 1] - 1, -f[i - 1][stack[top]]);
--top;
}
yzh.modify(1, j, stack[top] - 1, f[i - 1][j]);
stack[++top] = j;
f[i][j] = max(f[i][j], yzh.query(1, j, m) - sum[i][j - 1]);
}
}
lint ans = -inf;
for (int i = 1; i <= m; ++i)
ans = max(ans, f[n][i]);
printf("%lld", ans);
return 0;
}
当然还有正解,非常短,对着推出来的式子很好理解:
#include <cstdio>
using lint = long long;
const lint inf = 0x3f3f3f3f3f3f3f3f;
const int MAX = 1 << 26;
char ibuf[MAX], *p = ibuf;
#define getchar() *p++
#define isdigit(c) ('0' <= (c) && (c) <= '9')
inline void read(int &x) {
x = 0; char ch = getchar(), f = 0;
for (; !isdigit(ch); ch = getchar()) f |= ch == '-';
for (; isdigit(ch); ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
f && (x = -x);
}
inline lint max(lint a, lint b) { return a > b ? a : b;}
template <typename T>
inline void swap(T* (&a), T* (&b)) { T* t = a; a = b; b = t; }
const int N = 1000010;
int n, m, val[N];
lint buf[N << 1], *f = buf, *g = buf + N;
lint sum[N], pre[N], suf[N];
lint mxp[N], mxs[N];
signed main() {
fread(ibuf, 1, MAX, stdin);
read(n), read(m);
mxp[0] = mxs[m + 1] = -inf;
for (int i = 1; i <= n; ++i) {
swap(f, g);
for (int j = 1; j <= m; ++j) {
read(val[j]);
sum[j] = sum[j - 1] + val[j];
pre[j] = max(0, pre[j - 1] + val[j]);
mxp[j] = max(mxp[j - 1], g[j] + pre[j - 1] - sum[j - 1]);
}
for (int j = m; j >= 1; --j) {
suf[j] = max(0, suf[j + 1] + val[j]);
mxs[j] = max(mxs[j + 1], g[j] + suf[j + 1] + sum[j]);
f[j] = max(
mxp[j] + suf[j + 1] + sum[j],
mxs[j] + pre[j - 1] - sum[j - 1]
);
}
}
lint ans = -inf;
for (int i = 1; i <= m; ++i)
ans = max(ans, f[i]);
printf("%lld", ans);
return 0;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/18381604。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。