最大矩阵区间 题解

题意简述

给定 nm 列矩阵 A。对于每一行 i,选择非空区间 [li,ri],满足 i[1,n)[li,ri][li+1,ri+1] 相交,即 max{li,li+1}min{ri,ri+1}。求所有选出区间的 Ai,j 值之和的最大值,即 maxi=1nj=liriAi,j

题目分析

一眼 DP。有一个 naive 的想法是记 fi,l,r 表示前 i 行,第 i 行选出了 [l,r] 的价值和,转移可以优化到 Θ(nm2),只列出转移方程,不展开。记 sumi,j=k=1jAi,k

初始:f1,l,r=sum1,rsum1,l1

g1j 表示 i1 行,左端点是 j 的最大值,即 g1j=maxo=jmfi1,j,o

类似 g2j 表示右端点是 j,即 g2j=maxo=1jfi1,o,j

g3l,r 表示包含 [l,r] 的最大值,即 g3l,r=maxp=1lmaxq=rmfi1,p,q。可以从大区间向小区间递推:g3l,r=max{fi1,l,r,g3l1,r,g3l,r+1,g3l1,r+1}

fi,l,r 的转移:

fi,l,r=max{maxo=lrg1o,maxo=lrg2o,g3l,r}+sumi,rsumi,l1

上述 DP 的瓶颈显然在记录 [l,r] 上了。有一个 naive 的想法,我们只记一个端点 j,至于是左右端点我们不关心,因为我们发现似乎上一个区间的某一个端点在我们这个区间里,这两个区间一定是相交的。但是前者是后者充分不必要条件,考虑 i1l,r,如果有 l<lr<r,这种完全被包含的情况,我们没考虑到。但是事实证明,应该能骗到好多分。转移方程为:

fi,j=max{maxo=1j{maxk=ojfi1,k+k=ojAi,k},maxo=jm{maxk=jofi1,k+k=joAi,k}}=max{maxo=1j{maxk=ojfi1,k+sumi,jsumi,o1},maxo=jm{maxk=jofi1,k+sumi,osumi,j1}}=max{maxo=1j{maxk=ojfi1,ksumi,o1}+sumi,j,maxo=jm{maxk=jofi1,k+sumi,o}sumi,j1}

不妨对前一半讨论:maxo=1j{maxk=ojfi1,ksumi,o1}+sumi,jsumi,j 为定值,考虑 j 从小到大扫,对于一个新的 j,不会影响 sumi,o1,但是会更新一些 maxk=ojfi1,k,单调栈即可。数据结构用线段树维护。

我们发现只记端点来刻画一个状态还不够,那么就把 fi,j 的定义改为第 i 行选出的区间包含 j 的最大值,这样,两个区间相交变成了有一个公共点被选中。

类似地,对于 fi,j,我们不妨考虑其和 i1 行的区间交在 o。转移时,除了 oj 必选,我们也会贪心向外扩展一部分,记 preisufi 分别表示以 i 为左 / 右端点的最长连续子段和,有转移:

fi,j=max{maxo=1j{fi1,o+preo1+sufj+1+sumi,jsumi,o1},maxo=jm{fi1,o+sufo+1+prej1+sumi,osumi,j1}}=max{maxo=1j{fi1,o+preo1sumi,o1}+sufj+1+sumi,j,maxo=jm{fi1,o+sufo+1+sumi,o}+prej1sumi,j1}

扫一扫即可。

时间复杂度:Θ(nm)

代码

点击查看水到 90 分的代码
#include <cstdio>
#include <iostream>
using namespace std;
using lint = long long;
const lint inf = 0x3fffffffffffffff;
struct Array {
lint buf[2000010];
lint *val[1000010];
void init(int n, int m) {
val[0] = buf;
for (int i = 1; i <= n; ++i)
val[i] = val[i - 1] + m + 1;
}
lint * operator [] (const int x) {
return val[x];
}
} sum, f;
int n, m;
namespace yzh520 {
lint f[110][110][110]; // i, l ~ r
lint g1[110], g2[110], g3[110][110];
void solve() {
for (int l = 1; l <= m; ++l)
for (int r = l; r <= m; ++r) {
f[1][l][r] = sum[1][r] - sum[1][l - 1];
}
for (int i = 2; i <= n; ++i) {
for (int l = 1; l <= m; ++l) {
g1[l] = -inf;
for (int r = l; r <= m; ++r)
g1[l] = max(g1[l], f[i - 1][l][r]);
}
for (int r = 1; r <= m; ++r) {
g2[r] = -inf;
for (int l = r; l >= 1; --l)
g2[r] = max(g2[r], f[i - 1][l][r]);
}
g3[0][m + 1] = -inf;
g3[0][m] = -inf;
g3[1][m + 1] = -inf;
for (int l = 1; l <= m; ++l)
for (int r = m; r >= l; --r) {
g3[l][r] = f[i - 1][l][r];
g3[l][r] = max(g3[l][r], g3[l - 1][r]);
g3[l][r] = max(g3[l][r], g3[l][r + 1]);
g3[l][r] = max(g3[l][r], g3[l - 1][r + 1]);
}
for (int l = 1; l <= m; ++l) {
lint mx = -inf;
for (int r = l; r <= m; ++r) {
mx = max(mx, g1[r]);
mx = max(mx, g2[r]);
f[i][l][r] = g3[l][r];
f[i][l][r] = max(f[i][l][r], mx);
f[i][l][r] += sum[i][r] - sum[i][l - 1];
}
}
}
lint ans = -inf;
for (int l = 1; l <= m; ++l)
for (int r = l; r <= m; ++r)
ans = max(ans, f[n][l][r]);
printf("%lld\n", ans);
}
}
struct Segment_Tree {
#define lson (idx << 1 )
#define rson (idx << 1 | 1)
struct node {
int l, r;
lint mx, lazy;
} tree[1000010 << 2];
inline void pushup(int idx) {
tree[idx].mx = max(tree[lson].mx, tree[rson].mx);
}
void build(int idx, int l, int r, int i, bool tag) {
tree[idx] = {l, r, 0, 0};
if (l == r) {
if (tag)
tree[idx].mx = sum[i][l];
else
tree[idx].mx = -sum[i][l - 1];
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid, i, tag);
build(rson, mid + 1, r, i, tag);
pushup(idx);
}
inline void pushtag(int idx, lint v) {
tree[idx].mx += v, tree[idx].lazy += v;
}
inline void pushdown(int idx) {
if (!tree[idx].lazy) return;
pushtag(lson, tree[idx].lazy);
pushtag(rson, tree[idx].lazy);
tree[idx].lazy = 0;
}
void modify(int idx, int l, int r, lint v) {
if (tree[idx].l > r || tree[idx].r < l) return;
if (l <= tree[idx].l && tree[idx].r <= r) return pushtag(idx, v);
pushdown(idx);
modify(lson, l, r, v);
modify(rson, l, r, v);
pushup(idx);
}
lint query(int idx, int l, int r) {
if (tree[idx].l > r || tree[idx].r < l) return -inf;
if (l <= tree[idx].l && tree[idx].r <= r) return tree[idx].mx;
pushdown(idx);
return max(query(lson, l, r), query(rson, l, r));
}
#undef lson
#undef rson
} yzh;
int stack[1000010], top;
signed main() {
scanf("%d%d", &n, &m);
sum.init(n, m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%lld", &sum[i][j]);
sum[i][j] += sum[i][j - 1];
}
}
#ifndef XuYueming
if (n * m <= 100) return yzh520::solve(), 0;
#endif
f.init(n, m);
for (int i = 1; i <= n; ++i) {
yzh.build(1, 1, m, i, false);
top = 0, stack[0] = 0;
for (int j = 1; j <= m; ++j)
f[i][j] = -inf;
for (int j = 1; j <= m; ++j) {
while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
yzh.modify(1, stack[top - 1] + 1, stack[top], -f[i - 1][stack[top]]);
--top;
}
yzh.modify(1, stack[top] + 1, j, f[i - 1][j]);
stack[++top] = j;
f[i][j] = max(f[i][j], yzh.query(1, 1, j) + sum[i][j]);
}
yzh.build(1, 1, m, i, true);
top = 0, stack[0] = m + 1;
for (int j = m; j >= 1; --j) {
while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
yzh.modify(1, stack[top], stack[top - 1] - 1, -f[i - 1][stack[top]]);
--top;
}
yzh.modify(1, j, stack[top] - 1, f[i - 1][j]);
stack[++top] = j;
f[i][j] = max(f[i][j], yzh.query(1, j, m) - sum[i][j - 1]);
}
}
lint ans = -inf;
for (int i = 1; i <= m; ++i)
ans = max(ans, f[n][i]);
printf("%lld", ans);
return 0;
}

当然还有正解,非常短,对着推出来的式子很好理解:

#include <cstdio>
using lint = long long;
const lint inf = 0x3f3f3f3f3f3f3f3f;
const int MAX = 1 << 26;
char ibuf[MAX], *p = ibuf;
#define getchar() *p++
#define isdigit(c) ('0' <= (c) && (c) <= '9')
inline void read(int &x) {
x = 0; char ch = getchar(), f = 0;
for (; !isdigit(ch); ch = getchar()) f |= ch == '-';
for (; isdigit(ch); ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
f && (x = -x);
}
inline lint max(lint a, lint b) { return a > b ? a : b;}
template <typename T>
inline void swap(T* (&a), T* (&b)) { T* t = a; a = b; b = t; }
const int N = 1000010;
int n, m, val[N];
lint buf[N << 1], *f = buf, *g = buf + N;
lint sum[N], pre[N], suf[N];
lint mxp[N], mxs[N];
signed main() {
fread(ibuf, 1, MAX, stdin);
read(n), read(m);
mxp[0] = mxs[m + 1] = -inf;
for (int i = 1; i <= n; ++i) {
swap(f, g);
for (int j = 1; j <= m; ++j) {
read(val[j]);
sum[j] = sum[j - 1] + val[j];
pre[j] = max(0, pre[j - 1] + val[j]);
mxp[j] = max(mxp[j - 1], g[j] + pre[j - 1] - sum[j - 1]);
}
for (int j = m; j >= 1; --j) {
suf[j] = max(0, suf[j + 1] + val[j]);
mxs[j] = max(mxs[j + 1], g[j] + suf[j + 1] + sum[j]);
f[j] = max(
mxp[j] + suf[j + 1] + sum[j],
mxs[j] + pre[j - 1] - sum[j - 1]
);
}
}
lint ans = -inf;
for (int i = 1; i <= m; ++i)
ans = max(ans, f[i]);
printf("%lld", ans);
return 0;
}
posted @   XuYueming  阅读(24)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 地球OL攻略 —— 某应届生求职总结
· 周边上新:园子的第一款马克杯温暖上架
· Open-Sora 2.0 重磅开源!
· 提示词工程——AI应用必不可少的技术
· .NET周刊【3月第1期 2025-03-02】
点击右上角即可分享
微信分享提示