dp-最优二叉搜索树

最优二叉搜索树

问题描述

最优二叉搜索树(Optimal Binary Search Tree,Optimal BST)问题,形式化定义:给定一个n个不同关键字的已排序的序列K=<k1, k2, ..., kn>(k1<k2<...<kn),用这些关键字构造一棵二叉搜索树 —— 对每个关键字ki,都有一个概率pi表示其搜索频率。对于不在K中的搜索值构造n+1个”伪关键字“d0, d1, d2, ..., dn —— 伪关键字di表示所有在ki和ki+1之间的值(i=1,2,...,n-1,d0表示所有小于k1的值,dn表示所有大于kn的值),每个伪关键字di对应一个概率qi(搜索频率)。

例如,对一个n=5的关键字集合及如下的搜索频率,构造二叉搜索树。

i 0 1 2 3 4 5
pi 0 0.15 0.1 0.05 0.1 0.2
qi 0.05 0.1 0.05 0.05 0.05 0.1

如下为两种可能的结构,左侧的期望代价为2.8,右侧为最优2.75

另外,上述例子有两种最优,期望代价均为2.75,另一种是:(并且它的深度更小)

                k4
        k2               k5
    k1         k3     d4    d5
d0    d1    d2   d3

问题分析

分析这个问题,要从观察子树特征开始。一个二叉搜索树的任意子树,必须包含连续关键字 ki...kj 和伪关键字 di-1...dj.
而每个子树都有且只有一个根节点,根节点k将子树分为两部分: <i...k-1> 和 <k+1...j>,分别为左右子树。
所以问题转化为了递归地求解连续节点的根节点问题,假设函数f为期望代价:

f(i...j) = min(
    f(i-1)+f(i+1...n)+c[i...j],         // i为根节点
    f(i)+f(i+2...n)+c[i...j],           // i+1为根节点
    f(i..i+1)+f(i+3...n)+c[i...j],      // i+2为根节点
    ...
    f(i..k-1)+f(k+1...n)+c[i...j],      // k为根节点
    ...
    f(i...j-1)+f(j+1)+c[i...j]          // j为根节点
)

对某种已知结构的二叉树,每个节点的搜索代价为其搜索概率乘以深度+1(根节点深度为0,所以加1),然后将树中所有节点代价相加,得到树的期望搜索代价。
当将两个子树放在其同一个根节点上组成一个新的树时,所有节点深度都增加1,所以上述递推关系中,每种情况都增加了一个额外的代价c[i...j],即ki到kj所有节点的搜索概率之和:
c[i...j] = sum(ki...kj) + sum(di-1...dj)

思路

e[i,j] 为<i...j>构成的子树最小期望代价
w[i,j] 为<i...j>构成的子树所有节点概率相加,相当于上面的c[i...j]
root[i,j] 为<i...j>构成的子树最小期望对应的根节点id

为了处理边界问题,e和w长度都增加1,即第一维长度为n+1

OPTIMAL-BST(p, q, n)
    let e[1 : n + 1, 0 : n], w[1 : n + 1, 0 : n], and root[1 : n, 1 : n] be new tables
    for i = 1 to n + 1
        e[i, i - 1] = q_{i-1}
        w[i, i - 1] = q_{i-1}
    for l = 1 to n
        for i = 1 to n - l + 1
            j = i + l - 1
            e[i, j] = ∞
            w[i, j] = w[i, j - 1] + p_j + q_j
            for r = i to j
                t = e[i, r - 1] + e[r + 1, j] + w[i, j]
                if t < e[i, j]
                    e[i, j] = t
                    root[i, j] = r
    return e and root

程序

// optimal binary search tree
// 最优二叉搜索树
#include <iostream>
#include <float.h>
#include <vector>
#include <climits>
#include <time.h>

#define N 7

using namespace std;

void solve();
template<typename T>
void print_mat(T** mat, int len1, int len2);

int main(int argc, char** argv) {
    solve();


    return 0;
}

void solve() {
    //int keys[N+1]{0,1,2,3,4,5};
    // 下标0的元素为0,是为了和weights2对齐,无意义
    //double weights1[N+1]{0,0.15,0.1,0.05,0.1,0.2};
    //double weights2[N+1]{0.05,0.1,0.05,0.05,0.05,0.1};

    int keys[N+1]{0,1,2,3,4,5,6,7};
    double weights1[N+1]{0,0.04,0.06,0.08,0.02,0.1,0.12,0.14};
    double weights2[N+1]{0.06,0.06,0.06,0.06,0.05,0.05,0.05,0.05};

    // 存放权重(期望),只有右上和对角下一行有值
    //   对角下一行wei[i][i-1]表示weights2[i-1]
    double** wei = new double*[N+1];
    // wei1[i][j] : sum(weights1[i...j]) + sum(weights2[i-1,...,j])
    double** wei1 = new double*[N+1];
    // 右上角和对角存放根节点id,左下代表深度
    int** rid = new int*[N+1];
    for (int i = 0; i < N+1; ++i) {
        wei[i] = new double[N+1]{0};
        wei1[i] = new double[N+1]{0};
        rid[i] = new int[N+1]{0};
        if (i > 0) {
            wei[i][i-1] = weights2[i-1];
            wei[i][i] = weights1[i] + 2*(weights2[i-1] + weights2[i]);
            rid[i][i] = i;
            wei1[i][i-1] = weights2[i-1];
            wei1[i][i] = weights1[i] + weights2[i-1] + weights2[i];
        }
        // 相邻两个节点组成的树深度为2(包括伪关键字,根节点深度为0)
        //   剩余位置用一个比较大的值占位
        for (int j = 0; j < i; ++j) {
            rid[i][j] = i==j+1 ? 2 : N*2;
        }
    }

    for (int diff = 1; diff < N; ++diff) {
        for (int i = 1; i < N+1-diff; ++i) {
            int j = i + diff;
            wei1[i][j] = wei1[i][j-1] + weights1[j] + weights2[j];
            wei[i][j] = DBL_MAX;
            for (int k = i; k <= j; ++k) {      // 旧方案
            // 习题15.5-4, rid[i][j-1] <= rid[i][j] <= rid[i+1][j]
            //   上一层循环的时间复杂度变为 O(n),总时间复杂度变为 O(n^2)
            //   这样虽然时间复杂度降低了,但深度的计算变得不准确;旧方案的深度是准确的
            // for (int k = rid[i][j-1]; k <= rid[i+1][j]; ++k) {
                double w = wei1[i][j] + wei[i][k-1] + (k<N ? wei[k+1][j] : weights2[j]);
                int deep = rid[j][i];
                if (j == i+1) {
                    deep = 2;
                } else {
                    if (k <= i+1) {
                        deep = rid[k+1][i] + 1;
                    } else if (j <= k+1) {
                        deep = rid[j][k-1] + 1;
                    } else {
                        deep = max(rid[k-1][i], rid[j][k+1]) + 1;
                    }
                }
                if (w <= wei[i][j]) {
                    wei[i][j] = w;
                    rid[i][j] = k;
                    if (deep < rid[j][i]) {
                        rid[j][i] = deep;
                        rid[i][j] = k;
                    }
                }
                cout << i << "," << j << "," << k << "\tw: " << w << "\td: " << deep << endl;
            }
        }
    }
    cout << "wei:\n";
    print_mat(wei, N+1, N+1);
    cout << "wei1:\n";
    print_mat(wei1, N+1, N+1);
    cout << "rid:\n";
    print_mat(rid, N+1, N+1);

    // print structure
    cout << "structure:\n";
    vector<pair<int, int>> stru{{1, N}};
    pair<int, int> p;
    int id, count = 1;
    while (!stru.empty()) {
        p = stru[0];
        id = rid[p.first][p.second];
        cout << keys[id];
        if (p.first < id) {
            stru.push_back({p.first, id-1});
        } else {
            cout << "(L)";      // 左子树为空
        }
        if (id < p.second) {
            stru.push_back({id+1, p.second});
        } else {
            cout << "(R)";      // 右子树为空
        }
        stru.erase(stru.begin());
        cout << "\t";
        if (--count == 0) {
            cout << endl;
            count = stru.size();
        }
    }

    for (int i = 0; i < N+1; ++i) {
        delete[] wei[i];
        delete[] wei1[i];
        delete[] rid[i];
    }
    delete[] wei;
    delete[] wei1;
    delete[] rid;
}

template<typename T>
void print_mat(T** mat, int len1, int len2) {
    for (int i = 0; i < len1; ++i) {
        for (int j = 0; j < len2; ++j) {
            cout << mat[i][j] << "\t";
        }
        cout << endl;
    }
}

测试:

$ g++ -o bst bst.cpp && ./bst
1,2,1	w: 0.64	d: 2
1,2,2	w: 0.62	d: 2
2,3,2	w: 0.7	d: 2
2,3,3	w: 0.68	d: 2
3,4,3	w: 0.57	d: 2
3,4,4	w: 0.64	d: 2
4,5,4	w: 0.64	d: 2
4,5,5	w: 0.57	d: 2
5,6,5	w: 0.74	d: 2
5,6,6	w: 0.72	d: 2
6,7,6	w: 0.8	d: 2
6,7,7	w: 0.78	d: 2
1,3,1	w: 1.16	d: 3
1,3,2	w: 1.02	d: 4
1,3,3	w: 1.1	d: 3
2,4,2	w: 1.02	d: 3
2,4,3	w: 0.93	d: 4
2,4,4	w: 1.12	d: 3
3,5,3	w: 1.05	d: 3
3,5,4	w: 1.04	d: 4
3,5,5	w: 1.04	d: 3
4,6,4	w: 1.23	d: 3
4,6,5	w: 1.01	d: 4
4,6,6	w: 1.07	d: 3
5,7,5	w: 1.39	d: 3
5,7,6	w: 1.2	d: 4
5,7,7	w: 1.33	d: 3
1,4,1	w: 1.48	d: 3
1,4,2	w: 1.34	d: 4
1,4,3	w: 1.35	d: 4
1,4,4	w: 1.56	d: 3
2,5,2	w: 1.64	d: 3
2,5,3	w: 1.41	d: 4
2,5,4	w: 1.52	d: 4
2,5,5	w: 1.52	d: 3
3,6,3	w: 1.66	d: 3
3,6,4	w: 1.63	d: 4
3,6,5	w: 1.48	d: 4
3,6,6	w: 1.68	d: 3
4,7,4	w: 1.9	d: 3
4,7,5	w: 1.66	d: 4
4,7,6	w: 1.55	d: 4
4,7,7	w: 1.7	d: 3
1,5,1	w: 2.11	d: 3
1,5,2	w: 1.96	d: 4
1,5,3	w: 1.83	d: 3
1,5,4	w: 1.96	d: 4
1,5,5	w: 2.03	d: 3
2,6,2	w: 2.25	d: 3
2,6,3	w: 2.02	d: 4
2,6,4	w: 2.11	d: 3
2,6,5	w: 1.96	d: 4
2,6,6	w: 2.17	d: 3
3,7,3	w: 2.39	d: 3
3,7,4	w: 2.3	d: 4
3,7,5	w: 2.13	d: 3
3,7,6	w: 2.16	d: 4
3,7,7	w: 2.31	d: 3
1,6,1	w: 2.83	d: 3
1,6,2	w: 2.57	d: 4
1,6,3	w: 2.44	d: 4
1,6,4	w: 2.55	d: 4
1,6,5	w: 2.47	d: 4
1,6,6	w: 2.69	d: 3
2,7,2	w: 3.09	d: 3
2,7,3	w: 2.75	d: 4
2,7,4	w: 2.78	d: 4
2,7,5	w: 2.61	d: 4
2,7,6	w: 2.65	d: 4
2,7,7	w: 2.91	d: 3
1,7,1	w: 3.67	d: 3
1,7,2	w: 3.41	d: 4
1,7,3	w: 3.17	d: 4
1,7,4	w: 3.22	d: 4
1,7,5	w: 3.12	d: 4
1,7,6	w: 3.17	d: 4
1,7,7	w: 3.49	d: 3
wei:
0	0	0	0	0	0	0	0
0.06	0.28	0.62	1.02	1.34	1.83	2.44	3.12
0	0.06	0.3	0.68	0.93	1.41	1.96	2.61
0	0	0.06	0.32	0.57	1.04	1.48	2.13
0	0	0	0.06	0.24	0.57	1.01	1.55
0	0	0	0	0.05	0.3	0.72	1.2
0	0	0	0	0	0.05	0.32	0.78
0	0	0	0	0	0	0.05	0.34
wei1:
0	0	0	0	0	0	0	0
0.06	0.16	0.28	0.42	0.49	0.64	0.81	1
0	0.06	0.18	0.32	0.39	0.54	0.71	0.9
0	0	0.06	0.2	0.27	0.42	0.59	0.78
0	0	0	0.06	0.13	0.28	0.45	0.64
0	0	0	0	0.05	0.2	0.37	0.56
0	0	0	0	0	0.05	0.22	0.41
0	0	0	0	0	0	0.05	0.24
rid:
0	0	0	0	0	0	0	0
2	1	2	2	2	3	3	5
14	2	2	3	3	3	5	5
14	3	2	3	3	5	5	5
14	3	3	2	4	5	5	6
14	3	3	3	2	5	6	6
14	3	3	3	3	2	6	7
14	3	3	3	3	3	2	7
structure:
5
2	7(R)
1(L)(R)	3(L)	6(L)(R)
4(L)(R)
time cost: 0.32 ms

新方案的时间是 0.255 ms,但深度计算错误

posted @ 2023-08-13 16:36  keep-minding  阅读(180)  评论(0编辑  收藏  举报