[Coding Made Simple] Optimal Binary Search Tree
Given keys and frequency at which these keys are searched, how would you create a binary search tree from these keys
such that the cost of searching is minimum.
The cost of searching is defined as the sum of all node's search frequency * its depth; Root node's depth is 1.
Solution 1. Recursion
Algorithm:
1. first sort the input keys in ascending order based on their values;
2. from left to right of the sorted keys, pick keys[i] as the root node, then recursively calculate the search min cost of
the left subtree and right subtree;
3. repeat step 2 for all keys and upate the depths of each node that gives the min search cost.
4. reconstruct the optimal BST using the minCostDepths[] that contains the depth of each node, this BST gives the min search cost.
This recursive solution is not efficient in that it has overlapping subprolems. This is demonstrated in the following example.
Given keys {10, 11, 12, 13, 14}:
if 11 is used as root, then the subproblems of {10} and {12, 13, 14} are solved recursively;
if 12 is used as root, then the subproblems of {10, 11} and {13, 14} are solved recursively;
When solving subproblem of {12, 13, 14} in the case of 11 as root, we need to solve subproblem of {13, 14} after selecting 12 as root of subproblem {12, 13, 14};
This subproblem of {13, 14} is later solved again in the case of 12 as root; This is redundant work!
1 import java.util.Arrays; 2 import java.util.Comparator; 3 4 class TreeNode{ 5 int key; 6 int freq; 7 TreeNode left; 8 TreeNode right; 9 TreeNode(int key, int freq){ 10 this.key = key; 11 this.freq = freq; 12 this.left = null; 13 this.right = null; 14 } 15 } 16 public class OptimalBST { 17 public TreeNode getOptimalBST(TreeNode[] nodes){ 18 if(nodes == null || nodes.length == 0){ 19 return null; 20 } 21 Comparator<TreeNode> comp = new Comparator<TreeNode>(){ 22 public int compare(TreeNode node1, TreeNode node2){ 23 return node1.key - node2.key; 24 } 25 }; 26 Arrays.sort(nodes, comp); 27 int[] depths = new int[nodes.length]; 28 int[] minCostDepths = new int[nodes.length]; 29 getOptimalBSTRecursive(nodes, minCostDepths, depths, 0, nodes.length - 1, 1); 30 return reconstructOptimalBST(nodes, minCostDepths, 0, nodes.length - 1); 31 } 32 private int getOptimalBSTRecursive(TreeNode[] nodes, int[] minCostDepths, int[] depths, int start, int end, int depth){ 33 if(start > end){ 34 return 0; 35 } 36 int minCost = Integer.MAX_VALUE; 37 int cost = 0; 38 for(int i = start; i <= end; i++){ 39 depths[i] = depth; 40 cost += nodes[i].freq * depth; 41 cost += getOptimalBSTRecursive(nodes, minCostDepths, depths, start, i - 1, depth + 1); 42 cost += getOptimalBSTRecursive(nodes, minCostDepths, depths, i + 1, end, depth + 1); 43 if(cost < minCost){ 44 minCost = cost; 45 for(int j = start; j <= end; j++){ 46 minCostDepths[j] = depths[j]; 47 } 48 } 49 } 50 return minCost; 51 } 52 private TreeNode reconstructOptimalBST(TreeNode[] nodes, int[] minCostDepths, int start, int end){ 53 if(start > end){ 54 return null; 55 } 56 int minDepth = Integer.MAX_VALUE; 57 int minDepthIdx = 0; 58 for(int i = start; i <= end; i++){ 59 if(minCostDepths[i] < minDepth){ 60 minDepth = minCostDepths[i]; 61 minDepthIdx = i; 62 } 63 } 64 nodes[minDepthIdx].left = reconstructOptimalBST(nodes, minCostDepths, start, minDepthIdx - 1); 65 nodes[minDepthIdx].right = reconstructOptimalBST(nodes, minCostDepths, minDepthIdx + 1, end); 66 return nodes[minDepthIdx]; 67 } 68 }
Solution 2. Dynamic Programming
Given a list of n nodes that are sorted in ascending order using their keys, nodes[0..... n - 1],
State: T[i][j] contains the min cost of the bst using nodes[i....j] and the index of the root of this optimal bst.
Function: For a given range of nodes, its min cost is the sum of each nodes' key
T[i][j].cost = sum of nodes' freq in nodes[i....j] + min of {T[i][k - 1] + T[k + 1][j], for k is from i to j}
Initialization: T[i][i].cost = nodes[i].freq; T[i][i].idx = i;
Answer: T[0][n - 1] contains the min cost the optimal bst and its root idx.
The whole optimal bst can be constructed recursively base on the DP 2D array infos.
The DP solution avoids the duplicated work in solution 1 at the cost of O(n^2) extra memory.
1 import java.util.Arrays; 2 import java.util.Comparator; 3 4 class BstNode { 5 int key; 6 int freq; 7 BstNode left; 8 BstNode right; 9 BstNode(int key, int freq) { 10 this.key = key; 11 this.freq = freq; 12 this.left = null; 13 this.right = null; 14 } 15 } 16 public class OptimalBST { 17 class optimalBstEntry { 18 int cost; 19 int idx; 20 optimalBstEntry(int c, int i) { 21 this.cost = c; 22 this.idx = i; 23 } 24 } 25 public BstNode getOptimalBST(BstNode[] nodes) { 26 if(nodes == null || nodes.length == 0) { 27 return null; 28 } 29 Comparator<BstNode> comp = new Comparator<BstNode>() { 30 public int compare(BstNode node1, BstNode node2) { 31 return node1.key - node2.key; 32 } 33 }; 34 Arrays.sort(nodes, comp); 35 optimalBstEntry[][] T = new optimalBstEntry[nodes.length][nodes.length]; 36 for(int i = 0; i < nodes.length; i++) { 37 for(int j = i; j < nodes.length; j++) 38 T[i][j] = new optimalBstEntry(Integer.MAX_VALUE, -1); 39 } 40 for(int i = 0; i < nodes.length; i++) { 41 T[i][i].cost = nodes[i].freq; 42 T[i][i].idx = i; 43 } 44 for(int len = 2; len <= nodes.length; len++) { 45 for(int startIdx = 0; startIdx <= nodes.length - len; startIdx++) { 46 int sum = 0; int min = Integer.MAX_VALUE; 47 for(int i = startIdx; i < (startIdx + len); i++) { 48 sum += nodes[i].freq; 49 } 50 for(int rootIdx = startIdx; rootIdx <= (startIdx + len - 1); rootIdx++) { 51 int leftCost = 0, rightCost = 0; 52 if(rootIdx - 1 >= startIdx) { 53 leftCost = T[startIdx][rootIdx - 1].cost; 54 } 55 if(rootIdx + 1 <= startIdx + len - 1) { 56 rightCost = T[rootIdx + 1][startIdx + len - 1].cost; 57 } 58 if(leftCost + rightCost < min) { 59 T[startIdx][startIdx + len - 1].cost = sum + leftCost + rightCost; 60 T[startIdx][startIdx + len - 1].idx = rootIdx; 61 min = leftCost + rightCost; 62 } 63 } 64 } 65 } 66 //construct the optimal BST 67 return constructOptimalBst(nodes, T, 0, nodes.length - 1); 68 } 69 private BstNode constructOptimalBst(BstNode[] nodes, optimalBstEntry[][] T, int start, int end) { 70 if(start > end) { 71 return null; 72 } 73 int rootIdx = T[start][end].idx; 74 BstNode root = nodes[rootIdx]; 75 root.left = constructOptimalBst(nodes, T, start, rootIdx - 1); 76 root.right = constructOptimalBst(nodes, T, rootIdx + 1, end); 77 return root; 78 } 79 }
Further optimization
When calculating the sum of keys in a given range, the above solution always start from stratch for a given range.
For example, when calculating the sum of keys of the node 0, 1, 2, 3, it adds all 4 nodes' keys. Since we've already
calculated the key sums of 3 nodes, we can use prefix sum to memorize the previous calculation results. Doing this
makes this sum operation from O(n) to O(1) since we only need to do one add operation for the newly included node.
However, this optimize does not change the asymptotic complexity, it only changes the constant factor.
The runtime is still O(n^3).
Related Topics
Prefix Sum