【解题报告】「NOI 1995」石子合并

一个圆形操场上摆放 \(n\ (1\le n\le 100)\) 堆石子。现要将 \(n\) 堆石子合并成一堆,每次只能选相邻的 \(2\) 堆合并成新的一堆,代价为新的一堆的石子数。求将 \(n\) 堆石子合并成一堆的最小代价和、最大代价和。

“最大代价和”的转移证明

\(\color{blue}{\text{结论}}\):对于 \(f_{i,j}\),最优转移一定从 \(f_{i,j-1}\)\(f_{i+1,j}\) 来。下给出证明。

考虑反证,设最优转移从 \(k\ (i<k<j-1)\) 来,记 \(S_1=\sum\limits_{t=i}^{k} a_t\)\(S_2=\sum\limits_{t=k+1}^{j} a_t\)

我们可以看成 \([i,k]\)\([k+1,j]\)分别内部合并,再两者合并。若 \(S_1\ge S_2\),考虑以下情形:先把 \([i,k]\) 内部合并,然后将这一堆与 \(k+1\) 合并,之后进行 \([k+1,j]\) 的内部合并。这样做的好处是,每次有关 \(k+1\) 的合并(除了原本的 \([i,k]\)\([k+1,j]\) 这一次大合并),都能额外得到 \([i,k]\) 堆的贡献,贡献增量 \(\Delta\ge 2S_1-(S_1+S_2)\ge 0\)\(S_1<S_2\) 对称同理。\(\square\)

四边形不等式优化

\(f\) 满足对于 \(\forall a\le b\le c\le d\),有 \(f_{a,c}+f_{b,d}\le f_{b,c}+f_{a,d}\),则称其满足四边形不等式,简记为\(\color{orange}{交叉 \le 包含法则}\)

考虑 \(f_{i,j}=\min\limits_{i\le k<j}\left(f_{i,k}+f_{k+1,j}+w_{i,j}\right)\) 型区间 dp,先说\(\color{orange}{适用四边形不等式优化的条件}\)\(\color{red}{w 满足四边形不等式}\)

下给出证明。先证 \(\color{violet}{f\ 满足四边形不等式}\),只需要证 \(f_{i,j}+f_{i+1,j+1}\le f_{i+1,j}+f_{i,j+1}\)(因为这说明 \(f\) 满足凸性)。

按照区间长度从小到大归纳证明。设 \(f_{i+1,j}\) 的决策点为 \(x\)\(f_{i,j+1}\) 的决策点为 \(y\)。若 \(x\le y\),则:

\[\begin{aligned} f_{i,j}+f_{i+1,j+1}&\le \left(\color{olive}{f_{i,x}}+f_{x+1,j}+\color{cyan}{w_{i,j}}\right)+\left(\color{olive}{f_{i+1,y}}+f_{y+1,j+1}+\color{cyan}{w_{i+1,j+1}}\right)\\ &\le \color{olive}{\left(f_{i,y}+f_{i+1,x}\right)}+\left(f_{x+1,j}+f_{y+1,j+1}\right)+\color{cyan}{\left(w_{i,j+1}+w_{i+1,j}\right)}\\ &=f_{i+1,j}+f_{i,j+1} \end{aligned} \]

\(x>y\),在 \(f_{i,j}\) 处取决策 \(y\)\(f_{i+1,j+1}\) 处取决策 \(x\) 放缩即可。因此命题成立。

\(\color{orange}{记\ f_{i,j}\ 的决策点为\ s_{i,j}。}\)

再证 \(\color{violet}{f\ 满足决策单调性}\),即证 \(s_{i,j-1}\le s_{i,j}\le s_{i+1,j}\)。对于前者,记 \(x=s_{i,j}\)\(y=s_{i,j-1}\),考虑反证,若 \(x<y\),则:

\[\begin{aligned} && f_{x+1,j-1}+f_{y+1,j}&\le f_{y+1,j-1}+f_{x+1,j}\\ &\iff& \left(\color{teal}{f_{i,x}}+f_{x+1,j-1}+\color{teal}{w_{i,j-1}}\right)+\left(\color{teal}{f_{i,y}}+f_{y+1,j}+\color{teal}{w_{i,j}}\right)&\le \left(\color{teal}{f_{i,y}}+f_{y+1,j-1}+\color{teal}{w_{i,j-1}}\right)+\left(\color{teal}{f_{i,x}}+f_{x+1,j}+\color{teal}{w_{i,j}}\right)\\ &\iff& (f_{i,j-1},x)-(f_{i,j-1},y)&\le (f_{i,j},x)-(f_{i,j},y) \end{aligned} \]

注意到左式 \(\ge 0\),右式 \(\le 0\),因此 \(x<y\) 不优,故 \(s_{i,j-1}\le s_{i,j}\)。后者证明类似。\(\square\)

回到区间 dp,我们在计算 \(f_{i,j}\) 时,只需在 \(k\in [s_{i,j-1},s_{i+1,j}]\) 中遍历即可。

考虑 \(\color{blue}{f\ 每个主对角方向的斜线}\),遍历次数为 \(\sum\limits_{i}(s_i-s_{i-1})=\mathcal O(n)\),因此总时间复杂度为 \(\mathcal O(n^2)\)

Garsia - Wachs 算法(针对序列的情形)

算法的流程:不妨 \(a_0=a_{n+1}=+\infty\),找到第一个 \(k\) 满足 \(a_{k-1}\le a_{k+1}\),合并 \(a_{k-1}\)\(a_{k}\),并从 \(k\) 往前找到第一个 \(j\) 满足 \(a_j>a_{k-1}+a_k\),将 \((a_{k-1}+a_k)\) 插到 \(j\) 后面。反复执行此流程,直至 \(|a|=1\),代价和为 \(\sum (a_{k-1}+a_k)\)

算法的证明鸽了,论文找不到。奇偶位分别开一个平衡树维护,\(a_i\) 增量处理,时间复杂度 \(\mathcal O(n\log n)\)

\(\color{cyan}{坑}\)\(x=a_{k-1}+a_k\),它被插到了 \(a_j\) 后面成为 \(a_{j+1}\),分析可能会带来的新的影响(容易遗漏!):

  • 可能 \(a_{j+1}=a_{j+3}\),要将 \((j+1,j+2)\) 合并;
  • 可能 \(a_{j-1}\le a_{j+1}\),要将 \((j-1,j)\) 合并,\(\color{red}{但这还可能导致\ a_{j-3}\le a_j}\),因此我们需 while 循环直到不能合并为止}。

AC 代码(\(\mathcal O(n\log n)\)

// #include <bits/stdc++.h>
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <cassert>
using namespace std;

const int inf = 2e9 + 5;
const int N = 500005;

// mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
// mt19937 rng(2020);
// 两棵 “递减链状” Treap (奇偶位)
int rnd[N], sz[N], val[N], maxx[N], ls[N], rs[N], tot;
inline int newnode(int x) {
    ++tot, rnd[tot] = rand(), sz[tot] = 1, val[tot] = maxx[tot] = x, ls[tot] = rs[tot] = 0;
    return tot;
}
inline void pushup(int x) {
    sz[x] = sz[ls[x]] + sz[rs[x]] + 1, maxx[x] = max(val[x], max(maxx[ls[x]], maxx[rs[x]])); // !
}
void split_sz(int u, int rk, int &x, int &y) {
    if (!u) {
        x = y = 0;
        return ;
    }
    if (sz[ls[u]] >= rk) {
        split_sz(ls[u], rk, x, ls[u]);
        y = u;
        pushup(u);
    } else {
        split_sz(rs[u], rk - sz[ls[u]] - 1, rs[u], y);
        x = u;
        pushup(u);
    }
}
void split_val(int u, int w, int &x, int &y) {
    if (!u) {
        x = y = 0;
        return ;
    }
    assert(maxx[u] > w);
    if (maxx[rs[u]] > w) {
        // printf(">>> 1\n");
        split_val(rs[u], w, rs[u], y);
        x = u;
        pushup(u);
    } else if (val[u] > w) {
        // printf(">>> 2\n");
        x = u, y = rs[u], rs[u] = 0;
        pushup(u);
    } else {
        assert(maxx[ls[u]] > w);
        split_val(ls[u], w, x, ls[u]);
        y = u;
        pushup(u);
    }
}
int merge(int x, int y) {
    if (!x || !y) return x + y;
    if (rnd[x] < rnd[y]) {
        rs[x] = merge(rs[x], y);
        pushup(x);
        return x;
    } else {
        ls[y] = merge(x, ls[y]);
        pushup(y);
        return y;
    }
}

void dfs(int u) {
    if (!u) return ;
    dfs(ls[u]);
    // printf("%d(%d, %d, ls = %d, rs = %d) ", val[u], u, maxx[u], ls[u], rs[u]);
    // cerr << val[u] << ' ';
    if(val[u]!=inf)printf("%d ", val[u]);
    dfs(rs[u]);    
}

int a[N], n, cur, root[2], ans;

// (可行, 删除数)
pair<bool, int> check(int pos) {
    if (pos <= 3 || pos + 1 > cur) return {false, 0};
    // 找出 a[pos - 1] 和 a[pos + 1]
    int type = !(pos & 1), k1, k2, k3, k4;
    split_sz(root[type], pos / 2 - 1, k1, k2);
    split_sz(k2, 2, k2, k4), split_sz(k2, 1, k2, k3);
    if (val[k2] > val[k3]) {
        root[type] = merge(merge(k1, k2), merge(k3, k4));
        return {false, 0};
    }
    // cerr << "check pos = " << pos << '\n';
    assert(sz[k2] == 1 && sz[k3] == 1);
    int hp = val[k2];
    root[type] = merge(k1, merge(k3, k4));
    
    // 找出 a[pos]
    k1 = k2 = k3 = k4 = 0;
    type ^= 1;
    split_sz(root[type], (pos + 1) / 2 - 1, k1, k2);
    split_sz(k2, 1, k2, k3);
    hp += val[k2];
    root[type] = merge(k1, k3);
    cur -= 2;
    
    ans += hp;
    // cerr << "hp = " << hp << '\n';
    // cerr << "cur = " << cur << '\n';
    // cerr << "odd : "; dfs(root[1]); cerr << '\n';
    // cerr << "even: "; dfs(root[0]); cerr << '\n';
    
    // 剔除 a[cur]
    int tmp1, tmp2, newpos;
    split_sz(root[cur & 1], (cur + 1) / 2 - 1, tmp1, tmp2);
    assert(sz[tmp2] == 1);
    root[cur & 1] = tmp1;
    
    // cerr << "odd : "; dfs(root[1]); cerr << '\n';
    // cerr << "even: "; dfs(root[0]); cerr << '\n';

    k1 = k2 = k3 = k4 = 0;
    split_val(root[0], hp, k1, k2);
    split_val(root[1], hp, k3, k4);
    assert(maxx[k2] <= hp && maxx[k4] <= hp);
    // printf("sz1 = %d, sz2 = %d, (%d, %d), hp = %d\n", sz[k1], sz[k3], sz[k2], sz[k4], hp);
    if (2 * sz[k1] > 2 * sz[k3] - 1) { // 插在 k1 后面
        // cerr << "case 1\n";
        // cerr << sz[k1] << '\n';
        newpos = 2 * sz[k1] + 1;
        root[1] = merge(k3, k4);
        k3 = k4 = 0, split_sz(root[1], sz[k1], k3, k4);
        root[1] = merge(merge(k3, newnode(hp)), k2);
        root[0] = merge(k1, k4);
    } else { // 插在 k3 后面
        newpos = 2 * sz[k3];
        root[0] = merge(k1, k2);
        k1 = k2 = 0, split_sz(root[0], sz[k3] - 1, k1, k2);
        root[0] = merge(merge(k1, newnode(hp)), k4);
        root[1] = merge(k3, k2);
    }

    // 把 a[cur] 加回去
    cur++;
    root[cur & 1] = merge(root[cur & 1], tmp2);

    // cerr << "cur = " << cur << '\n';
    // cerr << "odd : "; dfs(root[1]); cerr << '\n';
    // cerr << "even: "; dfs(root[0]); cerr << '\n';

    int del = 1;
    while (1) {
        pair<bool, int> pref = check(newpos - 1);
        if (!pref.first) break;
        del += pref.second, newpos -= pref.second; 
    }
    del += check(newpos + 1).second;
    return {true, del};
}

void sc() {
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
    // for (int i = 1; i <= n; i++) a[i] = rng() % 5 + 1;
    // for (int i = 1; i <= n; i++) printf("%d ", a[i]); puts("");

    a[n + 1] = inf;

    maxx[0] = 0, tot = root[0] = root[1] = ans = 0;
    root[0] = newnode(inf), root[1] = newnode(inf), cur = 2;
    for (int i = 1; i <= n + 1; i++) {
        cur++;
        root[cur & 1] = merge(root[cur & 1], newnode(a[i]));
        // printf("now i = %d, cur = %d\n", i, cur);
        while (check(cur - 1).first) ; // cur 处可能会破坏 “单调递减” 性质, 在 pushup 处特殊处理
        // for (int j = cur - 1; j >= 1; j--) {
        //     assert(!check(cur - 1));
        // }
        // printf("i = %d, cur = %d\n", i, cur);
        // if(1){dfs(root[1]),puts(""),dfs(root[0]),puts("");}
    }
    printf("%d\n", ans);
}

int main() {
    // freopen("../data.in", "r", stdin);
    // freopen("../my.out", "w", stdout);
    while (~scanf("%d", &n) && n) sc();
    return 0;
}

/*
6
4 2 4 2 5 5
5
3 4 1 2 3
25
3 5 2 1 5 2 4 1 3 1 4 1 1 3 3 5 1 5 1 5 4 4 1 1 1 
15
403 500 56 206 54 352 3 232 485 292 473 190 126 465 210 
*/
posted @ 2024-06-24 13:53  wlzhouzhuan  阅读(149)  评论(0编辑  收藏  举报