Matrix chain multiplication-dynamic programming

Tags: , , , ,

What is matrix chain multiplication in general? To read on that please refer to Wiki. However, today’s problem is not about actually multiplying chain of matrices, but to find out the optimal way to multiply them in order to minimize the number of scalar multiplications.

To be able to multiply two matrices, it is required that the number of columns in the first matrix is equal to the number of rows of the second matrix.

If we multiply a matrix of dimension M x N to another matrix N x P, we can a matrix of dimension M x P.

matrix chain multiplication

How many scalar multiplications need to be done to multiply M x N to N x P matrix? It’s M x N x P.

Given N matrices with dimensions, find the optimal way to multiply these matrices, in order to minimize the total number of scalar multiplications.

Matrix chain multiplication : line of thoughts

Before going further, let’s understand some basics of matrix multiplication

  1. Matrix multiplication is associative i.e.  A* (B*C) = (A*B) *C
  2. It is not commutative i.e  A * (B*C) not equal to A * (C * B)
  3. To multiply two matrices, they should be compatible i.e. no of columns in the first matrix should be equal to the number of rows of the second matrix. No of columns of first matrix = No. of rows of second matrix

Since matrix multiplication is associative, this problem actually reduces to figure out a way to put parenthesis around matrices so that the total number of scalar multiplications is least.
Let’s take an example and understand. We have matrices A, B, C and D with dimensions array as 10 x 20, 20 x 30, 30 x 40, 40 x 50 respectively.

How can we solve this problem manually?  Given chain of matrices is as ABCD. There are three ways to split the chain into two parts: (A) x (BCD) or as (AB) x (CD) or as (ABC) x (D).

Any which way, we have smaller problems to solve now. If we take the first split, cost of multiplication of ABCD is cost of multiplication A + cost of (BCD) + cost of multiplication of A x (BCD).

Similarly for rest two splits. The answer will be the minimum of all three possible splits.

To get cost of ABCD,  solve cost of BCD.  (BCD) can be split in two parts : (B) x (CD) or (BC) x (D).

We will continue with (B) x (CD). For the cost of (CD), splits in one way (C) x (D). Cost of which is nothing but M x N x P where C is matrix of dimension M x N and D is a matrix of dimension N x P.

Cost of  (BC) = M x N x P. Below figure shows the detailed calculation of each possible splits and final gives the answer as the minimum of all possible splits.

matrix chain multiplication brute force
Notice that for N matrices, there are N-1 ways to split the chain. This manual calculation can easily be implemented as a recursive solution.

Recursive implementation

package com.company;

/**
 * Created by sangar on 31.12.17.
 */
public class MCM {

    public static int matrixChainMultiplication(int[] P, int i, int j){
        int count = 0;
        int min = Integer.MAX_VALUE;

        System.out.println("("+ i + "," + j + ")");
        if(i==j) return 0; // No cost of multiplying zero matrix

        for(int k=i; k<j; k++){
            System.out.println("Parent : ("+ i + "," + j + ")");
            count = matrixChainMultiplication(P,i, k)
                    + matrixChainMultiplication(P, k+1, j)
                    +   P[i-1]*P[k]*P[j];

            min =  Integer.min(count, min);
        }

        return min;
    }

    public static void main(String[] args) {
        int arr[] = new int[] {1, 2, 3, 4, 3};
        int n = arr.length;

        System.out.println("Minimum number of multiplications is "+
                matriChainMultiplication(arr));
    }
}

While implementing, we will get the matrices as array P where P[i-1] and P[i] represent the dimension of the matrix i.

However, complexity of this implementation of the matrix chain multiplication is exponential in time and hence of no use for every large input. Also, if you look at the calculations tree, there are many sub-problems that are solved again and again. This is called overlapping subproblems. What can be done to avoid calculating subproblems again and again? Save it somewhere, using memoization.

time complexity of matrix chain multiplication

If we can save cost of multiplication of matrices i to j, we can refer it back when needed. This technique is called as memorization in dynamic programming.

Cost of multiplying matrices Ai to Aj  is the cost of

Cost (Ai, Aj) = Cost(Ai,Ak) + Cost(Ak+1,Aj )+(P[i-1] * P[k] * P[j])

The idea is to find out K such that cost(Ai, Aj) becomes minimum. If M[i,j] represents the cost to multiply matrix i to matrix j, then,

M[i,j]  = M[i,k] + M[K+1,j] + ((P[i-1] * P[k] * P[j])

When calculating M[i,j]; M[i,k] and M[k+1,j] should be already available, this is called bottom-up filling of a matrix. Also, M[i,i] = 0 as cost of multiplying a single matrix will be 0.

Since, we are taking the bottom-up approach to fill our solution matrix, start calculating by grouping two matrices at a time, then 3 and then 4 till we reach n matrices chain. We start with length L= 2 and go on to solve the table entry for length N. M[1, N] will give us the final cost.

To find minimum cost(i,j), we need to find a K such that expression

Cost (Ai, Aj) = Cost(Ai,Ak) + Cost(Ak+1,Aj )+(P[i-1] * P[k] * P[j])

becomes minimum, hence

M[i,j] = min (M[i,j], (M[i,k] + M[k+1,j], P[i-1] * P[k] * P[j]))

Matrix chain multiplication leetcode implementation

package com.company;

/**
 * Created by sangar on 31.12.17.
 */
public class MCM {

    public  static  int matriChainMultiplicationDP(int[] P){
        int n = P.length;

        int[][] M = new int[n][n];

        for(int i=0; i<n; i++){
            for(int j=0; j<n; j++){
                M[i][j] = 0;
            }
        }

        for(int L=2; L<n; L++){
            /* For every position i, we check every chain of len L */
            for(int i=1; i<n-L+1; i++){
                int j = i+L-1;
                M[i][j] = Integer.MAX_VALUE;

                /* For matrix i to j, check every split K */
                for(int k=i; k<j; k++){
                    int temp = M[i][k] + M[k+1][j] + P[i-1] * P[k] * P[j];
                    /* Check if the current count is less than minimum */
                    M[i][j] = Integer.min(temp, M[i][j]);
                }
            }
        }

        return M[1][n-1];
    }
    public static void main(String[] args) {
        int arr[] = new int[] {1, 2, 3, 4, 3};
        int n = arr.length;

        System.out.println("Minimum number of multiplications is "+
                matriChainMultiplicationDP(arr));
    }
}

Let’s run through an example and understand how does this code work?

P = {10, 20,30,40,30}, 
dimensions of matrix [1] = 10X20, 
dimensions of matrix [2] = 20X30, 
dimensions of matrix [3] = 30X40,
dimensions of matrix [4] = 40X30

Important to understand here is what does M[i][j] represent? In our case, M[i][j] represent the minimum cost to multiply a chain of matrices from matrix i to matrix j.
With this representation, we can safely say that M[i][i] is 0, as there is no cost to multiply only one matrix.

Start with for loop with L=2.

complexity of matrix chain multiplication
leetcode matrix chain multiplication

M[1, N-1] will be the solution to the matrix chain multiplication problem.

Time complexity of matrix chain multiplication using dynamic programming is O(n2). Also space complexity is O(n2).

Reference
http://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/Dynamic/chainMatrixMult.htm

If you want to contribute to the website, please contact us. Please share if there is something wrong or missing.