【解题报告】「NOI 1995」石子合并
一个圆形操场上摆放 \(n\ (1\le n\le 100)\) 堆石子。现要将 \(n\) 堆石子合并成一堆,每次只能选相邻的 \(2\) 堆合并成新的一堆,代价为新的一堆的石子数。求将 \(n\) 堆石子合并成一堆的最小代价和、最大代价和。
“最大代价和”的转移证明
\(\color{blue}{\text{结论}}\):对于 \(f_{i,j}\),最优转移一定从 \(f_{i,j-1}\) 或 \(f_{i+1,j}\) 来。下给出证明。
考虑反证,设最优转移从 \(k\ (i<k<j-1)\) 来,记 \(S_1=\sum\limits_{t=i}^{k} a_t\),\(S_2=\sum\limits_{t=k+1}^{j} a_t\)。
我们可以看成 \([i,k]\) 和 \([k+1,j]\) 先分别内部合并,再两者合并。若 \(S_1\ge S_2\),考虑以下情形:先把 \([i,k]\) 内部合并,然后将这一堆与 \(k+1\) 合并,之后进行 \([k+1,j]\) 的内部合并。这样做的好处是,每次有关 \(k+1\) 的合并(除了原本的 \([i,k]\) 和 \([k+1,j]\) 这一次大合并),都能额外得到 \([i,k]\) 堆的贡献,贡献增量 \(\Delta\ge 2S_1-(S_1+S_2)\ge 0\)。\(S_1<S_2\) 对称同理。\(\square\)
四边形不等式优化
若 \(f\) 满足对于 \(\forall a\le b\le c\le d\),有 \(f_{a,c}+f_{b,d}\le f_{b,c}+f_{a,d}\),则称其满足四边形不等式,简记为\(\color{orange}{交叉 \le 包含法则}\)。
考虑 \(f_{i,j}=\min\limits_{i\le k<j}\left(f_{i,k}+f_{k+1,j}+w_{i,j}\right)\) 型区间 dp,先说\(\color{orange}{适用四边形不等式优化的条件}\):\(\color{red}{w 满足四边形不等式}\)。
下给出证明。先证 \(\color{violet}{f\ 满足四边形不等式}\),只需要证 \(f_{i,j}+f_{i+1,j+1}\le f_{i+1,j}+f_{i,j+1}\)(因为这说明 \(f\) 满足凸性)。
按照区间长度从小到大归纳证明。设 \(f_{i+1,j}\) 的决策点为 \(x\),\(f_{i,j+1}\) 的决策点为 \(y\)。若 \(x\le y\),则:
若 \(x>y\),在 \(f_{i,j}\) 处取决策 \(y\),\(f_{i+1,j+1}\) 处取决策 \(x\) 放缩即可。因此命题成立。
\(\color{orange}{记\ f_{i,j}\ 的决策点为\ s_{i,j}。}\)
再证 \(\color{violet}{f\ 满足决策单调性}\),即证 \(s_{i,j-1}\le s_{i,j}\le s_{i+1,j}\)。对于前者,记 \(x=s_{i,j}\),\(y=s_{i,j-1}\),考虑反证,若 \(x<y\),则:
注意到左式 \(\ge 0\),右式 \(\le 0\),因此 \(x<y\) 不优,故 \(s_{i,j-1}\le s_{i,j}\)。后者证明类似。\(\square\)
回到区间 dp,我们在计算 \(f_{i,j}\) 时,只需在 \(k\in [s_{i,j-1},s_{i+1,j}]\) 中遍历即可。
考虑 \(\color{blue}{f\ 每个主对角方向的斜线}\),遍历次数为 \(\sum\limits_{i}(s_i-s_{i-1})=\mathcal O(n)\),因此总时间复杂度为 \(\mathcal O(n^2)\)。
Garsia - Wachs 算法(针对序列的情形)
算法的流程:不妨 \(a_0=a_{n+1}=+\infty\),找到第一个 \(k\) 满足 \(a_{k-1}\le a_{k+1}\),合并 \(a_{k-1}\) 和 \(a_{k}\),并从 \(k\) 往前找到第一个 \(j\) 满足 \(a_j>a_{k-1}+a_k\),将 \((a_{k-1}+a_k)\) 插到 \(j\) 后面。反复执行此流程,直至 \(|a|=1\),代价和为 \(\sum (a_{k-1}+a_k)\)。
算法的证明鸽了,论文找不到。奇偶位分别开一个平衡树维护,\(a_i\) 增量处理,时间复杂度 \(\mathcal O(n\log n)\)。
\(\color{cyan}{坑}\) 记 \(x=a_{k-1}+a_k\),它被插到了 \(a_j\) 后面成为 \(a_{j+1}\),分析可能会带来的新的影响(容易遗漏!):
- 可能 \(a_{j+1}=a_{j+3}\),要将 \((j+1,j+2)\) 合并;
- 可能 \(a_{j-1}\le a_{j+1}\),要将 \((j-1,j)\) 合并,\(\color{red}{但这还可能导致\ a_{j-3}\le a_j}\),因此我们需 while 循环直到不能合并为止}。
AC 代码(\(\mathcal O(n\log n)\))
// #include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cassert>
using namespace std;
const int inf = 2e9 + 5;
const int N = 500005;
// mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// mt19937 rng(2020);
// 两棵 “递减链状” Treap (奇偶位)
int rnd[N], sz[N], val[N], maxx[N], ls[N], rs[N], tot;
inline int newnode(int x) {
++tot, rnd[tot] = rand(), sz[tot] = 1, val[tot] = maxx[tot] = x, ls[tot] = rs[tot] = 0;
return tot;
}
inline void pushup(int x) {
sz[x] = sz[ls[x]] + sz[rs[x]] + 1, maxx[x] = max(val[x], max(maxx[ls[x]], maxx[rs[x]])); // !
}
void split_sz(int u, int rk, int &x, int &y) {
if (!u) {
x = y = 0;
return ;
}
if (sz[ls[u]] >= rk) {
split_sz(ls[u], rk, x, ls[u]);
y = u;
pushup(u);
} else {
split_sz(rs[u], rk - sz[ls[u]] - 1, rs[u], y);
x = u;
pushup(u);
}
}
void split_val(int u, int w, int &x, int &y) {
if (!u) {
x = y = 0;
return ;
}
assert(maxx[u] > w);
if (maxx[rs[u]] > w) {
// printf(">>> 1\n");
split_val(rs[u], w, rs[u], y);
x = u;
pushup(u);
} else if (val[u] > w) {
// printf(">>> 2\n");
x = u, y = rs[u], rs[u] = 0;
pushup(u);
} else {
assert(maxx[ls[u]] > w);
split_val(ls[u], w, x, ls[u]);
y = u;
pushup(u);
}
}
int merge(int x, int y) {
if (!x || !y) return x + y;
if (rnd[x] < rnd[y]) {
rs[x] = merge(rs[x], y);
pushup(x);
return x;
} else {
ls[y] = merge(x, ls[y]);
pushup(y);
return y;
}
}
void dfs(int u) {
if (!u) return ;
dfs(ls[u]);
// printf("%d(%d, %d, ls = %d, rs = %d) ", val[u], u, maxx[u], ls[u], rs[u]);
// cerr << val[u] << ' ';
if(val[u]!=inf)printf("%d ", val[u]);
dfs(rs[u]);
}
int a[N], n, cur, root[2], ans;
// (可行, 删除数)
pair<bool, int> check(int pos) {
if (pos <= 3 || pos + 1 > cur) return {false, 0};
// 找出 a[pos - 1] 和 a[pos + 1]
int type = !(pos & 1), k1, k2, k3, k4;
split_sz(root[type], pos / 2 - 1, k1, k2);
split_sz(k2, 2, k2, k4), split_sz(k2, 1, k2, k3);
if (val[k2] > val[k3]) {
root[type] = merge(merge(k1, k2), merge(k3, k4));
return {false, 0};
}
// cerr << "check pos = " << pos << '\n';
assert(sz[k2] == 1 && sz[k3] == 1);
int hp = val[k2];
root[type] = merge(k1, merge(k3, k4));
// 找出 a[pos]
k1 = k2 = k3 = k4 = 0;
type ^= 1;
split_sz(root[type], (pos + 1) / 2 - 1, k1, k2);
split_sz(k2, 1, k2, k3);
hp += val[k2];
root[type] = merge(k1, k3);
cur -= 2;
ans += hp;
// cerr << "hp = " << hp << '\n';
// cerr << "cur = " << cur << '\n';
// cerr << "odd : "; dfs(root[1]); cerr << '\n';
// cerr << "even: "; dfs(root[0]); cerr << '\n';
// 剔除 a[cur]
int tmp1, tmp2, newpos;
split_sz(root[cur & 1], (cur + 1) / 2 - 1, tmp1, tmp2);
assert(sz[tmp2] == 1);
root[cur & 1] = tmp1;
// cerr << "odd : "; dfs(root[1]); cerr << '\n';
// cerr << "even: "; dfs(root[0]); cerr << '\n';
k1 = k2 = k3 = k4 = 0;
split_val(root[0], hp, k1, k2);
split_val(root[1], hp, k3, k4);
assert(maxx[k2] <= hp && maxx[k4] <= hp);
// printf("sz1 = %d, sz2 = %d, (%d, %d), hp = %d\n", sz[k1], sz[k3], sz[k2], sz[k4], hp);
if (2 * sz[k1] > 2 * sz[k3] - 1) { // 插在 k1 后面
// cerr << "case 1\n";
// cerr << sz[k1] << '\n';
newpos = 2 * sz[k1] + 1;
root[1] = merge(k3, k4);
k3 = k4 = 0, split_sz(root[1], sz[k1], k3, k4);
root[1] = merge(merge(k3, newnode(hp)), k2);
root[0] = merge(k1, k4);
} else { // 插在 k3 后面
newpos = 2 * sz[k3];
root[0] = merge(k1, k2);
k1 = k2 = 0, split_sz(root[0], sz[k3] - 1, k1, k2);
root[0] = merge(merge(k1, newnode(hp)), k4);
root[1] = merge(k3, k2);
}
// 把 a[cur] 加回去
cur++;
root[cur & 1] = merge(root[cur & 1], tmp2);
// cerr << "cur = " << cur << '\n';
// cerr << "odd : "; dfs(root[1]); cerr << '\n';
// cerr << "even: "; dfs(root[0]); cerr << '\n';
int del = 1;
while (1) {
pair<bool, int> pref = check(newpos - 1);
if (!pref.first) break;
del += pref.second, newpos -= pref.second;
}
del += check(newpos + 1).second;
return {true, del};
}
void sc() {
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
// for (int i = 1; i <= n; i++) a[i] = rng() % 5 + 1;
// for (int i = 1; i <= n; i++) printf("%d ", a[i]); puts("");
a[n + 1] = inf;
maxx[0] = 0, tot = root[0] = root[1] = ans = 0;
root[0] = newnode(inf), root[1] = newnode(inf), cur = 2;
for (int i = 1; i <= n + 1; i++) {
cur++;
root[cur & 1] = merge(root[cur & 1], newnode(a[i]));
// printf("now i = %d, cur = %d\n", i, cur);
while (check(cur - 1).first) ; // cur 处可能会破坏 “单调递减” 性质, 在 pushup 处特殊处理
// for (int j = cur - 1; j >= 1; j--) {
// assert(!check(cur - 1));
// }
// printf("i = %d, cur = %d\n", i, cur);
// if(1){dfs(root[1]),puts(""),dfs(root[0]),puts("");}
}
printf("%d\n", ans);
}
int main() {
// freopen("../data.in", "r", stdin);
// freopen("../my.out", "w", stdout);
while (~scanf("%d", &n) && n) sc();
return 0;
}
/*
6
4 2 4 2 5 5
5
3 4 1 2 3
25
3 5 2 1 5 2 4 1 3 1 4 1 1 3 3 5 1 5 1 5 4 4 1 1 1
15
403 500 56 206 54 352 3 232 485 292 473 190 126 465 210
*/