【学习笔记】Max 卷积 & 闵可夫斯基和
Max-Add 卷积 / 闵可夫斯基和
形如 \(\displaystyle f_{i} = \max_{k=0}^i\{g_k + h_{i - k}\}\) 的卷积形式,我们称它为 Max-Add 卷积。
如果 \((i, f_i)\) 能够形成一个凸包(即 \(f_i\) 为凸函数),那么 Max-Add 卷积实际上就是两个凸包的闵可夫斯基和。
考虑对 \(f_i\) 进行差分,由于是凸函数,这个差分得到的数组是单调的。同时,差分后 Max-Add 卷积就变成了从 \(g\) 与 \(h\) 中分别选一个前缀,满足一共选 \(i\) 个数,最大化前缀和。很显然的贪心就是我们直接选前 \(i\) 大的就行,因为差分后得到的数组有单调性。
所以我们卷积后的差分数组其实就是原差分数组进行了归并排序。
vector<int> max_add_convolution(vector<int> a, vector<int> b) {
for (int i = a.size() - 1; i >= 1; i--)
a[i] -= a[i - 1];
for (int i = b.size() - 1; i >= 1; i--)
b[i] -= b[i - 1];
vector<int> c(a.size() + b.size() - 1);
c[0] = a[0] + b[0];
merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
for (int i = 1; i < a.size() + b.size() - 1; i++)
c[i] += c[i - 1];
return c;
}
优化 DP
有一类 DP 的形式形如 \(\displaystyle f_{i, j} = \max_{k < j} \{f_{i - 1, k} + a_i\}\),且满足 \(f_i\) 是凸函数,贡献与 \(j\) 没有关系,我们可以将这种 DP 改成区间 DP,这样转移就变成 Max-Add 卷积的形式了。
这样我们可以分治,然后每次将左右两边 Max-Add 卷积起来,这样复杂度就是 \(O(n \log n)\) 的。
2022 省选联测14 加减
可以发现 \(j\) 为奇数和偶数的时候分别为凸函数。
于是我们可以对奇数、偶数与第一个数为 + 还是 - 分开维护。
例如:第一个数为 + 的奇数可以由第一个数为 + 的偶数和第一个数为 + 的奇数合并得到,也可以由第一个数为 + 的奇数和第一个数为 - 的偶数合并得到,两者取 \(\max\) 即可。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500005;
typedef long long ll;
int n, a[MAXN];
vector<ll> Merge(vector<ll> a, vector<ll> b, bool ext) {
vector<ll> c(a.size() + b.size() - 1);
c[0] = a[0] + b[0];
merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
if (ext) {
reverse(c.begin(), c.end());
c.push_back(0);
reverse(c.begin(), c.end());
}
return c;
}
vector<ll> Max(vector<ll> a, vector<ll> b) {
for (int i = 1; i < a.size(); i++)
a[i] += a[i - 1];
for (int i = 1; i < b.size(); i++)
b[i] += b[i - 1];
int len = max(a.size(), b.size());
while (a.size() < len) a.push_back(LLONG_MIN);
while (b.size() < len) b.push_back(LLONG_MIN);
for (int i = 0; i < a.size(); i++)
a[i] = max(a[i], b[i]);
for (int i = a.size() - 1; i >= 1; i--)
a[i] -= a[i - 1];
return a;
}
vector<vector<ll>> solve(int l = 1, int r = n) {
if (l == r) {
return {{ a[l] }, { -a[l] }, { 0 }, { 0 }};
}
ll mid = (l + r) >> 1;
vector<vector<ll>> ret(4), L = solve(l, mid), R = solve(mid + 1, r);
ll len = r - l + 1;
ret[0] = Max(Merge(L[0], R[3], 0), Merge(L[2], R[0], 0));
ret[1] = Max(Merge(L[1], R[2], 0), Merge(L[3], R[1], 0));
ret[2] = Max(Merge(L[2], R[2], 0), Merge(L[0], R[1], 1));
ret[3] = Max(Merge(L[3], R[3], 0), Merge(L[1], R[0], 1));
return ret;
}
int main() {
freopen("jia.in", "r", stdin);
freopen("jia.out", "w", stdout);
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
auto ans = solve();
for (int i = 1; i < ans[0].size(); i++) {
ans[0][i] += ans[0][i - 1];
}
for (int i = 1; i < ans[2].size(); i++) {
ans[2][i] += ans[2][i - 1];
}
for (int i = 1; i <= n; i++) {
if (i & 1) {
printf("%lld ", ans[0][i / 2]);
} else {
printf("%lld ", ans[2][i / 2]);
}
}
return 0;
}
Gym - 103202L Forged in the Barrens
直接从区间 DP 去考虑,我们设 \(f[0/1/2][0/1/2]\) 为区间左边是否有一个 +/- 或者没有,区间右边是否有一个 +/- 或者没有,然后就转移就行了。我们可以将第二维加在 + 上,用 + 代表一个区间。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
typedef long long ll;
int n, a[MAXN];
vector<ll> Merge(vector<ll> a, vector<ll> b) {
vector<ll> c(a.size() + b.size() - 1);
c[0] = a[0] + b[0];
merge(a.begin() + 1, a.end(), b.begin() + 1, b.end(), c.begin() + 1, greater<>());
return c;
}
vector<ll> Max(vector<ll> a, vector<ll> b) {
for (int i = 1; i < a.size(); i++)
a[i] += a[i - 1];
for (int i = 1; i < b.size(); i++)
b[i] += b[i - 1];
int len = max(a.size(), b.size());
while (a.size() < len) a.push_back(LLONG_MIN);
while (b.size() < len) b.push_back(LLONG_MIN);
for (int i = 0; i < a.size(); i++)
a[i] = max(a[i], b[i]);
for (int i = a.size() - 1; i >= 1; i--)
a[i] -= a[i - 1];
return a;
}
const long long INF = 1e15;
vector<vector<ll>> solve(int l = 1, int r = n) {
if (l == r) {
return {
{ 0, 0 }, { -INF, INF + a[l] }, { -a[l] },
{ -INF, INF + a[l] }, { -INF }, { -INF },
{ -a[l] }, { -INF }, { -INF }
};
}
ll mid = (l + r) >> 1;
vector<vector<ll>> ret(9), L = solve(l, mid), R = solve(mid + 1, r);
for (int ll = 0; ll < 3; ll++) {
for (int rr = 0; rr < 3; rr++) {
ret[ll + rr * 3] = Max(Max(Merge(L[ll], R[rr * 3]), Merge(L[ll + 3], R[rr * 3 + 2])), Merge(L[ll + 6], R[rr * 3 + 1]));
ret[ll + rr * 3] = Max(ret[ll + rr * 3], Max(L[ll + rr * 3], R[ll + rr * 3]));
}
}
// printf("merge(%d, %d):\n", l, r);
// for (int i = 0; i < 9; i++) {
// printf(" ret[%d]: ", i);
// long long sum = 0;
// for (ll j : ret[i]) sum += j, printf("%lld ", sum);
// printf("\n");
// }
return ret;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
auto ans = solve();
for (int i = 1; i < ans[0].size(); i++) {
ans[0][i] += ans[0][i - 1];
}
for (int i = 1; i <= n; i++) {
printf("%lld\n", ans[0][i]);
}
return 0;
}
感谢 _ICEY_ dalao 的帮助。
Gym - 104128H Factories Once More
较为进阶且比较基础的应用。
设 \(f_{u, i}\) 表示 \(u\) 子树内的权值和。那么转移有:
后者是一个凸函数,那么就可以做闵可夫斯基和。
可以拿平衡树维护差分数组,支持区间加等差数列(\(w_{u, v} j (k-j)\) 的差分为等差数列),插入一个数,然后树上启发式合并,复杂度 \(O(n \log^2 n)\)。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1000005;
int n, k;
vector<pair<int, int>> e[MAXN];
mt19937 Rand(chrono::system_clock::now().time_since_epoch().count());
struct Treap {
int lc[MAXN], rc[MAXN], rnd[MAXN];
long long val[MAXN], k[MAXN], b[MAXN];
int siz[MAXN];
int tot;
stack<int> s;
int newNode(long long v) {
int p = s.empty() ? ++tot : s.top();
if (!s.empty()) s.pop();
k[p] = b[p] = lc[p] = rc[p] = 0;
val[p] = v, siz[p] = 1, rnd[p] = Rand();
return p;
}
void pushUp(int p) {
siz[p] = siz[lc[p]] + siz[rc[p]] + 1;
}
void tag(int p, long long K, long long B) {
val[p] += K * (siz[lc[p]] + 1) + B;
k[p] += K, b[p] += B;
}
void pushDown(int p) {
if (k[p] || b[p]) {
if (lc[p]) tag(lc[p], k[p], b[p]);
if (rc[p]) tag(rc[p], k[p], k[p] * (siz[lc[p]] + 1) + b[p]);
k[p] = b[p] = 0;
}
}
void split(long long v, int p, int &x, int &y) {
if (!p) x = y = 0;
else {
pushDown(p);
if (v > val[p]) {
y = p;
split(v, lc[p], x, lc[p]);
} else {
x = p;
split(v, rc[p], rc[p], y);
}
pushUp(p);
}
}
int merge(int x, int y) {
if (!x || !y) return x + y;
pushDown(x), pushDown(y);
if (rnd[x] < rnd[y]) {
rc[x] = merge(rc[x], y);
pushUp(x);
return x;
} else {
lc[y] = merge(x, lc[y]);
pushUp(y);
return y;
}
}
void flatten(int p, vector<long long> &v) {
s.push(p);
pushDown(p);
if (lc[p]) flatten(lc[p], v);
v.push_back(val[p]);
if (rc[p]) flatten(rc[p], v);
}
void insert(int &p, long long v) {
int x, y; split(v, p, x, y);
p = merge(merge(x, newNode(v)), y);
}
} t;
int root[MAXN];
int siz[MAXN];
void dfs(int u, int pre) {
siz[u] = 1;
int s = 0;
for (auto p : e[u]) if (p.first != pre) {
int v = p.first, w = p.second;
dfs(v, u);
t.tag(root[v], -2ll * w, 1ll * w * (k + 1));
siz[u] += siz[v];
if (siz[v] > siz[s]) s = v;
}
if (s) root[u] = root[s];
t.insert(root[u], 0);
for (auto p : e[u]) if (p.first != pre && p.first != s) {
int v = p.first;
vector<long long> val;
t.flatten(root[v], val);
for (long long w : val) {
t.insert(root[u], w);
}
}
assert(siz[u] == t.siz[root[u]]);
}
int main() {
// freopen("H.in", "r", stdin);
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
e[u].push_back({v, w});
e[v].push_back({u, w});
}
int rt = 1;
dfs(rt, 0);
long long ans = 0;
vector<long long> val;
t.flatten(root[rt], val);
for (int i = 1; i <= k; i++) {
ans += val[i - 1];
}
printf("%lld\n", ans);
// printf("tot = %d\n", t.tot);
return 0;
}
/*
6 3
1 2 3
2 3 2
2 4 1
1 5 2
5 6 3
*/