Optimal Binary Search Trees

Tags:

BSTs are used to organize a set of search keys for fast access: the tree maintains the keys in-order so that comparison with the query at any node either results in a match, or directs us to continue the search in left or right sub-tree.

For this problem we are given a set of search keys (0, 1, … n) along with the search frequency count (f0, f1, …. fn) of each key. The set of keys is sorted. A BST can be constructed from such a set of keys, so that keys can be searched quickly but there’s a cost associated with the search operation on BST. Searching cost for a key/node in the BST is defined as – level of that key/node multiplied by its frequency. Level of root node is 1. Total searching cost of the BST is the sum of searching cost of all the keys/nodes in the BST. Given a set of keys the problem is to arrange the keys in a BST that minimizes the total searching cost.

For example:
Keys: {0 ,1} and Freq: {10, 20}
Possible BSTs created from this set of keys are:

Possible BSTs
1) Total cost of BST = (level of key0 * freq of key0) +
                       (level of key1 * freq of key1)
                     = (1 * 10) + (2 * 20) 
                     = 50
2) Total cost of BST = (level of key1 * freq of key1) +
                       (level of key0 * freq of key0)
                     = (1 * 20) + (2 * 10)
                     = 40 

Hence, the minimum total searching cost for given set of keys is 40.

Thought Process:

As per definition of searching cost for a key in the BST – (level of key (L) * freq of key (F)), here we can observe that starting from level ‘1’ till level ‘L’ at each level the key contributes ‘F’ to the total cost and that’s why its searching cost is (L * F).

In order to minimize the total search cost, a simple Greedy approach comes to mind where we try to keep keys with higher frequency at the top of the tree like we choose the key with highest frequency as root, then from the keys on the left of it we again choose a key with highest frequency and make it the left child of root and similarly we choose the right child of the root from the keys on the right and build a BST.

But will this approach build a BST that give minimum total search cost? To prove a Greedy approach works, we have to give a proof. But to prove a Greedy approach fails, we just have to give an example where it doesn’t work.

Let’s consider this example:

Keys0123456
Freq22182052528

Using the Greedy approach discussed above let’s build a BST and calculate its total cost:
Key ‘4’ has highest frequency: 25, so it will be the root. Keys {0,…,3} will be used to build the left sub-tree and keys {5, 6} will be used to build the right sub-tree.
Among keys {0,…,3}, key ‘0’ has the highest frequency, hence it will be the left child of the root.
Among keys {5, 6}, key ‘6’ has the highest frequency, hence it will be the right child of the root.
We keep doing this all the remaining keys and the final BST would look like this:

Greedy BST
If Level of Key(k) is Lk and Frequency of Key(k) is Fk, then - 

Total cost of Greedy BST = (L4 * F4)+(L0 * F0)+(L6 * F6)+
                           (L2 * F2)+(L5 * F5)+(L1 * F1)+
                           (L3 * F3)
                         = (1 * 25)+(2 * 22)+(2 * 8) +
                           (3 * 20)+(3 * 2) +(4 * 18)+
                           (4 * 5)
                         = 243

But is there any other possible BST that has lower total searching cost?
Let’s consider this BST:

Optimal BST
Total cost of this BST = (1 * 20) + (2 * 22) + (2 * 25) + 
                         (3 * 18) + (3 * 5) + (3 * 8) + 
                         (4 * 2)
                       = 215

This BST has lower total cost (215) than the BST created using Greedy approach (243). Hence the Greedy approach fails to solve this problem. But then how do we find the optimal BST?

Solution:

Let’s consider a given set of keys {Ki, … , Kj} and Min_Total_Cost(i, j) returns the total searching cost for the optimal BST for this set.
Let’s say we have created the optimal BST for this set of keys and ‘Kr’ is the root of this BST such that i <= r <= j, the tree would look like this:

                                                             Kr

                                                             / \

                                           Ki,…, Kr-1    Kr+1,…, Kj

The keys on the left of Kr in the given set will be part of left sub-tree and the keys on the right will be part of right sub-tree. If Total_Cost(i, j) gives the total searching cost for this BST, then it includes –

  1. The searching cost of the root which is – level of root (1) * frequency of root key,
  2. The total cost of the left sub-tree and the total cost of the right sub-tree (the sub-problems),
  3. And as explained earlier that making the keys on the left and right of Kr in the given set the children of Kr will increase their path length by 1 and hence all these keys will incur that cost to the total cost, i.e. all keys which are yet to be included in the BST contribute a cost equal to their frequency to the total cost at every level, hence at each level we have sum of frequency of all such keys/nodes.
Total_Cost(i, j) =  (Level of Kr * Freq of Kr)
                   +(Total searching cost of left sub-tree)
                   +(Total searching cost of right sub-tree)
                   +(Sum of frequency of all the keys in the
                     left sub-tree)
                   +(Sum of frequency of all the keys in the
                     right sub-tree)
                          
                 = Total_Cost(i, r-1) +
                   Total_Cost(r+1, j) +
                   (Sum of frequency of all the keys {Ki,…,
                    Kj})

Since we do not know the key Kr, we will have to try out each key in the set as root of the BST and we will keep track of the minimum of the total searching cost of the BSTs as we calculate them.
Using this formula above we can write for Min_Total_Cost(i, j) as –

Min_Total_Cost(i, j) = min ( Min_Total_Cost(i, r-1)
                           + Min_Total_Cost(r+1, j) 
                           + Sum of all Fx for x in
                             {i,..,j} )
                      for all r in {i,..,j}

If i > j which doesn’t make a valid set of keys, Min_Total_Cost(i, j) = 0.

Also this shows this problem has optimal substructure (i.e. an optimal solution can be constructed from optimal solutions of subproblems).

Recursive Approach:

Using this we can write a recursive implementation:

C++:

#include <bits/stdc++.h>
using namespace std;

int Min_Total_Cost(int freq[], int i, int j)
{
    if (i > j) 
        return 0;
    
    int min_total_cost = INT_MAX;
    
    for (int k = i; k <= j; ++k)
    {
        int total_cost = ( Min_Total_Cost(freq, i, k-1) 
                                 + Min_Total_Cost(freq, k+1, j)
                                 + accumulate(freq+i, freq+j+1, 0));
        
        if (total_cost < min_total_cost)
            min_total_cost = total_cost;
    }
    
    return min_total_cost;
}

int getTotalCostOfOptimalBST(int keys[], int freq[], int num_keys)
{
    return Min_Total_Cost(freq, 0, num_keys-1);
}

int main() 
{
    int keys[] = {0, 1, 2};
    int freq[] = {34, 8, 50};
    int n = sizeof(keys) / sizeof(keys[0]);
    
    cout<<"Total cost of Optimal BST:"<<getTotalCostOfOptimalBST(keys, freq, n)<<endl;
    
    return 0;
}

Java:

import java.io.*;

class OptimalBST
{
    static int sum(int freq[], int left_idx, int right_idx)
    {
        int sum = 0;
        for (int i=left_idx; i <= right_idx; ++i)
        {
            sum += freq[i];
        }
        return sum;
    }

    static int Min_Total_Cost(int freq[], int i, int j)
    {
        if (i > j) 
            return 0;
        
        int min_total_cost = Integer.MAX_VALUE;
        
        for (int k = i; k <= j; ++k)
        {
            int total_cost = ( Min_Total_Cost(freq, i, k-1) 
                             + Min_Total_Cost(freq, k+1, j)
                             + sum(freq, i, j));
            
            if (total_cost < min_total_cost)
                min_total_cost = total_cost;
        }
        
        return min_total_cost;
    }

    static int getTotalCostOfOptimalBST(int keys[], int freq[], int num_keys)
    {
        return Min_Total_Cost(freq, 0, num_keys-1);
    }

    public static void main (String[] args) 
    {
        int keys[] = {0, 1, 2};
        int freq[] = {34, 8, 50};
        int n = keys.length;
    
        System.out.println("Total cost of Optimal BST:" + 
                            getTotalCostOfOptimalBST(keys, freq, n));
    }
}

But this implementation has exponential time complexity. To find the reason behind such high time complexity let’s have a look at the recursive function call tree:

Recursion tree

In this example of a set consisting of 3 keys {0, 1, 2}, we can see that subproblems such as  Min_Total_Cost(freq, 2, 2) and Min_Total_Cost(freq, 1, 1) are calculated repeatedly.
Our recursive algorithm for this problem solves the same subproblem over and over rather than always generating new subproblems. These are called overlapping subproblems.

As the two properties required for using Dynamic Programming : ‘optimal substructure’ and ‘overlapping subproblems’ hold, we can use DP for this problem.

Dynamic Programming Solution:

In DP we start calculating from the bottom and move up towards the final solution.
Hence we first solve the sub-problem {i=0, j=0}, then we skip all the sub-problems where (i > j), then next we solve {i=1, j=1}, and reuse solutions to these sub-problems to solve {i=0, j=1} and so on.
Finally we solve the sub-problem {i=0, j=(n-1)} and this gives us the final answer.

Solution of all subproblems are stored in a 2D array / DP table so that they can be reused when required.

C++:

#include <bits/stdc++.h>
using namespace std;

long long int getTotalCostOfOptimalBST(int keys[], int freq[], int num_keys)
{
    long long int DP_Table[num_keys][num_keys]{};
    
    for (int j = 0; j < num_keys; ++j)
    {
        for (int i = j; i >= 0; --i)
        {
            long long int min_total_cost = INT_MAX,
                                    sum_freq = accumulate(freq+i, freq+j+1, 0);

            for (int k = i; k <= j; ++k)
            {
                long long int total_cost = 0,
                                        total_cost_left_subtree = 0,
                                        total_cost_right_subtree = 0;
                    
                if (k > i)
                {
                    total_cost_left_subtree = DP_Table[i][k-1];
                }
                
                if (k < j)
                {
                    total_cost_right_subtree = DP_Table[k+1][j];
                }
                
                total_cost = ( total_cost_left_subtree
                                  + total_cost_right_subtree
                                  + sum_freq );
                
                if (total_cost < min_total_cost)
                    min_total_cost = total_cost;
            }
            
            DP_Table[i][j] = min_total_cost;
        }
    }
    
    return DP_Table[0][num_keys-1];
}

int main() 
{
    int keys[] = {0, 1, 2, 3, 4, 5, 6};
    int freq[] = {22, 18, 20, 5, 25, 2, 8};
    int num_keys = (sizeof(keys) / sizeof(keys[0]));

    cout<<"Total cost of Optimal BST:"
            <<getTotalCostOfOptimalBST(keys, freq, num_keys)<<endl;
    
	return 0;
}

Java:

import java.io.*;

class OptimalBST 
{
    static int sum(int freq[], int left_idx, int right_idx)
    {
        int sum = 0;
        for (int i=left_idx; i <= right_idx; ++i)
            sum += freq[i];
        return sum;
    }
    
    static int getTotalCostOfOptimalBST(int keys[], int freq[], int num_keys)
    {
        int DP_Table[][] = new int[num_keys][num_keys];
        
        for (int j = 0; j < num_keys; ++j)
        {
            for (int i = j; i >= 0; --i)
            {
                int min_total_cost = Integer.MAX_VALUE,
                    sum_freq = sum(freq, i, j);
    
                for (int k = i; k <= j; ++k)
                {
                    int total_cost = 0,
                        total_cost_left_subtree = 0,
                        total_cost_right_subtree = 0;
                        
                    if (k > i)
                        total_cost_left_subtree = DP_Table[i][k-1];
                    
                    if (k < j)
                        total_cost_right_subtree = DP_Table[k+1][j];
                    
                    total_cost = ( total_cost_left_subtree
                                         + total_cost_right_subtree
                                         + sum_freq );
                    
                    if (total_cost < min_total_cost)
                        min_total_cost = total_cost;
                }
                
                DP_Table[i][j] = min_total_cost;
            }
        }
        return DP_Table[0][num_keys-1];
    }

    public static void main (String[] args) 
    {
        int keys[] = {0, 1, 2, 3, 4, 5, 6};
        int freq[] = {22, 18, 20, 5, 25, 2, 8};
        int num_keys = keys.length;
    
        System.out.println("Total cost of Optimal BST is "
                                         + getTotalCostOfOptimalBST(keys, freq, num_keys));
    }
}

Time complexity: O(n^3)
Space complexity: O(n^2)

This article is contributed by Bhushan Agrawal.