CF280D k-Maximum Subsequence Sum

k-Maximum Subsequence Sum - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
Problem - 280D - Codeforces

借鉴题解 CF280D k-Maximum Subsequence Sum - 洛谷专栏 (luogu.com.cn)

思路确实难想,代码复杂程度有紫甚至还高。观察数据 \(k\) 比较小,总数只有 \(10^4\),那么可以直接线段树求区间最大字段和,每次选完标记上不选,询问结束后再取消标记。但是这样会有一个问题,我想了半天没想出来,如 \(3,-1,3\)\(k = 2\),直接求最大字段和,会直接选上 \(3,-1,3\),然而最优答案应该是选两个 \(3\),也就是说,目前最大字段和里面分多个可能还能有更优答案。所以要考虑反悔,把大子段里的更优的找出来。可以直接把选完的子段取反,这样如果还能取,就相当于取了总段里两个(因为肯定不选边界,所以选完后相当于原来的两段)。选的段数增加,恰好符合反悔。

那么总体思想就是,取最大字段和,取完把这段取负数,再继续取最大字段和,重复。当然要记录取反区间和顺序,方便撤销取反。这题就没了。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <vector>

#define x first
#define y second

using namespace std;

typedef pair<int, int> PII;

const int N = 100010;

int n, m;
int g[N];

struct Node
{
    int l, r, mul, sum;
    int maxv, lmaxv, rmaxv, lm, rm, lms, rms;
    int minv, lminv, rminv, lv, rv, lvs, rvs;
}tr[N * 4];

void pushup(Node &u, Node &l, Node &r)
{
    u = {l.l, r.r, 1, l.sum + r.sum};
    int maxv, rmaxv, lmaxv, minv, lminv, rminv;
    rmaxv = max(r.rmaxv, r.sum + l.rmaxv);
    lmaxv = max(l.lmaxv, l.sum + r.lmaxv);
    maxv = max({lmaxv, rmaxv, l.maxv, r.maxv, l.rmaxv + r.lmaxv});
    
    rminv = min(r.rminv, r.sum + l.rminv);
    lminv = min(l.lminv, l.sum + r.lminv);
    minv = min({lminv, rminv, l.minv, r.minv, l.rminv + r.lminv});
    
    u.maxv = maxv;
    u.rmaxv = rmaxv;
    u.lmaxv = lmaxv;
    if (lmaxv == l.lmaxv) u.lms = l.lms;
    if (lmaxv == l.sum + r.lmaxv) u.lms = r.lms;
    
    if (rmaxv == r.rmaxv) u.rms = r.rms;
    if (rmaxv == r.sum + l.rmaxv) u.rms = l.rms;
    
    if (maxv == lmaxv) u.lm = l.l, u.rm = u.lms;
    if (maxv == rmaxv) u.rm = r.r, u.lm = u.rms;
    if (maxv == l.maxv) u.lm = l.lm, u.rm = l.rm;
    if (maxv == r.maxv) u.lm = r.lm, u.rm = r.rm;
    if (maxv == l.rmaxv + r.lmaxv) u.lm = l.rms, u.rm = r.lms;
    
    // 同样思路
    u.minv = minv;
    u.rminv = rminv;
    u.lminv = lminv;
    
    if (lminv == l.lminv) u.lvs = l.lvs;
    if (lminv == l.sum + r.lminv) u.lvs = r.lvs;
    
    if (rminv == r.rminv) u.rvs = r.rvs;
    if (rminv == r.sum + l.rminv) u.rvs = l.rvs;
    
    if (minv == lminv) u.lv = l.l, u.rv = u.lvs;
    if (minv == rminv) u.rv = r.r, u.lv = u.rvs;
    if (minv == l.minv) u.lv = l.lv, u.rv = l.rv;
    if (minv == r.minv) u.lv = r.lv, u.rv = r.rv;
    if (minv == l.rminv + r.lminv) u.lv = l.rvs, u.rv = r.lvs;
}

void pushup(int u)
{
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void get(Node &u) // 取反
{
    u.mul = -u.mul;
    u.sum = -u.sum;
    
    u.maxv = -u.maxv;
    u.lmaxv = -u.lmaxv;
    u.rmaxv = -u.rmaxv;
    
    u.minv = -u.minv;
    u.lminv = -u.lminv;
    u.rminv = -u.rminv;
    
    swap(u.minv, u.maxv);
    swap(u.lmaxv, u.lminv);
    swap(u.rmaxv, u.rminv);
    swap(u.lm, u.lv), swap(u.rm, u.rv);
    swap(u.lms, u.lvs), swap(u.rms, u.rvs);
}

void pushdown(Node &u, Node &l, Node &r)
{
    if (u.mul == -1)
    {
        get(l);
        get(r);
    }
    u.mul = 1;
}

void pushdown(int u)
{
    pushdown(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r)
{
    if (l == r) tr[u] = {l, r, 1, g[l], g[l], g[l], g[l], l, r, l, r, g[l], g[l], g[l], l, l, l, l};
    else 
    {
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int l, int r)
{
    if (l <= tr[u].l && tr[u].r <= r) get(tr[u]);
    else
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r);
        if (r > mid) modify(u << 1 | 1, l, r);
        pushup(u);
    }
}

void modify1(int u, int x, int k)
{
    if (tr[u].l == tr[u].r) tr[u] = {x, x, 1, k, k, k, k, x, x, x, x, k, k, k, x, x, x, x};
    else 
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (x <= mid) modify1(u << 1, x, k);
        else modify1(u << 1 | 1, x, k);
        pushup(u);
    }
}

Node query(int u, int l, int r)
{
    if (l <= tr[u].l && tr[u].r <= r) return tr[u];
    else 
    {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else 
        {
            Node res, A, B;
            A = query(u << 1, l, r);
            B = query(u << 1 | 1, l, r);
            pushup(res, A, B);
            return res;
        }
    }
}


int main()
{
    cin >> n;
    for (int i = 1; i <= n; i ++ ) scanf("%d", &g[i]);
    
    cin >> m;
    build(1, 1, n);
    while (m -- )
    {
        int op, l, r, k, x;
        scanf("%d", &op);
        // cout << op << endl;
        if (op) 
        {
            scanf("%d%d%d", &l, &r, &k);
            
            vector<PII> res;
            int ans = 0;
            for (int i = 1; i <= k; i ++ )
            {
                Node t = query(1, l, r);
                if (t.maxv < 0) break;
                ans += t.maxv;
                modify(1, t.lm, t.rm);
                res.push_back({t.lm, t.rm});
            }
            reverse(res.begin(), res.end());
            for (auto t : res) modify(1, t.x, t.y); // 撤销影响
            cout << ans << endl;
        }
        else 
        {
            scanf("%d%d", &x, &k);
            modify1(1, x, k);
        }
    }
    
    return 0;
}
posted @ 2024-11-26 14:16  blind5883  阅读(1)  评论(0编辑  收藏  举报