Example: Matrix Mul - rFronteddu/general_wiki GitHub Wiki

Given several matrices, determine the best way to multiply them.

Explanation

  • Assume a chain of matrices $A_1$, $A_2$, ..., $A_n$, where matrix $A_i$ has dimensions $p_{i-1}*p_i$, the goal is to find the most efficient way to multiply these matrices together.

  • The multiplication cost between two matrices is dependent on:

    • The number of rows of the first matrix in the chain.
    • The number of columns of the last matrix in the chain.
    • The shared dimension between the two matrices being multiplied.
  • dims is typically an array of dimensions where dims[i] represents the number of rows of matrix A and dims[i+1] represents the number of columns of matrix $A_i$

  • For 3 matrices $A_1$, $A_2$, $A_3$

    • $A_1$ has dimensions $p_0*p_1$
    • $A_2$ has dimensions $p_1*p_2$
    • $A_3$ has dimensions $p_2*p_3$
  • Let $M[i,j]$ be the matrix that results from evaluating the product of matrices from index i to j (inclusive).

    • If $i \le j$, we need to split the product between $M[i,k]$ and $M[k+1, j]$ for $i \leq k \le j$.
    • That is, we first compute $M[i,k]$ and $M[k+1,j]$ and then multiply them together to produce $M[i,j]$.
    • The cost of this multiplication is the cost of computing $M[i,k]$ plus the cost of computing $M[k+1,j]$ plus the cost of multiplying the two together.
  • Observe that if we find an optimal k that minimizes this cost, it must also be possible to split optimally the two resulting matrices; otherwise, k would not be optimal, leading to a contradiction.

  • We define our recursive problem as determining the minimum cost of parenthesizing $M[i,j]$ for $1 \leq i \leq j \leq n$.

    • let $m[i,j]$ be the minimum number of scalar multiplications to compute $M[i,j]$.
    • Note that the solution of the full problem for $M[1,n]$ is $m[1,n]$.
  • We can define $m[i,j]$ recursively as follows:

    • if i == j, $m[i,j]=0$ since there is only one matrix and no product to compute.
    • If $i \leq j$, $m[i,j] = \min_{i \leq k \leq j} \left( m[i,k] + m[k+1,j] + p[i-1] \cdot p[k] \cdot p[j] \right)$
    • In this notation, matrix M has dimensions p. We only need to evaluate j-i values for k, namely i, i+1, ..., j - 1 since $i \leq k \le j$
    • To rebuild the solution, we also need to keep track of each k in a matrix. Let s[i,j] be the optimal split k for the cost m[i,j]. We know that the product j-i+1 matrices depends on less than j-1+1 matrices.
// dims[i-1][0] gives the number of rows of the first matrix in the current subproblem (A_i).
// dims[k][1] gives the number of columns of the k-th matrix.
// dims[j][1] gives the number of columns of the last matrix (A_j) in the subproblem.
matrixMul(p)
    // number of matrices
    n = p.len 
    let m [0..n-1, 0..n-1] and s[0..n-1,0..n-1] be new tables

    // chain of increasing size
    for len = 2 to n 
        // i is the starting index of the subsequence
        for i = 1 to n - len
            // j is the last index of the subsequence starting from i
            j = i + len - 1
            m[i,j] = MaxInt
            for k = i to j - 1
                q = m[i,k] + m[k + 1, j] + p[i - 1] * p[k] * p[j]
                if q < m[i, j]
                    m[i, j] = q
                    s[i, j] = k
    return m and s

  • To print the result, observe that the last split to be computed is $A_{1..n}$ => $A_{1..s[1,n]}$ $A_{s[1,n] + 1, n}$
  • We can then proceed recursively:
printMul(s, i, j)
    if(i == j)
        print "A"
    else 
        print "("
        printMul(s, i, s[i, j])
        printMul(s, s[i, j] + 1, j)
        print ")"
printMul(s, 1, n)

JAVA Implementation

public class MatrixChain
{
    public static int matrixChain (int[][] dims) {
        int n = dims.length - 1; // Number of matrices
        int[][] dp = new int[n + 1][n + 1];
        int[][] splitPoints = new int[n + 1][n + 1]; // Store split points

        // Initialize the dynamic programming table
        for (int len = 2; len <= n; len++) {
            for (int i = 1; i <= n - len + 1; i++) {
                int j = i + len - 1;
                dp[i][j] = Integer.MAX_VALUE;
                for (int k = i; k < j; k++) {
                    // note that dims[i] refers to the dimensions of matrix A_i+1
                    int cost = dp[i][k] + dp[k + 1][j] + dims[i - 1][0] * dims[k][1] * dims[j][1];
                    if (cost < dp[i][j]) {
                        dp[i][j] = cost;
                        splitPoints[i][j] = k; // Store split point
                    }
                }
            }
        }

        // Print matrices and split points
        printMatrices(splitPoints, 1, n, dims);

        // The minimum cost is stored in dp[1][n]
        return dp[1][n];
    }

    // Recursive function to print matrices and split points
    private static void printMatrices(int[][] splitPoints, int i, int j, int[][] dims) {
        if (i == j) {
            System.out.println("Matrix " + i + ": [" + dims[i - 1][0] + " x " + dims[i][1] + "]");
        } else {
            int k = splitPoints[i][j];
            System.out.println("Multiplying matrices " + i + " to " + k + " and " + (k + 1) + " to " + j);
            printMatrices(splitPoints, i, k, dims);
            printMatrices(splitPoints, k + 1, j, dims);
        }
    }
}