【CF280D】 k-Maximum Subsequence Sum ,线段树模拟费用流

昨天考试被教育了一波。为了学习一下\(T3\)的科技,我就找到了这个远古时期的\(cf\)题(虽然最后\(T3\)还是不会写吧\(QAQ\)

顾名思义,这个题目其实可以建成一个费用流的模型。我们用流量来限制区间个数,用费用强迫它每次每次选择最大的区间就可以啦。但是因为询问很多,复杂度似乎不行,于是就有了这种神奇的科技——线段树模拟费用流。

在原先的费用流模型里,我们有正反两种边,而反向边的意义就在于,在每一次增广的时候可以反悔以前的操作,把局部最优向更大范围的局部更优优化。

参考反向边的原理,我们可以想象出来:如果对这个区间,我们每次都取用最大子区间,并在取用这个最大子区间以后将其价值变为负数,不就可以模拟费用流的行为了嘛?这样做的复杂度是\(O(NMlogN)\)的,可以解决更大数据范围的问题。

算法很好理解,关键是千万不要把代码写挂\(QwQ\),真的不是很好调啊\(TwT\)


#include<bits/stdc++.h>
using namespace std;

struct dat {
    int s; //sum of sequence
	int lmx, lmxp; //left -> max_val && it's pos
	int lmn, lmnp; //left -> min_val && it's pos
	int rmx, rmxp; //righ -> max_val && it's pos
	int rmn, rmnp; //righ -> min_val && it's pos
	int smx, smxl, smxr; //sub -> max_val && it's pos (l, r)
	int smn, smnl, smnr; //sub -> min_val && it's pos (l, r)
    dat (int pos = 0, int val = 0){
        lmxp = lmnp = rmxp = rmnp = smxl = smxr = smnl = smnr = pos;
        s = lmx = lmn = rmx = rmn = smx = smn = val;
		//对单个的点进行数据更新
	}
}T[400010];

dat operator + (dat l, dat r){
    dat u;
    u.s = l.s+r.s; //先更新关于和的数据
    if (l.lmx > l.s + r.lmx) { //max_left's pos 是否越过 mid
        u.lmx = l.lmx;
        u.lmxp = l.lmxp;
    } else {
        u.lmx = l.s + r.lmx;
        u.lmxp = r.lmxp;
    }
    if (r.rmx > r.s + l.rmx) { //max_righ's pos 是否越过 mid
        u.rmx = r.rmx;
        u.rmxp = r.rmxp;
    } else {
        u.rmx = r.s + l.rmx;
        u.rmxp = l.rmxp;
    }
    if (l.lmn < l.s + r.lmn) { //min_left's pos 是否越过 mid
        u.lmn = l.lmn;
        u.lmnp = l.lmnp;
    } else {
        u.lmn = l.s + r.lmn;
        u.lmnp = r.lmnp;
    }
    if (r.rmn < r.s + l.rmn) { //min_righ's pos 是否越过 mid
        u.rmn = r.rmn;
        u.rmnp = r.rmnp;
    } else {
        u.rmn = r.s + l.rmn;
        u.rmnp = l.rmnp;
    }
    if (l.smx > r.smx) { //最大子段 in left / righ
        u.smx = l.smx;
        u.smxl = l.smxl;
        u.smxr = l.smxr;
    } else {
        u.smx = r.smx;
        u.smxl = r.smxl;
        u.smxr = r.smxr;
    }
    if (l.rmx + r.lmx > u.smx){ //最大子段是否越过 mid
        u.smx = l.rmx + r.lmx;
        u.smxl = l.rmxp;
        u.smxr = r.lmxp;
    }
    if (l.smn < r.smn) { //最小子段 in left / righ
        u.smn = l.smn;
        u.smnl = l.smnl;
        u.smnr = l.smnr;
    } else {
        u.smn = r.smn;
        u.smnl = r.smnl;
        u.smnr = r.smnr;
    }
    if (l.rmn + r.lmn < u.smn) { //最小子段是否跨过 mid
        u.smn = l.rmn + r.lmn;
        u.smnl = l.rmnp;
        u.smnr = r.lmnp;
    }
    return u;
}

#define ls (x << 1)
#define rs (x << 1 | 1)

void pushup (int x) {
    T[x] = T[ls] + T[rs];
}

int a[100010];

void build (int l, int r, int x) {
    if (l == r) {
        T[x] = dat (l, a[l]);
        return;
    }
    int mid = (l + r) >> 1;
    build (l, mid, ls);
    build (mid + 1, r, rs);
    pushup (x);
}

int f[400010];

void rev (int x) {
    dat &u = T[x];
	// max 变成 min
    swap (u.lmx, u.lmn); 
    swap (u.lmxp, u.lmnp);
    swap (u.rmx, u.rmn);
    swap (u.rmxp, u.rmnp);
    swap (u.smx, u.smn);
    swap (u.smxl, u.smnl);
    swap (u.smxr, u.smnr);
    f[x] ^= 1;
    u.lmx *= -1;
    u.lmn *= -1;
    u.rmx *= -1;
    u.rmn *= -1;
    u.smx *= -1;
    u.smn *= -1;
    u.s *= -1;
}
void pushdown (int x) {
    if(f[x]) {
        rev (ls);
        rev (rs);
        f[x] = 0;
    }
}

void modify (int p, int v, int l, int r, int x) {
    if (l == r) {
        T[x] = dat (l, v);
        return;
    }
    pushdown (x);
    int mid = (l + r) >> 1;
    if (p <= mid) {
        modify (p, v, l, mid, ls);
    } else {
        modify (p, v, mid + 1, r, rs);
	}
	pushup (x);
}

void reverse (int L, int R, int l, int r, int x) {
	//其实就是取用啦
    if (L <= l && r <= R) return rev (x);
    pushdown (x);
    int mid = (l + r) >> 1;
    if (L <= mid) reverse (L, R, l, mid, ls);
    if (mid < R) reverse (L, R, mid + 1, r, rs);
    pushup (x);
}

dat query (int L, int R, int l, int r, int x) {
	//求[l, r]区间内的最大值嘛
    if (L <= l && r <= R) return T[x];
    pushdown (x);
    int mid = (l + r) >> 1;
    if (R <= mid) return query (L, R, l, mid, ls); //如果区间全在左边
    if (mid < L) return query (L, R, mid + 1, r, rs); //如果区间全在右边
    return query (L, R, l, mid, ls) + query (L, R, mid + 1, r, rs); //跨 mid 了 QwQ
}

int L[30], R[30], top;
int n, m, x, y, k, opt;
    
int main () {
	cin >> n;
    for (int i = 1; i <= n; ++i) {
		cin >> a[i];
	}
	build (1, n, 1);
	cin >> m;
    for (int i = 1; i <= m; ++i) {
		cin >> opt >> x >> y;
        if (opt == 0) {
            modify (x, y, 1, n, 1); //把点 x 的值改为 y
        } else {
			cin >> k;  //在[x, y]之间取 k 段的最大值
            int ans = 0;
            for (int j = 1; j <= k; ++j) {
                dat t = query (x, y, 1, n, 1);
                if (t.smx <= 0) break;
				//选至多 k 段, 可以少选 !
                ans += t.smx;
                L[++top] = t.smxl, R[top] = t.smxr;
                reverse (L[top], R[top], 1, n, 1);
            }
            while (top) {
                reverse (L[top], R[top], 1, n, 1);
                top = top - 1;
            }
			cout << ans << endl;
        }
    }
}

posted @ 2019-02-17 18:53  maomao9173  阅读(285)  评论(0编辑  收藏  举报