Matrix chain multiplication-dynamic programming

Matrix chain multiplication dynamic programming

What is matrix chain multiplication in general? To read on that please refer to Wiki. However, today’s problem is not about actually multiplying chain of matrices, but to find out optimal way to multiply them in order to minimize number of scalar multiplications. There is something called as dimensions of matrix: rows  and columns.  To be able to multiply two matrices, it is required that number of columns in first matrix is equal to number of rows of second matrix. If we multiply a matrix of dimension M x N to another matrix N x P, we can a matrix of dimension M x P.

matrix chain multiplication

How many scalar multiplications needs to be done to multiply M x N to N x P matrix? It’s M x N x P.

Given N matrices with dimensions, find optimal way to multiply these matrices, in order to minimize total number of scalar multiplications.

Matrix chain multiplication : line of thoughts

Before going further, lets understand some basics of matrix multiplication

  1. Matrix multiplication is associative i.e.  A* (B*C) = (A*B) *C
  2. It is not commutative i.e  A * (B*C) not equal to A * (C * B)
  3. To multiply two matrices, they should be compatible i.e. no of columns in first matrix should be equal to number of rows of second matrix. No of columns of first matrix = No. of rows of second matrix

Since matrix multiplication is associative, this problem actually reduces to figure out a way to put parenthesis around matrices so that total number of scalar multiplications are least.
Let’s take an example and understand. We have matrices A, B, C and D with dimensions array as 10 x 20, 20 x 30, 30 x 40, 40 x 50 respectively.

How can we solve this problem manually?  Given chain of matrices is as ABCD. There are three ways to split the chain into two parts :  as (A) x (BCD) or as (AB) x (CD) or as (ABC) x (D).

Any which way, we have smaller problems to solve now. If we take the first split, cost of multiplication of ABCD is cost of multiplication A + cost of (BCD) + cost of multiplication of A x (BCD).

Similarly for rest two splits. Answer will be minimum of all three possible splits.

To get cost of ABCD,  solve cost of BCD.  (BCD) can be split in two parts : (B) x (CD) or (BC) x (D).

We will continue with (B) x (CD). For cost of (CD), splits in one way (C) x (D). Cost of which is nothing but M x N x P where C is matrix of dimension M x N and D is matrix of dimension N x P.

Cost of  (BC) = M x N x P. Below figure shows the detailed calculation of each possible splits and final gives the answer as minimum of all possible splits.

matrix chain multiplication brute force
Notice that for N matrices, there N-1 ways to split the chain. This manual calculation can easily be implemented as recursive solution.

Recursive implementation of matrix chain multiplication

package com.company;

/**
 * Created by sangar on 31.12.17.
 */
public class MCM {

    public static int matrixChainMultiplication(int[] P, int i, int j){
        int count = 0;
        int min = Integer.MAX_VALUE;

        System.out.println("("+ i + "," + j + ")");
        if(i==j) return 0; // No cost of multiplying zero matrix

        for(int k=i; k<j; k++){
            System.out.println("Parent : ("+ i + "," + j + ")");
            count = matrixChainMultiplication(P,i, k)
                    + matrixChainMultiplication(P, k+1, j)
                    +   P[i-1]*P[k]*P[j];

            min =  Integer.min(count, min);
        }

        return min;
    }

    public static void main(String[] args) {
        int arr[] = new int[] {1, 2, 3, 4, 3};
        int n = arr.length;

        System.out.println("Minimum number of multiplications is "+
                matriChainMultiplication(arr));
    }
}

While implementing, we will get the matrices as array P where P[i-1] and P[i] represent the dimension of matrix i.

Matrix chain multiplication : Dynamic programming approach

However, complexity will be exponential in time and hence of no use for every large inputs. Also, if you look at the calculations tree, there are many sub-problems which are solved again and again. This is called overlapping sub-problems. What can be done to avoid calculating sub-problems again and again? Save it somewhere.

matrix chain multiplication : overlapping subproblems

If we can save cost of multiplication of matrices  i to j, we can refer it back when needed. This technique is called as memorization in dynamic programming.

Cost of multiplying matrices Ai to Aj  is cost of

Cost (Ai, Aj) = Cost(Ai,Ak) + Cost(Ak+1,Aj )+(P[i-1] * P[k] * P[j])

Idea is to find out K such that cost(Ai, Aj) becomes minimum. If M[i,j] represents the cost to multiply matrix i to matrix j, then,

M[i,j]  = M[i,k] + M[K+1,j] + ((P[i-1] * P[k] * P[j])

When calculating M[i,j]; M[i,k] and M[k+1,j] should be already available, this is called bottom up filling of matrix. Also, M[i,i] = 0 as cost of multiplying a single matrix will be 0.

Since, we are taking bottom up approach to fill our solution matrix, start calculating by grouping two matrices at a time, then 3 and then 4 till we reach N matrices chain. We start with length L= 2 and go on to solve the table entry for length N. M[1, N] will give us the final cost.

To find minimum cost(i,j), we need to find a K such that expression

Cost (Ai, Aj) = Cost(Ai,Ak) + Cost(Ak+1,Aj )+(P[i-1] * P[k] * P[j])

becomes minimum, hence

M[i,j] = min (M[i,j], (M[i,k] + M[k+1,j], P[i-1] * P[k] * P[j]))

Implementation for multiplication of matrix problem

package com.company;

/**
 * Created by sangar on 31.12.17.
 */
public class MCM {

    public  static  int matriChainMultiplicationDP(int[] P){
        int n = P.length;

        int[][] M = new int[n][n];

        for(int i=0; i<n; i++){
            for(int j=0; j<n; j++){
                M[i][j] = 0;
            }
        }

        for(int L=2; L<n; L++){
            /* For every position i, we check every chain of len L */
            for(int i=1; i<n-L+1; i++){
                int j = i+L-1;
                M[i][j] = Integer.MAX_VALUE;

                /* For matrix i to j, check every split K */
                for(int k=i; k<j; k++){
                    int temp = M[i][k] + M[k+1][j] + P[i-1] * P[k] * P[j];
                    /* Check if the current count is less than minimum */
                    M[i][j] = Integer.min(temp, M[i][j]);
                }
            }
        }

        return M[1][n-1];
    }
    public static void main(String[] args) {
        int arr[] = new int[] {1, 2, 3, 4, 3};
        int n = arr.length;

        System.out.println("Minimum number of multiplications is "+
                matriChainMultiplicationDP(arr));
    }
}

C implementation of matrix chain multiplication

#include<stdlib.h>
#include<stdio.h>

#define MAX_INT 10000000

int matrixMultiplication(int p[], int N){

  int L,i, j, temp;

  int M[N][N];

  for(i=0; i<N; i++){
      for(j=0; j<N; j++){
           M[i][j] = 0;
      }
  }       

  for(L=2; L<N; L++){
    /* For every position i, we check every chain of len L */
    for(i=1; i<N-L+1; i++){
        j = i+L-1;
        M[i][j] = MAX_INT;
        /* For matrix i to j, check every split K */
            for(int k=i; k<j; k++){
                temp = M[i][k] + M[k+1][j] + p[i-1] * p[k] * p[j];
               /* Check if the current count is less than minimum */
                if(temp < M[i][j]){
                    M[i][j] = temp;                 
                }
            }
        }
    }
    return M[1][N-1];
}

/* Driver program to run above code */
int main(){

    int p [] ={10, 20, 30, 40, 30};
    int n = sizeof(p)/sizeof(p[0]);
    
    printf("%d\n", matrixMultiplication(p,n));
    
    return 0;
}

Let’s run through an example and understand how does this code work?

P = {10, 20,30,40,30}, 
dimensions of matrix [1] = 10X20, 
dimensions of matrix [2] = 20X30, 
dimensions of matrix [3] = 30X40,
dimensions of matrix [4] = 40X30

Function will be called with array P and size which is 5.
Table of N size is declared. M[6][6] because we will need to store M[1][4].

Diagonal of the table is set to zero, as cost of multiplication of single matrix is 0.

M[0][0] = M[1][1] =M[2][2]=M[3][3]=M[4][4]= M[5][5]=M[6][6] =0

Start with for loop with L =2. Inner for loop will run from i =1  to N-L+1 = 5-2+1 =4, j will be i +L -1. which is 2.
Inner most loop runs with K=1 (i) till k = 2 (j). M[1,2] is INT_MAX.

temp = M[1,1] + M[2,2] + P[0] *P[1]*P[2] = 6000.

Since it is less than INT_MAX M[1,2] = 6000.
Similarly, for i =2, j = 2 + 2 -1 =3. Above process is followed and M[2,3] = 24000 and so on.

M[3,4] = P[2] * P[3] * P[4] = 36000.

Coming to L =3. i =1, j = 4-1 =3. K =1 to K <3.
For K =1,

temp =  M[1,1] + M[2,3] + P[0] * [1] * P[3] = 24000 +8000 = 32000

For K=2,

temp = M[1,2] + M[3,3] + P[0]*[2]*P[3] = 6000 + 12000 = 18000

Hence M[1,3] = min(32000, and 18000) = 18000 with i = 1.

With i =2. j =2+3-1 =4. K =2 to K <4
For K =2,

temp = M[2,2] + M[3,4] + P[1] * P[2]*P[4] =  36000 + 18000 = 54000.

For K =3,

temp = M[2,3] + M[4,4] + P[1] * P[3]*P[4] =  32000 + 24000 =  66000.

Hence M[1,3] remains 18000.
Same process is followed and the entries are filled. Finally, matrix will look like this :

0 6000 18000 30000 
0 0 24000 48000 
0 0 0 36000 
0 0 0 0

M[1,N-1] will be solution to our problem.

Time complexity of matrix chain multiplication using dynamic programming is O(n2). Also space complexity is O(n2).
Reference
http://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/Dynamic/chainMatrixMult.htm

If you want to contribute to website, please contact us. Please share if you there is something wrong or missing.

What is dynamic programming?

What is Dynamic Programming or DP

Dynamic programming is an approach to solve a larger problem with the help of the results of smaller subproblems. It is a technique used to avoid computing multiple time the same subproblem in a recursive algorithm. I find a lot of students asking me question around, how do I know this problem is a dynamic programming problem? There is a definite way to arrive at the conclusion if a problem is a dynamic programming problem or not?

The first thing I would recommend you to read before going down is this beautiful explanation of dynamic programming to 4 years old.

The first thing you will notice about dynamic programming problems (not all problems) is they are optimization problem. Either it will be finding minimum or maximum of some entity. For example, find minimum edit between two strings or find longest common subsequence etc. However, problems like Fibonacci series are not exactly like an optimization problem, these are more like Combinatorial problems. Still, this can be a good hint that a problem can be a DP problem.

Second, you will notice that the problem can be divided into a pattern like an fn(n) = C + fn(n-k) where k can be anything between 1 and n.
This property is called optimum subproblem structure, where an optimum solution to the subproblem leads to the optimum solution to the larger problem.
Once you get the equation, it is very easy to come with a recursive solution to the problem. I would advise you to write the recursive solution and try to calculate the complexity of the solution. It will exponential in big-O notation.

Then why did recursion work so well with a divide and conquer approach? The key point is that in divide and conquer, a problem is expressed in terms of subproblems that are substantially smaller, say half the size. For instance, mergesort sorts an array of size n by recursively sorting two subarrays of size n/2. Because of this sharp drop in problem size, the full recursion tree has only logarithmic depth and a polynomial number of nodes. In contrast, in a typical dynamic programming formulation, a problem is reduced to subproblems that are only slightly smaller than the original. For instance, fn(j) relies on fn(j − 1). Thus the full recursion tree generally has polynomial depth and an exponential number of nodes.
However, it turns out that most of these nodes are repeats, that there are not too many distinct subproblems among them. Efficiency is therefore obtained by explicitly enumerating the distinct subproblems and solving them in the right order.
Reference

This will lead us to the third property, which is overlapping subproblems. Once, you draw the execution tree of the recursive solution of the problem, it will appear that a lot of problems are being solved again and again at different levels of recursion.

The intuition behind dynamic programming is that we trade space for time, i.e. to say that instead of calculating all the subproblems taking a lot of time but no space, we take up space to store the results of all the subproblems to save time later. The typical characteristics of a dynamic programming problem are optimization problems, optimal substructure property, overlapping subproblems, trade space for time, implementation via bottom-up/memoization.

Dynamic programming in action

Enough of theory, let’s take an example and see how dynamic programming works on real problems. I will take a very commonly used but most effective problem to explain DP in action. Problem is known as the Fibonacci series. Fibonacci series is a series of integers where each integer is the sum of previous two integers. For example, 1,1,2,3,5,8,13,17 is a Fibonacci series of eight integers. Now, the question is given a number n, output the integer which will be at the nth integer in Fibonacci series. For example for n = 4, the output should be 3 and for n=6, it should 8.

First hint: It is a combinatorial problem, so maybe a DP problem. Second, it is already given in the problem that current integer depends on the sum of previous two integers, that means f(n) = f(n-1) + f(n-2). This implies that the solution to subproblems will lead to a solution to the bigger problem which is optimal substructure property.

Next step is to implement the recursive function.

 public int fibonacci (int n) {
    if (n < 2) //base case
        return 1;

    return fibonacci(n-1) + fibonacci(n-2);
 }

Great, next step is to draw the execution tree of this function. It looks like below for n = 6. It is apparent how many times the same problem is solved at different levels.

what is dynamic programming
Recursive tree of Fibonacci series function

So, now we know three things about the Fibonacci problem: It is combinatorial problem, there is optimal substructure and there are overlapping subproblems. As in dynamic programming, we side with more space than time, we will try to use extra space to avoid recalculating subproblems.

The first way is to use a case, which stores the value of fab(n) if it is already calculated. This is called memoization or top-down approach.

Map<Integer, Integer> cache = new HashMap<>();

public int fibonacci(int n){
    if (n == 0)
       return 0;
    if (n == 1)
        return 1;

    if(cache.containsKey(n))
        return cache.get(n);

    cache.put(n, fibonacci(n - 1) + fibonacci(n - 2));

    return cache.get(n);
}

Another approach is bottom up, where the smaller problems are solved in an order which helps us with solving bigger problems. Here also, we use memoization but in a different way. We store the solution of smaller subproblems and directly use this to build the solution.

int[] fib = new int[n];
fib[0] = fib[1] = 1;
public int fibonacci(int n){
   for(int i=2; i<=n; i++){
       fib[n] = fib[n-1] + fib[n-2];
   }
   return fib[n];
}

Above solution requires extra O(n) space, however, the time complexity is also reduced to O(n) with each subproblem solved only once.

Follow longest increasing subsequence problem, how we have applied the same pattern while we solved the problem.

Final thoughts
Where to apply dynamic programming : If you solution is based on optimal substructure and overlapping sub problem then in that case using the earlier calculated value will be useful so you do not have to recompute it. It is bottom up approach. Suppose you need to calculate fib(n) in that case all you need to do is add the previous calculated value of fib(n-1) and fib(n-2)

Recursion : Basically subdividing you problem into smaller part to solve it with ease but keep it in mind it does not avoid re computation if we have same value calculated previously in other recursion call.

Memoization : Basically storing the old calculated recursion value in table is known as memoization which will avoid re-computation if its already been calculated by some previous call so any value will be calculated once. So before calculating we check whether this value has already been calculated or not if already calculated then we return the same from table instead of recomputing. It is also top down approach.
Reference: Answer by Endeavour

Please share if there is something wrong or missing. If you are preparing for an interview and need help with preparation, please book a free session with us to guide you through it.