矩阵链相乘
问题:n个矩阵M1M2……Mn链乘法的耗费,取决于n-1个乘法执行的顺序。现在要找到其最小次数的乘法执行顺序。
如,n = 4,有以下可能的顺序:
(M1(M2(M3M4)))
(M1((M2M3)M4))
((M1M2)(M3M4))
(((M1M2)M3)M4)
((M1(M2M3))M4)
如果使用暴力算法枚举,时间复杂度为O(2n)。运用动态规划技术,可以将时间复杂度降为O(n3).
算法基本思路:
1.得到最小乘法次数
(1)将矩阵序列在其中找出某个k位置,将其分为两部分;
分别求出这两部分的最优解,然后再将其加上这两部分相乘的耗费,就是用当前划分所需要的乘法次数;
依次枚举按k划分,比较并取最小乘法次数。
(2)从上面的思路可以知道构造最优解的方法,但如果自顶向下必然会导致很多重复运算(类比斐波拉契数列),因此采用自底向下的打表方法,逐步求出最优解。
2.构造矩阵乘法序列
在上面算法的基础上,每次获取最优解的时候都标记当前k的位置;然后将序列从最终结果处逐步按 k 回溯,并划分(即添加括号)。
算法具体实现:
输入:r[0...n], r[0...n-1]表示M0M1……Mn-1的行数,r[n]表示Mn的列数。因为相邻矩阵的行列相等,所以此数组即可包含所有数据。
1.最小乘法次数
构造最优表:c[0...n-1][0...n-1], c[i][j]表示 MiMi+1……Mj 的乘法次数
枚举k(i < k <= j),C[i][j] = min( c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1] )
其中:
c[i][k-1]表示 MiMi+1……Mk-1 的乘法次数
c[k][j]表示 MkMk+1……Mj 的乘法次数
r[i] = Mi 的行数
r[k] = Mk 的行数
r[j+1] = Mj 的列数
再利用 t[i][j] 保存当前最优的k,不难看出,k 就是那个序列的划分。
最终结果 c[0][n-1] = min( c[0][k-1] + c[k][n-1] + r[0]*r[k]*r[n] ),其中 0 < k <= n
2.构造最优解
设 i = 0, j = n-1 即c[i][j] 为最小乘法次数
要输出当前的序列,利用t[i][j]划分序列,输出它们的序列即可;一直划分到仅剩下一个元素即 i = j.(递归)
实现代码:
/* input:r[0..n]; //r[0..n-1] = the number of rows of M0...Mn-1, r[n] = the number of columns of Mn-1 because the number of columns of Mi is equal to the number of rows of Mi+1, so enough information table:c[0..n-1][0..n-1] //c[i][j] = the number of multiplications of (Mi * Mi+1 *...* Mj) enumerate k, i< k <=j, c[i][j] = min(c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1]) because Mi * Mi+1 *...* Mj = (Mi * Mi+1 *...* Mk-1)*(Mk * Mk+1 *...* Mj) c[i][k-1] = the number of multiplications of (Mi * Mi+1 *...* Mk-1) c[k][j] = the number of multiplications of (Mk * Mk+1 *...* Mj) r[i] = the number of rows of Mi r[k] = the number of rows of Mk r[j+1] = the number of columns of Mj */ public class MatrixChain { public static String getMinMatrixChain(int[] r){ //return minimum number of multiplications int[][] c = new int[r.length-1][r.length-1]; int[][] t = new int[r.length-1][r.length-1]; //init diagonal0 for(int i = 0; i < c.length; i ++){ c[i][i] = 0; t[i][i] = -1; } //computer diagonal1..n-1 for(int d = 1; d < c.length; d ++){ for(int i = 0; i < c.length-d; i ++){ int j = i + d; //compute c[i][j] c[i][j] = Integer.MAX_VALUE; for(int k = i+1; k <= j; k ++){ int temp = c[i][k-1] + c[k][j] + r[i]*r[k]*r[j+1]; if(temp < c[i][j]){ c[i][j] = temp; t[i][j] = k; } } } } // for(int i = 0; i < c.length; i ++){ // for(int j = 0; j < c.length; j ++){ // System.out.print(c[i][j] + "\t"); // } // System.out.println(); // } return c[0][c.length-1] + "/" + getSequence(t, 0, c.length-1); } private static String getSequence(int[][] t, int i, int j){ //return the matrix sequence of mul if(i == j) //atom return "M" + i; //t[i][j] = the place of "(" return "(" + getSequence(t, i, t[i][j]-1) + getSequence(t, t[i][j], j) + ")"; } public static void main(String[] args) { int[] r = {4, 5, 3, 6, 4, 5}; String[] result = getMinMatrixChain(r).split("/"); System.out.println(result[0] + ": " + result[1]); } }