CF573E Bear and Bowling【DP,平衡树】

给定长为 \(n\) 的序列 \(a_1,\cdots,a_n\),对于 \(a\) 的一个长为 \(m\) 的子序列 \(b\)(可以为空),定义其权值为 \(\sum\limits_{i=1}^m i \cdot b_i\)。求权值最大的子序列的权值。

\(n \leq 10^5\)\(|a_i| \leq 10^7\),时限 \(\text{6.0s}\)


首先有一个傻逼 DP,设 \(f_{i,j}\) 为前 \(i\) 个数中选择了 \(j\) 个数的最大权值,容易得到转移:

\[f_{i,j} \gets \max\{ f_{i-1,j},f_{i-1,j-1} + a_i \times j\} \]

然后我就不会了,至于有题解说 \(f_i\) 凸,这显然是不对的。重要的结论是:\(\forall i\)\(\exists k_i \in [0,i]\) 满足 \(\forall j < k_i,f_{i,j} = f_{i-1,j}\)\(\forall j \geq k_i,f_{i,j} = f_{i-1,j-1} + a_i \times j\)。这等价于 \(\forall j < k_i,f_{i-1,j} \geq f_{i-1,j-1} + a_i \times j\)\(\forall j \geq k_i,f_{i-1,j} \leq f_{i-1,j-1} + a_i \times j\)

\(g_{i,j} = f_{i,j} - f_{i,j-1}\),稍作变形容易得到 \(\forall j < k_i,\frac{g_{i-1,j}}{j} \geq a_i\)\(\forall j \geq k_i,\frac{g_{i-1,j}}{j} \leq a_i\)

显然,如果 \(\frac{g_{i,j}}{j}\) 关于 \(j\) 单调不增,那么结论就一定成立了。虽然这并不充分,但事实上确实是这样——考虑归纳证明,\(g_1\) 显然合法,现在假设 \(g_{i-1}\) 合法,此时 \(k_i\) 存在:

对于 \(j < k_i\),显然 \(g_{i,j}\) 不变。对于 \(k_i\),由于 \(j < k_i,\frac{g_{i-1,j}}{j} \leq a_i\)\(\frac{g_{i,k_i}}{k_i} = \frac{k_i a_i}{k_i} = a_i\),单调性也不变。对于 \(j > k_i\),有 \(\frac{g_{i-1,j}}{j} \geq \frac{g_{i-1,j+1}}{j+1}\),即 \(g_{i-1,j} \geq \frac{j}{j+1} \cdot g_{i-1,j+1}\)。同时,我们有 \(g_{i,j+1} = g_{i-1,j} + a_i,g_{i,j+2} = g_{i-1,j+1}+a_i\)

考虑我们要证的结论:\(\frac{g_{i,j+1}}{j+1} \geq \frac{g_{i,j+2}}{j+2}\),稍作整理容易得到 \(g_{i - 1,j} \geq \frac{j+1}{j+2} \cdot g_{i-1,j+1} - \frac{1}{j+2} \cdot a_i\)。而

\[\begin{aligned} \frac{j}{j+1} \cdot g_{i-1,j+1} - (\frac{j+1}{j+2} \cdot g_{i-1,j+1} - \frac{1}{j+2} \cdot a_i) &= \frac{((j+2) \cdot j - (j+1)^2) \cdot g_{i-1,j+1} + (j+1) \cdot a_i}{(j+1)(j+2)} \\ & = \frac{(j+1) \cdot a_i - f_{i-1,j+1}}{(j+1)(j+2)} \geq 0 \end{aligned} \]

又由 \(g_{i-1,j} \geq \frac{j}{j+1} \cdot g_{i-1,j+1}\) 可知结论成立。

现在考虑如何维护这个 DP,容易发现我们需要支持的操作就是在 \(k_i\) 的位置插入一个数,然后对后缀区间加等差数列,平衡树维护即可。找 \(k_i\) 我偷懒写了 \(O(n \log^2 n)\) 的,因为时限比较松所以也过了。

但是问题来了,这种神秘结论到底该怎么想到呢,感觉除了打表没有其他好的方法啊,有没有大神教育我

Code
/*
最黯淡的一个 梦最为炽热
万千孤单焰火 让这虚构灵魂鲜活
至少在这一刻 热爱不问为何
存在为将心声响彻
*/
#include <bits/stdc++.h>
#define pii pair<int, int>
#define mp(x, y) make_pair(x, y)
#define pb push_back
#define eb emplace_back
#define fi first
#define se second
#define int long long
#define mem(x, v) memset(x, v, sizeof(x))
#define mcpy(x, y, n) memcpy(x, y, sizeof(int) * (n))
#define lob lower_bound
#define upb upper_bound
using namespace std;

inline int read() {
	int x = 0, w = 1;char ch = getchar();
	while (ch > '9' || ch < '0') { if (ch == '-')w = -1;ch = getchar(); }
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return x * w;
}

const int MN = 1e5 + 5;
const int Mod = 1e9 + 7;
const int inf = 1e13;

inline int min(int x, int y) { return x < y ? x : y; }
inline int max(int x, int y) { return x > y ? x : y; }

inline int qPow(int a, int b = Mod - 2, int ret = 1) {
    while (b) {
        if (b & 1) ret = ret * a % Mod;
        a = a * a % Mod, b >>= 1;
    }
    return ret;
}

#define dbg

int N, a[MN];

#define ls(x) lc[x]
#define rs(x) rc[x]
int rt, tot;
int sz[MN], lc[MN], rc[MN], k[MN], b[MN], val[MN], w[MN];
inline void make(int v) { sz[++tot] = 1, val[tot] = v, w[tot] = rand(); }
inline void pushup(int x) { sz[x] = sz[ls(x)] + sz[rs(x)] + 1; }
inline void pushtag(int x, int _k, int _b) { 
    if (x) k[x] += _k, b[x] += _b, val[x] += _k * (sz[ls(x)] + 1) + _b;
}
inline void pushdown(int x) {
    if (k[x] || b[x]) {
        pushtag(ls(x), k[x], b[x]), pushtag(rs(x), k[x], k[x] * (sz[ls(x)] + 1) + b[x]);
        k[x] = b[x] = 0;
    }
}
inline int merge(int x, int y) {
    if (!x || !y) return x + y;
    pushdown(x), pushdown(y);
    if (w[x] < w[y]) return ls(y) = merge(x, ls(y)), pushup(y), y;
    return rs(x) = merge(rs(x), y), pushup(x), x;
}
inline void split(int o, int k, int &x, int &y) {
    if (!o) x = y = 0;
    else {
        pushdown(o);
        if (k > sz[ls(o)]) x = o, split(rs(o), k - sz[ls(o)] - 1, rs(x), y), pushup(x);
        else y = o, split(ls(o), k, x, ls(y)), pushup(y);
    }
}
inline int qry(int o, int k) {
    pushdown(o);
    if (k <= sz[ls(o)]) return qry(ls(o), k);
    if (k > sz[ls(o)] + 1) return qry(rs(o), k - sz[ls(o)] - 1);
    return val[o];
}
inline int getans(int o) {
    if (!o) return 0;
    pushdown(o);
    return max(max(getans(ls(o)), getans(rs(o))), val[o]);
}
inline void ins(int k, int v) {
    int rt1, rt2;
    split(rt, k, rt1, rt2);
    make(v);
    rt = merge(rt1, merge(tot, rt2));
}
inline void upd(int p, int k, int b) {
    int rt1, rt2;
    split(rt, p, rt1, rt2);
    pushtag(rt2, k, b);
    rt = merge(rt1, rt2);
}

signed main(void) {
    srand(time(0));
    N = read();
    rt = 1;
    make(0);
    for (int i = 1; i <= N; i++) a[i] = read();
    for (int i = 1, l, r, p; i <= N; i++) {
        l = 0, r = p = i - 1;
        while (l <= r) {
            int mid = (l + r) >> 1;
            if (qry(rt, mid + 1) + a[i] * (mid + 1) > qry(rt, mid + 2)) r = mid - 1, p = mid;
            else l = mid + 1;
        }
        int w = qry(rt, p + 1);
        ins(p + 1, w);
        upd(p + 1, a[i], a[i] * p);
    }
    printf("%lld\n", getans(rt));
    return 0;
}
posted @ 2022-09-01 20:52  came11ia  阅读(31)  评论(0编辑  收藏  举报