[NOI2009] 二叉查找树

[NOI2009] 二叉查找树

今天搞了一天区间 \(DP\) , 希望能有点进步.

这个题仍然是区间 \(DP\) , 比今天做的前几个要难一些, 因为它需要有一些性质的提炼.

乍一看, 莫得思路, 但是我们可以发现几个关键字眼: 改为任何实数 , 这说明什么, 说明那个必须保证所有数不同的限制不存在, 因为我们可以使得两个数无限接近而不相等, 为了方便处理, 我们就可以当成相等来处理.

然后, 我们发现这棵树其实是一颗 \(Treap\) , 满足二叉查找树的性质, 也满足堆性质, 一棵标准的 \(Treap\) , \(Treap\) 有一个性质: 中序遍历为这个序列从小到大的排序.

由于堆性质是按照数据值来排的, 而数据值又是不变的, 那么我们的中序遍历也是不变的, 我们只能改变它的权值, 就相当于是 \(Treap\) 的旋转操作.

那这个问题就简单多了, 我们设 \(f[i][j][k]\) 表示子树 \([i, j]\) 的根节点值为 \(k\) 时最小代价. 转移我们枚举子树 \([i, j]\) 的根节点 \(t\) , 那么 \(f[i][j][k] = \min(f[i][j][k], f[i][t - 1][k] + f[t + 1][j][k] + Sum(i, j) + K)\) , 如果 \(val[t] \ge k\) , 那么我们也可以不修改, 转移: \(f[i][j][k] = \min(f[i][j][k], f[i][t - 1][v] - f[t + 1][j][v] + Sum(i, j))\) , 这里的 \(v\) 就是 \(val[t]\) .

由于值的范围比较大, 而数据的个数很少, 我们又只关心相对关系, 那么我们就可以离散化.

复杂度 \(O(n^4)\) .

\(code:\)

#include <bits/stdc++.h>
using namespace std;
int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (isdigit(ch)) {
        x = (x << 1) + (x << 3) + (ch ^ 48);
        ch = getchar();
    }
    return x * f;
}
const int N = 75, M = 4e5 + 5, inf = 1 << 30;
int n, K, f[N][N][N], sum[N], mp[M], cnt;
struct Node {
    int dat, val, fre;
} a[N];
bool cmp(Node x, Node y) {
    return x.dat < y.dat;
}
int Sum(int l, int r) {
    return sum[r] - sum[l - 1];
}
int main() {
    n = read(), K = read();
    for (int i = 1; i <= n; i++) a[i].dat = read();
    for (int i = 1; i <= n; i++) {
        a[i].val = read(); mp[a[i].val] = 1;
    }
    for (int i = 1; i <= n; i++) a[i].fre = read();
    for (int i = 1; i < M; i++) if (mp[i]) mp[i] = ++cnt;
    for (int i = 1; i <= n; i++) a[i].val = mp[a[i].val];
    sort(a + 1, a + n + 1, cmp);
    for (int i = 1; i <= n; i++) sum[i] = sum[i - 1] + a[i].fre;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            f[i][i][j] = a[i].fre;
            if (a[i].val < j) f[i][i][j] += K;
        }
    }
    for (int l = 2; l <= n; l++) {
        for (int i = 1; i + l - 1 <= n; i++) {
            int j = i + l - 1;
            for (int k = 1; k <= n; k++) {
                f[i][j][k] = inf;
                for (int t = i; t <= j; t++) {
                    f[i][j][k] = min(f[i][j][k], f[i][t - 1][k] + f[t + 1][j][k] + Sum(i, j) + K);
                    int v = a[t].val;
                    if (v >= k) f[i][j][k] = min(f[i][j][k], f[i][t - 1][v] + f[t + 1][j][v] + Sum(i, j));
                }
            }
        }
    }
    printf("%d", f[1][n][1]);
    return 0;
}
posted @ 2021-08-23 21:39  sshadows  阅读(32)  评论(0编辑  收藏  举报