CF280D k-Maximum Subsequence Sum
k-Maximum Subsequence Sum - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
Problem - 280D - Codeforces
借鉴题解 CF280D k-Maximum Subsequence Sum - 洛谷专栏 (luogu.com.cn)
思路确实难想,代码复杂程度有紫甚至还高。观察数据 \(k\) 比较小,总数只有 \(10^4\),那么可以直接线段树求区间最大字段和,每次选完标记上不选,询问结束后再取消标记。但是这样会有一个问题,我想了半天没想出来,如 \(3,-1,3\),\(k = 2\),直接求最大字段和,会直接选上 \(3,-1,3\),然而最优答案应该是选两个 \(3\),也就是说,目前最大字段和里面分多个可能还能有更优答案。所以要考虑反悔,把大子段里的更优的找出来。可以直接把选完的子段取反,这样如果还能取,就相当于取了总段里两个(因为肯定不选边界,所以选完后相当于原来的两段)。选的段数增加,恰好符合反悔。
那么总体思想就是,取最大字段和,取完把这段取负数,再继续取最大字段和,重复。当然要记录取反区间和顺序,方便撤销取反。这题就没了。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <vector>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
const int N = 100010;
int n, m;
int g[N];
struct Node
{
int l, r, mul, sum;
int maxv, lmaxv, rmaxv, lm, rm, lms, rms;
int minv, lminv, rminv, lv, rv, lvs, rvs;
}tr[N * 4];
void pushup(Node &u, Node &l, Node &r)
{
u = {l.l, r.r, 1, l.sum + r.sum};
int maxv, rmaxv, lmaxv, minv, lminv, rminv;
rmaxv = max(r.rmaxv, r.sum + l.rmaxv);
lmaxv = max(l.lmaxv, l.sum + r.lmaxv);
maxv = max({lmaxv, rmaxv, l.maxv, r.maxv, l.rmaxv + r.lmaxv});
rminv = min(r.rminv, r.sum + l.rminv);
lminv = min(l.lminv, l.sum + r.lminv);
minv = min({lminv, rminv, l.minv, r.minv, l.rminv + r.lminv});
u.maxv = maxv;
u.rmaxv = rmaxv;
u.lmaxv = lmaxv;
if (lmaxv == l.lmaxv) u.lms = l.lms;
if (lmaxv == l.sum + r.lmaxv) u.lms = r.lms;
if (rmaxv == r.rmaxv) u.rms = r.rms;
if (rmaxv == r.sum + l.rmaxv) u.rms = l.rms;
if (maxv == lmaxv) u.lm = l.l, u.rm = u.lms;
if (maxv == rmaxv) u.rm = r.r, u.lm = u.rms;
if (maxv == l.maxv) u.lm = l.lm, u.rm = l.rm;
if (maxv == r.maxv) u.lm = r.lm, u.rm = r.rm;
if (maxv == l.rmaxv + r.lmaxv) u.lm = l.rms, u.rm = r.lms;
// 同样思路
u.minv = minv;
u.rminv = rminv;
u.lminv = lminv;
if (lminv == l.lminv) u.lvs = l.lvs;
if (lminv == l.sum + r.lminv) u.lvs = r.lvs;
if (rminv == r.rminv) u.rvs = r.rvs;
if (rminv == r.sum + l.rminv) u.rvs = l.rvs;
if (minv == lminv) u.lv = l.l, u.rv = u.lvs;
if (minv == rminv) u.rv = r.r, u.lv = u.rvs;
if (minv == l.minv) u.lv = l.lv, u.rv = l.rv;
if (minv == r.minv) u.lv = r.lv, u.rv = r.rv;
if (minv == l.rminv + r.lminv) u.lv = l.rvs, u.rv = r.lvs;
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void get(Node &u) // 取反
{
u.mul = -u.mul;
u.sum = -u.sum;
u.maxv = -u.maxv;
u.lmaxv = -u.lmaxv;
u.rmaxv = -u.rmaxv;
u.minv = -u.minv;
u.lminv = -u.lminv;
u.rminv = -u.rminv;
swap(u.minv, u.maxv);
swap(u.lmaxv, u.lminv);
swap(u.rmaxv, u.rminv);
swap(u.lm, u.lv), swap(u.rm, u.rv);
swap(u.lms, u.lvs), swap(u.rms, u.rvs);
}
void pushdown(Node &u, Node &l, Node &r)
{
if (u.mul == -1)
{
get(l);
get(r);
}
u.mul = 1;
}
void pushdown(int u)
{
pushdown(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r) tr[u] = {l, r, 1, g[l], g[l], g[l], g[l], l, r, l, r, g[l], g[l], g[l], l, l, l, l};
else
{
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r) get(tr[u]);
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r);
if (r > mid) modify(u << 1 | 1, l, r);
pushup(u);
}
}
void modify1(int u, int x, int k)
{
if (tr[u].l == tr[u].r) tr[u] = {x, x, 1, k, k, k, k, x, x, x, x, k, k, k, x, x, x, x};
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify1(u << 1, x, k);
else modify1(u << 1 | 1, x, k);
pushup(u);
}
}
Node query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r) return tr[u];
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else
{
Node res, A, B;
A = query(u << 1, l, r);
B = query(u << 1 | 1, l, r);
pushup(res, A, B);
return res;
}
}
}
int main()
{
cin >> n;
for (int i = 1; i <= n; i ++ ) scanf("%d", &g[i]);
cin >> m;
build(1, 1, n);
while (m -- )
{
int op, l, r, k, x;
scanf("%d", &op);
// cout << op << endl;
if (op)
{
scanf("%d%d%d", &l, &r, &k);
vector<PII> res;
int ans = 0;
for (int i = 1; i <= k; i ++ )
{
Node t = query(1, l, r);
if (t.maxv < 0) break;
ans += t.maxv;
modify(1, t.lm, t.rm);
res.push_back({t.lm, t.rm});
}
reverse(res.begin(), res.end());
for (auto t : res) modify(1, t.x, t.y); // 撤销影响
cout << ans << endl;
}
else
{
scanf("%d%d", &x, &k);
modify1(1, x, k);
}
}
return 0;
}