【学习笔记】Max 卷积 & 闵可夫斯基和

Max-Add 卷积 / 闵可夫斯基和

形如 \(\displaystyle f_{i} = \max_{k=0}^i\{g_k + h_{i - k}\}\) 的卷积形式,我们称它为 Max-Add 卷积。

如果 \((i, f_i)\) 能够形成一个凸包(即 \(f_i\) 为凸函数),那么 Max-Add 卷积实际上就是两个凸包的闵可夫斯基和。

考虑对 \(f_i\) 进行差分,由于是凸函数,这个差分得到的数组是单调的。同时,差分后 Max-Add 卷积就变成了从 \(g\)\(h\) 中分别选一个前缀,满足一共选 \(i\) 个数,最大化前缀和。很显然的贪心就是我们直接选前 \(i\) 大的就行,因为差分后得到的数组有单调性。

所以我们卷积后的差分数组其实就是原差分数组进行了归并排序。

vector<int> max_add_convolution(vector<int> a, vector<int> b) {
    for (int i = a.size() - 1; i >= 1; i--)
        a[i] -= a[i - 1];
    for (int i = b.size() - 1; i >= 1; i--)
        b[i] -= b[i - 1];
    vector<int> c(a.size() + b.size() - 1);
    c[0] = a[0] + b[0];
    merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
    for (int i = 1; i < a.size() + b.size() - 1; i++)
        c[i] += c[i - 1];
    return c;
}

优化 DP

有一类 DP 的形式形如 \(\displaystyle f_{i, j} = \max_{k < j} \{f_{i - 1, k} + a_i\}\),且满足 \(f_i\) 是凸函数,贡献与 \(j\) 没有关系,我们可以将这种 DP 改成区间 DP,这样转移就变成 Max-Add 卷积的形式了。

这样我们可以分治,然后每次将左右两边 Max-Add 卷积起来,这样复杂度就是 \(O(n \log n)\) 的。

2022 省选联测14 加减

可以发现 \(j\) 为奇数和偶数的时候分别为凸函数。

于是我们可以对奇数、偶数与第一个数为 + 还是 - 分开维护。

例如:第一个数为 + 的奇数可以由第一个数为 + 的偶数和第一个数为 + 的奇数合并得到,也可以由第一个数为 + 的奇数和第一个数为 - 的偶数合并得到,两者取 \(\max\) 即可。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500005;
typedef long long ll;
int n, a[MAXN];
vector<ll> Merge(vector<ll> a, vector<ll> b, bool ext) {
    vector<ll> c(a.size() + b.size() - 1);
    c[0] = a[0] + b[0];
    merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
    if (ext) {
        reverse(c.begin(), c.end());
        c.push_back(0);
        reverse(c.begin(), c.end());
    }
    return c;
}
vector<ll> Max(vector<ll> a, vector<ll> b) {
    for (int i = 1; i < a.size(); i++)
        a[i] += a[i - 1];
    for (int i = 1; i < b.size(); i++)
        b[i] += b[i - 1];
    int len = max(a.size(), b.size());
    while (a.size() < len) a.push_back(LLONG_MIN);
    while (b.size() < len) b.push_back(LLONG_MIN);
    for (int i = 0; i < a.size(); i++)
        a[i] = max(a[i], b[i]);
    for (int i = a.size() - 1; i >= 1; i--)
        a[i] -= a[i - 1];
    return a;
}
vector<vector<ll>> solve(int l = 1, int r = n) {
    if (l == r) {
        return {{ a[l] }, { -a[l] }, { 0 }, { 0 }};
    }
    ll mid = (l + r) >> 1;
    vector<vector<ll>> ret(4), L = solve(l, mid), R = solve(mid + 1, r);
    ll len = r - l + 1;
    ret[0] = Max(Merge(L[0], R[3], 0), Merge(L[2], R[0], 0));
    ret[1] = Max(Merge(L[1], R[2], 0), Merge(L[3], R[1], 0));
    ret[2] = Max(Merge(L[2], R[2], 0), Merge(L[0], R[1], 1));
    ret[3] = Max(Merge(L[3], R[3], 0), Merge(L[1], R[0], 1));
    return ret;
}
int main() {
    freopen("jia.in", "r", stdin);
    freopen("jia.out", "w", stdout);
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    auto ans = solve();
    for (int i = 1; i < ans[0].size(); i++) {
        ans[0][i] += ans[0][i - 1];
    }
    for (int i = 1; i < ans[2].size(); i++) {
        ans[2][i] += ans[2][i - 1];
    }
    for (int i = 1; i <= n; i++) {
        if (i & 1) {
            printf("%lld ", ans[0][i / 2]);
        } else {
            printf("%lld ", ans[2][i / 2]);
        }
    }
    return 0;
}

Gym - 103202L Forged in the Barrens

直接从区间 DP 去考虑,我们设 \(f[0/1/2][0/1/2]\) 为区间左边是否有一个 +/- 或者没有,区间右边是否有一个 +/- 或者没有,然后就转移就行了。我们可以将第二维加在 + 上,用 + 代表一个区间。

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
typedef long long ll;
int n, a[MAXN];
vector<ll> Merge(vector<ll> a, vector<ll> b) {
    vector<ll> c(a.size() + b.size() - 1);
    c[0] = a[0] + b[0];
    merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
    return c;
}

vector<ll> Max(vector<ll> a, vector<ll> b) {
    for (int i = 1; i < a.size(); i++)
        a[i] += a[i - 1];
    for (int i = 1; i < b.size(); i++)
        b[i] += b[i - 1];
    int len = max(a.size(), b.size());
    while (a.size() < len) a.push_back(LLONG_MIN);
    while (b.size() < len) b.push_back(LLONG_MIN);
    for (int i = 0; i < a.size(); i++)
        a[i] = max(a[i], b[i]);
    for (int i = a.size() - 1; i >= 1; i--)
        a[i] -= a[i - 1];
    return a;
}
const long long INF = 1e15;
vector<vector<ll>> solve(int l = 1, int r = n) {
    if (l == r) {
        return {
            { 0, 0 }, { -INF, INF + a[l] }, { -a[l] },
            { -INF, INF + a[l] }, { -INF }, { -INF },
            { -a[l] }, { -INF }, { -INF }
        };
    }
    ll mid = (l + r) >> 1;
    vector<vector<ll>> ret(9), L = solve(l, mid), R = solve(mid + 1, r);
    for (int ll = 0; ll < 3; ll++) {
        for (int rr = 0; rr < 3; rr++) {
            ret[ll + rr * 3] = Max(Max(Merge(L[ll], R[rr * 3]), Merge(L[ll + 3], R[rr * 3 + 2])), Merge(L[ll + 6], R[rr * 3 + 1]));
            ret[ll + rr * 3] = Max(ret[ll + rr * 3], Max(L[ll + rr * 3], R[ll + rr * 3]));
        }
    }
    // printf("merge(%d, %d):\n", l, r);
    // for (int i = 0; i < 9; i++) {
    //     printf("  ret[%d]: ", i);
    //     long long sum = 0;
    //     for (ll j : ret[i]) sum += j, printf("%lld ", sum);
    //     printf("\n");
    // }
    return ret;
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
    }
    auto ans = solve();
    for (int i = 1; i < ans[0].size(); i++) {
        ans[0][i] += ans[0][i - 1];
    }
    for (int i = 1; i <= n; i++) {
        printf("%lld\n", ans[0][i]);
    }
    return 0;
}

感谢 _ICEY_ dalao 的帮助。

Gym - 104128H Factories Once More

较为进阶且比较基础的应用。

\(f_{u, i}\) 表示 \(u\) 子树内的权值和。那么转移有:

\[f_{u, i} = \max_{j = 0}^i f'_{u, i - j} + f_{v, j} + w_{u, v} j (k-j) \]

后者是一个凸函数,那么就可以做闵可夫斯基和。

可以拿平衡树维护差分数组,支持区间加等差数列(\(w_{u, v} j (k-j)\) 的差分为等差数列),插入一个数,然后树上启发式合并,复杂度 \(O(n \log^2 n)\)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1000005;
int n, k;
vector<pair<int, int>> e[MAXN];
mt19937 Rand(chrono::system_clock::now().time_since_epoch().count());
struct Treap {
    int lc[MAXN], rc[MAXN], rnd[MAXN];
    long long val[MAXN], k[MAXN], b[MAXN];
    int siz[MAXN];
    int tot;
    stack<int> s;
    int newNode(long long v) {
        int p = s.empty() ? ++tot : s.top();
        if (!s.empty()) s.pop();
        k[p] = b[p] = lc[p] = rc[p] = 0;
        val[p] = v, siz[p] = 1, rnd[p] = Rand();
        return p;
    }
    void pushUp(int p) {
        siz[p] = siz[lc[p]] + siz[rc[p]] + 1;
    }
    void tag(int p, long long K, long long B) {
        val[p] += K * (siz[lc[p]] + 1) + B;
        k[p] += K, b[p] += B;
    }
    void pushDown(int p) {
        if (k[p] || b[p]) {
            if (lc[p]) tag(lc[p], k[p], b[p]);
            if (rc[p]) tag(rc[p], k[p], k[p] * (siz[lc[p]] + 1) + b[p]);
            k[p] = b[p] = 0;
        }
    }
    void split(long long v, int p, int &x, int &y) {
        if (!p) x = y = 0;
        else {
            pushDown(p);
            if (v > val[p]) {
                y = p;
                split(v, lc[p], x, lc[p]);
            } else {
                x = p;
                split(v, rc[p], rc[p], y);
            }
            pushUp(p);
        }
    }
    int merge(int x, int y) {
        if (!x || !y) return x + y;
        pushDown(x), pushDown(y);
        if (rnd[x] < rnd[y]) {
            rc[x] = merge(rc[x], y);
            pushUp(x);
            return x;
        } else {
            lc[y] = merge(x, lc[y]);
            pushUp(y);
            return y;
        }
    }
    void flatten(int p, vector<long long> &v) {
        s.push(p);
        pushDown(p);
        if (lc[p]) flatten(lc[p], v);
        v.push_back(val[p]);
        if (rc[p]) flatten(rc[p], v);
    }
    void insert(int &p, long long v) {
        int x, y; split(v, p, x, y);
        p = merge(merge(x, newNode(v)), y);
    }
} t;
int root[MAXN];
int siz[MAXN];
void dfs(int u, int pre) {
    siz[u] = 1;
    int s = 0;
    for (auto p : e[u]) if (p.first != pre) {
        int v = p.first, w = p.second;
        dfs(v, u);
        t.tag(root[v], -2ll * w, 1ll * w * (k + 1));
        siz[u] += siz[v];
        if (siz[v] > siz[s]) s = v;
    }
    if (s) root[u] = root[s];
    t.insert(root[u], 0);
    for (auto p : e[u]) if (p.first != pre && p.first != s) {
        int v = p.first;
        vector<long long> val;
        t.flatten(root[v], val);
        for (long long w : val) {
            t.insert(root[u], w);
        }
    }
    assert(siz[u] == t.siz[root[u]]);
}
int main() {
    // freopen("H.in", "r", stdin);
    scanf("%d%d", &n, &k);
    for (int i = 1; i < n; i++) {
        int u, v, w; scanf("%d%d%d", &u, &v, &w);
        e[u].push_back({v, w});
        e[v].push_back({u, w});
    }
    int rt = 1;
    dfs(rt, 0);
    long long ans = 0;
    vector<long long> val;
    t.flatten(root[rt], val);
    for (int i = 1; i <= k; i++) {
        ans += val[i - 1];
    }
    printf("%lld\n", ans);
    // printf("tot = %d\n", t.tot);
    return 0;
}
/*
6 3
1 2 3
2 3 2
2 4 1
1 5 2
5 6 3

*/
posted @ 2023-01-10 19:14  APJifengc  阅读(2194)  评论(3编辑  收藏  举报