矩阵链相乘

问题:n个矩阵M1M2……Mn链乘法的耗费,取决于n-1个乘法执行的顺序。现在要找到其最小次数的乘法执行顺序。

如,n = 4,有以下可能的顺序:

(M1(M2(M3M4)))

(M1((M2M3)M4))

((M1M2)(M3M4))

(((M1M2)M3)M4)

((M1(M2M3))M4)

如果使用暴力算法枚举,时间复杂度为O(2n)。运用动态规划技术,可以将时间复杂度降为O(n3).

 

 

算法基本思路:

1.得到最小乘法次数

(1)将矩阵序列在其中找出某个k位置,将其分为两部分;

分别求出这两部分的最优解,然后再将其加上这两部分相乘的耗费,就是用当前划分所需要的乘法次数;

依次枚举按k划分,比较并取最小乘法次数。

(2)从上面的思路可以知道构造最优解的方法,但如果自顶向下必然会导致很多重复运算(类比斐波拉契数列),因此采用自底向下的打表方法,逐步求出最优解。

 

2.构造矩阵乘法序列

在上面算法的基础上,每次获取最优解的时候都标记当前k的位置;然后将序列从最终结果处逐步按 k 回溯,并划分(即添加括号)。

 

 

算法具体实现:

输入:r[0...n],     r[0...n-1]表示M0M1……Mn-1的行数,r[n]表示Mn的列数。因为相邻矩阵的行列相等,所以此数组即可包含所有数据。

1.最小乘法次数

构造最优表:c[0...n-1][0...n-1],    c[i][j]表示 MiMi+1……M的乘法次数

枚举k(i < k <= j),C[i][j] = min( c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1] )

其中:

    c[i][k-1]表示 MiMi+1……Mk-1 的乘法次数

    c[k][j]表示 MkMk+1……M的乘法次数

    r[i] = M的行数

    r[k] = M的行数

    r[j+1] = M的列数

再利用 t[i][j] 保存当前最优的k,不难看出,k 就是那个序列的划分。

最终结果 c[0][n-1] = min( c[0][k-1] + c[k][n-1] + r[0]*r[k]*r[n] ),其中 0 < k <= n

 

2.构造最优解

设 i = 0, j = n-1 即c[i][j] 为最小乘法次数

要输出当前的序列,利用t[i][j]划分序列,输出它们的序列即可;一直划分到仅剩下一个元素即 i = j.(递归)

 

实现代码:

/*
input:r[0..n];            //r[0..n-1] = the number of rows of M0...Mn-1, r[n] = the number of columns of Mn-1
because the number of columns of Mi is equal to the number of rows of Mi+1, so enough information

table:c[0..n-1][0..n-1]        //c[i][j] = the number of multiplications of (Mi * Mi+1 *...* Mj)
    
enumerate k, i< k <=j, c[i][j] = min(c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1])
because Mi * Mi+1 *...* Mj = (Mi * Mi+1 *...* Mk-1)*(Mk * Mk+1 *...* Mj)
        c[i][k-1] = the number of multiplications of (Mi * Mi+1 *...* Mk-1)
        c[k][j] = the number of multiplications of (Mk * Mk+1 *...* Mj)
        r[i] = the number of rows of Mi
        r[k] = the number of rows of Mk
        r[j+1] = the number of columns of Mj
*/

public class MatrixChain {
    
    public static String getMinMatrixChain(int[] r){
        //return minimum number of multiplications
        int[][] c = new int[r.length-1][r.length-1];
        int[][] t = new int[r.length-1][r.length-1];
        //init diagonal0
        for(int i = 0; i < c.length; i ++){
            c[i][i] = 0;
            t[i][i] = -1;
        }
        //computer diagonal1..n-1
        for(int d = 1; d < c.length; d ++){
            for(int i = 0; i < c.length-d; i ++){
                int j = i + d;
                //compute c[i][j]
                c[i][j] = Integer.MAX_VALUE;
                for(int k = i+1; k <= j; k ++){
                    int temp = c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1];
                    if(temp < c[i][j]){
                        c[i][j] = temp;
                        t[i][j] = k;
                    }
                }
            }
        }
//        for(int i = 0; i < c.length; i ++){
//            for(int j = 0; j < c.length; j ++){
//                System.out.print(c[i][j] + "\t");
//            }
//            System.out.println();
//        }
        return c[0][c.length-1] + "/" + getSequence(t, 0, c.length-1);
    }

    private static String getSequence(int[][] t, int i, int j){
        //return the matrix sequence of mul
        if(i == j)                //atom
            return "M" + i;
        //t[i][j] = the place of "("
        return "(" + getSequence(t, i, t[i][j]-1) + getSequence(t, t[i][j], j) + ")";
    }

    public static void main(String[] args) {
        int[] r = {4, 5, 3, 6, 4, 5};
        String[] result = getMinMatrixChain(r).split("/");
        System.out.println(result[0] + ": " + result[1]);
    }

}
Java

 

posted @ 2013-11-17 20:16  7hat  阅读(821)  评论(0编辑  收藏  举报