BJOI2017 喷式水战改
题目链接。
Description
维护一个序列,支持操作:
- 每次在 \(P_i\) 位置后插入一段 \(X_i\) 单位的燃料,这一段有三个模式,对应的能量分别是 \(A_i, B_i, C_i\)。然后将这个序列分成四段(一段可以为空),权值分别是 \(ABCA\),最后求最大总能量。
Solution
首先我们发现一个性质,就是说一段其实在最优解下的状态是相同的,否则可以把状态价值高的蔓延到低的,会更优。
如果不考虑查询,可以把每一段看做一个大小为 \(X_i\) 的点,这个插入操作在时间复杂度能接受的范围内其实是一个平衡树的操作。因为每次插入最坏情况下会分裂一个点,所以点数最多 \(2n\)。我们可以考虑是否能在维护平衡树的时候同步维护答案。
最大总能量显然是 DP,而这道题的 DP 可以写出线性 DP和区间 DP 两种,考虑如果插入一个元素,如果是线性 \(DP\) ,这个元素后面的所有都要重新算一遍,复杂度爆炸。而区间 DP 能够满足我们的要求的。
因为平衡树满足 BST 的性质,所以每个节点的子树可以看做一段区间,每次修改,可以修改的过程同时维护每个节点所在子树区间的答案即可。
状态设计
设 \(f_{i,j}\) 为一个节点所在的子树所形成的区间,状态区间是 \([i, j]\) 所搞成的最大总能量。
初始状态
考虑每个点初始的答案。
$f_{i, j} = X_i \times $ \([i, j]\) 状态中最大的单位权值。
状态转移
考虑一段区间的合并,设左边的为 \(A.f\),右边的是 \(B.f\),答案是 \(C.f\)
有 \(C.f_{i, j} = \max(A.f_{i, k} + B.f_{k, j} )\) 。
在真正实现的时候,先让 $A = $ 左儿子, $B = $当前节点,合并后再合并右儿子即可,合并顺序不影响答案。
时间复杂度
因为每次合并的时候复杂度\(O(4 ^ 3)\),所以总复杂度 \(O(64NlogN)\)
Code
实现下来用的是 Fhq-Treap
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int n, idx, rt;
LL last = 0;
struct F{
LL w[4][4];
F(){}
F (int a, int b, int c, int v) {
memset(w, 0, sizeof w);
w[0][0] = w[3][3] = (LL)a * v, w[1][1] = (LL)b * v, w[2][2] = (LL)c * v;
for (int i = 0; i < 4; i++)
for (int j = i + 1; j < 4; j++) w[i][j] = max(w[i][j - 1], w[j][j]);
}
F operator + (const F &b) const {
F c; memset(c.w, 0, sizeof c.w);
for (int i = 0; i < 4; i++)
for (int j = 0; j < 4; j++)
for (int k = i; k <= j; k++) c.w[i][j] = max(c.w[i][j], w[i][k] + b.w[k][j]);
return c;
}
} val[N << 2], sum[N << 2];
struct T{
int l, r, rnd, sz, len, a, b, c;
LL tot;
} t[N << 2];
int getNode(int a, int b, int c, int len) {
t[++idx] = (T) { 0, 0, rand(), 1, len, a, b, c, len};
val[idx] = sum[idx] = F(a, b, c, len);
return idx;
}
void pushup(int p) {
t[p].sz = t[t[p].l].sz + t[t[p].r].sz + 1;
t[p].tot = t[t[p].l].tot + t[t[p].r].tot + t[p].len;
sum[p] = val[p];
if (t[p].l) sum[p] = sum[t[p].l] + sum[p];
if (t[p].r) sum[p] = sum[p] + sum[t[p].r];
}
int merge(int A, int B) {
if (!A || !B) return A + B;
if (t[A].rnd < t[B].rnd) {
t[A].r = merge(t[A].r, B);
pushup(A);
return A;
} else {
t[B].l = merge(A, t[B].l);
pushup(B);
return B;
}
}
// 按 tot 的 size 分裂,让 x 的 tot 总和 <= k
void split1(int p, LL k, int &x, int &y) {
if (!p) { x = y = 0; return; }
if (t[t[p].l].tot + t[p].len <= k) {
x = p;
split1(t[p].r, k - (t[t[p].l].tot + t[p].len), t[p].r, y);
} else {
y = p;
split1(t[p].l, k, x, t[p].l);
}
pushup(p);
}
// 按 size 分裂,让 x 的 sz 总和 <= k
void split2(int p, int k, int &x, int &y) {
if (!p) { x = y = 0; return; }
if (t[t[p].l].sz + 1 <= k) {
x = p;
split2(t[p].r, k - (t[t[p].l].sz + 1), t[p].r, y);
} else {
y = p;
split2(t[p].l, k, x, t[p].l);
}
pushup(p);
}
int main() {
int x, y, z;
scanf("%d", &n);
while (n--) {
LL p; int a, b, c, v; scanf("%lld%d%d%d%d", &p, &a, &b, &c, &v);
split1(rt, p, x, y); split2(y, 1, y, z);
int w = getNode(a, b, c, v), l = p - t[x].tot;
if (l) t[w].l = getNode(t[y].a, t[y].b, t[y].c, l);
if (t[y].len - l) t[w].r = getNode(t[y].a, t[y].b, t[y].c, t[y].len - l);
pushup(w);
rt = merge(x, merge(w, z));
printf("%lld\n", sum[rt].w[0][3] - last);
last = sum[rt].w[0][3];
}
return 0;
}