最大矩阵区间 题解

题意简述

给定 \(n\)\(m\) 列矩阵 \(A\)。对于每一行 \(i\),选择非空区间 \([l_i, r_i]\),满足 \(\forall i \in [1, n)\)\([l_i, r_i]\)\([l_{i + 1}, r_{i + 1}]\) 相交,即 \(\max \{ l_i, l_{i + 1} \} \leq \min \{ r_i, r_{i+1} \}\)。求所有选出区间的 \(A_{i, j}\) 值之和的最大值,即 \(\max \sum \limits _ {i = 1} ^ n \sum \limits _ {j = l_i} ^ {r_i} A_{i, j}\)

题目分析

一眼 DP。有一个 naive 的想法是记 \(f_{i, l, r}\) 表示前 \(i\) 行,第 \(i\) 行选出了 \([l, r]\) 的价值和,转移可以优化到 \(\Theta(nm^2)\),只列出转移方程,不展开。记 \(sum_{i, j} = \sum \limits _ {k = 1} ^ {j} A_{i, k}\)

初始:\(f_{1, l, r} = sum_{1, r} - sum_{1, l - 1}\)

\(g1_{j}\) 表示 \(i - 1\) 行,左端点是 \(j\) 的最大值,即 \(g1_{j} = \max \limits _ {o = j} ^ m f_{i - 1, j, o}\)

类似 \(g2_{j}\) 表示右端点是 \(j\),即 \(g2_{j} = \max \limits _ {o = 1} ^ j f_{i - 1, o, j}\)

\(g3_{l, r}\) 表示包含 \([l, r]\) 的最大值,即 \(g3_{l, r} = \max \limits _ {p = 1} ^ l \max \limits _ {q = r} ^ m f_{i - 1, p, q}\)。可以从大区间向小区间递推:\(g3_{l, r} = \max \{ f_{i - 1, l, r}, g3_{l - 1, r}, g3_{l, r + 1}, g3_{l - 1, r + 1} \}\)

\(f_{i, l, r}\) 的转移:

\[f_{i, l, r} = \max \{ \max \limits _ {o = l} ^ r g1_o, \max \limits _ {o = l} ^ r g2_o, g3_{l, r} \} + sum_{i, r} - sum_{i, l - 1} \]

上述 DP 的瓶颈显然在记录 \([l, r]\) 上了。有一个 naive 的想法,我们只记一个端点 \(j\),至于是左右端点我们不关心,因为我们发现似乎上一个区间的某一个端点在我们这个区间里,这两个区间一定是相交的。但是前者是后者充分不必要条件,考虑 \(i - 1\)\(l', r'\),如果有 \(l' \lt l \leq r \lt r'\),这种完全被包含的情况,我们没考虑到。但是事实证明,应该能骗到好多分。转移方程为:

\[\begin{aligned} f_{i, j} &= \max \Bigg \{ \max \limits _ {o = 1} ^ {j} \Big \{ \max _ {k = o} ^ {j} f_{i - 1, k} + \sum _ {k = o} ^ j A_{i, k} \Big \}, \max \limits _ {o = j} ^ {m} \Big \{ \max _ {k = j} ^ {o} f_{i - 1, k} + \sum _ {k = j} ^ o A_{i, k} \Big \} \Bigg \} \\ &= \max \Bigg \{ \max \limits _ {o = 1} ^ {j} \Big \{ \max _ {k = o} ^ {j} f_{i - 1, k} + sum_{i, j} - sum_{i, o - 1} \Big \}, \max \limits _ {o = j} ^ {m} \Big \{ \max _ {k = j} ^ {o} f_{i - 1, k} + sum_{i, o} - sum_{i, j - 1} \Big \} \Bigg \} \\ &= \max \Bigg \{ \max \limits _ {o = 1} ^ {j} \Big \{ \max _ {k = o} ^ {j} f_{i - 1, k} - sum_{i, o - 1} \Big \} + sum_{i, j}, \max \limits _ {o = j} ^ {m} \Big \{ \max _ {k = j} ^ {o} f_{i - 1, k} + sum_{i, o} \Big \} - sum_{i, j - 1} \Bigg \} \\ \end{aligned} \]

不妨对前一半讨论:\(\max \limits _ {o = 1} ^ {j} \Big \{ \max \limits _ {k = o} ^ {j} f_{i - 1, k} - sum_{i, o - 1} \Big \} + sum_{i, j}\)\(sum_{i, j}\) 为定值,考虑 \(j\) 从小到大扫,对于一个新的 \(j\),不会影响 \(sum_{i, o - 1}\),但是会更新一些 \(\max \limits _ {k = o} ^ {j} f_{i - 1, k}\),单调栈即可。数据结构用线段树维护。

我们发现只记端点来刻画一个状态还不够,那么就把 \(f_{i, j}\) 的定义改为第 \(i\) 行选出的区间包含 \(j\) 的最大值,这样,两个区间相交变成了有一个公共点被选中。

类似地,对于 \(f_{i, j}\),我们不妨考虑其和 \(i - 1\) 行的区间交在 \(o\)。转移时,除了 \(o \sim j\) 必选,我们也会贪心向外扩展一部分,记 \(pre_i\)\(suf_i\) 分别表示以 \(i\) 为左 / 右端点的最长连续子段和,有转移:

\[\begin{aligned} f_{i, j} &= \small \max \Bigg \{ \max _ {o = 1} ^ j \{ f_{i - 1, o} + pre_{o - 1} + suf_{j + 1} + sum_{i, j} - sum_{i, o - 1} \}, \max _ {o = j} ^ m \{ f_{i - 1, o} + suf_{o + 1} + pre_{j - 1} + sum_{i, o} - sum_{i, j - 1} \} \Bigg \} \\ &= \small \max \Bigg \{ \max _ {o = 1} ^ j \{ f_{i - 1, o} + pre_{o - 1} - sum_{i, o - 1} \} + suf_{j + 1} + sum_{i, j}, \max _ {o = j} ^ m \{ f_{i - 1, o} + suf_{o + 1} + sum_{i, o} \} + pre_{j - 1} - sum_{i, j - 1} \Bigg \} \\ \end{aligned} \]

扫一扫即可。

时间复杂度:\(\Theta(nm)\)

代码

点击查看水到 $90$ 分的代码
#include <cstdio>
#include <iostream>
using namespace std;

using lint = long long;

const lint inf = 0x3fffffffffffffff;

struct Array {
    lint buf[2000010];
    lint *val[1000010];
    void init(int n, int m) {
        val[0] = buf;
        for (int i = 1; i <= n; ++i)
            val[i] = val[i - 1] + m + 1;
    }
    lint * operator [] (const int x) {
        return val[x];
    }
} sum, f;

int n, m;

namespace yzh520 {
    lint f[110][110][110];  // i, l ~ r
    lint g1[110], g2[110], g3[110][110];
    void solve() {
        for (int l = 1; l <= m; ++l)
        for (int r = l; r <= m; ++r) {
            f[1][l][r] = sum[1][r] - sum[1][l - 1];
        }
        for (int i = 2; i <= n; ++i) {
            for (int l = 1; l <= m; ++l) {
                g1[l] = -inf;
                for (int r = l; r <= m; ++r)
                    g1[l] = max(g1[l], f[i - 1][l][r]);
            }
            for (int r = 1; r <= m; ++r) {
                g2[r] = -inf;
                for (int l = r; l >= 1; --l)
                    g2[r] = max(g2[r], f[i - 1][l][r]);
            }
            g3[0][m + 1] = -inf;
            g3[0][m] = -inf;
            g3[1][m + 1] = -inf;
            for (int l = 1; l <= m; ++l)
            for (int r = m; r >= l; --r) {
                g3[l][r] = f[i - 1][l][r];
                g3[l][r] = max(g3[l][r], g3[l - 1][r]);
                g3[l][r] = max(g3[l][r], g3[l][r + 1]);
                g3[l][r] = max(g3[l][r], g3[l - 1][r + 1]);
            }
            for (int l = 1; l <= m; ++l) {
                lint mx = -inf;
                for (int r = l; r <= m; ++r) {
                    mx = max(mx, g1[r]);
                    mx = max(mx, g2[r]);
                    f[i][l][r] = g3[l][r];
                    f[i][l][r] = max(f[i][l][r], mx);
                    f[i][l][r] += sum[i][r] - sum[i][l - 1];
                }
            }
        }
        
        lint ans = -inf;
        for (int l = 1; l <= m; ++l)
        for (int r = l; r <= m; ++r)
            ans = max(ans, f[n][l][r]);
        printf("%lld\n", ans);
    }
}

struct Segment_Tree {
    #define lson (idx << 1    )
    #define rson (idx << 1 | 1)
    
    struct node {
        int l, r;
        lint mx, lazy;
    } tree[1000010 << 2];
    
    inline void pushup(int idx) {
        tree[idx].mx = max(tree[lson].mx, tree[rson].mx);
    }
    
    void build(int idx, int l, int r, int i, bool tag) {
        tree[idx] = {l, r, 0, 0};
        if (l == r) {
            if (tag)
                tree[idx].mx = sum[i][l];
            else
                tree[idx].mx = -sum[i][l - 1];
            return;
        }
        int mid = (l + r) >> 1;
        build(lson, l, mid, i, tag);
        build(rson, mid + 1, r, i, tag);
        pushup(idx);
    }
    
    inline void pushtag(int idx, lint v) {
        tree[idx].mx += v, tree[idx].lazy += v;
    }
    
    inline void pushdown(int idx) {
        if (!tree[idx].lazy) return;
        pushtag(lson, tree[idx].lazy);
        pushtag(rson, tree[idx].lazy);
        tree[idx].lazy = 0;
    }
    
    void modify(int idx, int l, int r, lint v) {
        if (tree[idx].l > r || tree[idx].r < l) return;
        if (l <= tree[idx].l && tree[idx].r <= r) return pushtag(idx, v);
        pushdown(idx);
        modify(lson, l, r, v);
        modify(rson, l, r, v);
        pushup(idx);
    }
    
    lint query(int idx, int l, int r) {
        if (tree[idx].l > r || tree[idx].r < l) return -inf;
        if (l <= tree[idx].l && tree[idx].r <= r) return tree[idx].mx;
        pushdown(idx);
        return max(query(lson, l, r), query(rson, l, r));
    }
    
    #undef lson
    #undef rson
} yzh;

int stack[1000010], top;

signed main() {
    scanf("%d%d", &n, &m);
    sum.init(n, m);
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= m; ++j) {
            scanf("%lld", &sum[i][j]);
            sum[i][j] += sum[i][j - 1];
        }
    }
    #ifndef XuYueming
    if (n * m <= 100) return yzh520::solve(), 0;
    #endif
    f.init(n, m);
    for (int i = 1; i <= n; ++i) {
        yzh.build(1, 1, m, i, false);
        top = 0, stack[0] = 0;
        for (int j = 1; j <= m; ++j)
            f[i][j] = -inf;
        for (int j = 1; j <= m; ++j) {
            while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
                yzh.modify(1, stack[top - 1] + 1, stack[top], -f[i - 1][stack[top]]);
                --top;
            }
            yzh.modify(1, stack[top] + 1, j, f[i - 1][j]);
            stack[++top] = j;
            f[i][j] = max(f[i][j], yzh.query(1, 1, j) + sum[i][j]);
        }
        yzh.build(1, 1, m, i, true);
        top = 0, stack[0] = m + 1;
        for (int j = m; j >= 1; --j) {
            while (top && f[i - 1][stack[top]] <= f[i - 1][j]) {
                yzh.modify(1, stack[top], stack[top - 1] - 1, -f[i - 1][stack[top]]);
                --top;
            }
            yzh.modify(1, j, stack[top] - 1, f[i - 1][j]);
            stack[++top] = j;
            f[i][j] = max(f[i][j], yzh.query(1, j, m) - sum[i][j - 1]);
        }
    }
    lint ans = -inf;
    for (int i = 1; i <= m; ++i)
        ans = max(ans, f[n][i]);
    printf("%lld", ans);
    return 0;
}

当然还有正解,非常短,对着推出来的式子很好理解:

#include <cstdio>

using lint = long long;
const lint inf = 0x3f3f3f3f3f3f3f3f;
const int MAX = 1 << 26;

char ibuf[MAX], *p = ibuf;
#define getchar() *p++
#define isdigit(c) ('0' <= (c) && (c) <= '9')
inline void read(int &x) {
    x = 0; char ch = getchar(), f = 0;
    for (; !isdigit(ch); ch = getchar()) f |= ch == '-';
    for (;  isdigit(ch); ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ 48);
    f && (x = -x);
}

inline lint max(lint a, lint b) { return a > b ? a : b;}
template <typename T>
inline void swap(T* (&a), T* (&b)) { T* t = a; a = b; b = t; }

const int N = 1000010;

int n, m, val[N];
lint buf[N << 1], *f = buf, *g = buf + N;
lint sum[N], pre[N], suf[N];
lint mxp[N], mxs[N];

signed main() {
    fread(ibuf, 1, MAX, stdin);
    read(n), read(m);
    mxp[0] = mxs[m + 1] = -inf;
    for (int i = 1; i <= n; ++i) {
        swap(f, g);
        for (int j = 1; j <= m; ++j) {
            read(val[j]);
            sum[j] = sum[j - 1] + val[j];
            pre[j] = max(0, pre[j - 1] + val[j]);
            mxp[j] = max(mxp[j - 1], g[j] + pre[j - 1] - sum[j - 1]);
        }
        for (int j = m; j >= 1; --j) {
            suf[j] = max(0, suf[j + 1] + val[j]);
            mxs[j] = max(mxs[j + 1], g[j] + suf[j + 1] + sum[j]);
            f[j] = max(
                mxp[j] + suf[j + 1] + sum[j],
                mxs[j] + pre[j - 1] - sum[j - 1]
            );
        }
    }
    lint ans = -inf;
    for (int i = 1; i <= m; ++i)
        ans = max(ans, f[i]);
    printf("%lld", ans);
    return 0;
}
posted @ 2024-08-26 22:19  XuYueming  阅读(15)  评论(0编辑  收藏  举报