最优矩阵链乘
什么是矩阵链乘?
我们都学过矩阵的乘法。矩阵的乘法不满足分配率,但是满足结合律,因此矩阵\((A×B)×C\)和\(A×(B×C)\)的结果是一样的,但是中间的运算量可能是不同的。
比如三个矩阵\(A=2\times3\)、\(B=3\times4\)、\(C=4\times5\),则\((A \times B)\times C\)需要运算\(2\times3\times4+2\times4\times5=64\)次,而\(A\times(B \times C)\)的运算量是\(3\times5\times5+2\times3\times5=90\)次。显然第一次的计算方法比较节约计算量
最优矩阵链乘,就是在一堆矩阵相乘的情况下,找出需要运算次数最少的那个方法。通俗的说,就是给矩阵乘法加括号
如何求得最优矩阵链乘?
我们假设有\(n\)个矩阵相乘:
\(A_1 \times A_2 \times A_3 \times ...... \times A_n\)
我们可以先将这n个矩阵分为两部分\(P\)和\(Q\):
\(P=A_1 \times A_2 \times A_3 \times ...... \times A_k\)
\(Q=A_{k+1} \times A_{k+2} \times A_{k+3} \times ...... \times A_n\)
其中\(1 \leq k<n\)。矩阵链乘的结果就可以用\(P\times Q\)来表示了,且答案为\(k\)从\(1\)到\(n-1\)循环一遍,运算结果最小的那个。状态转移方程如下:
\(f(i,j)=min \{ f(i,k) + f(k+1,j) + P_{i-1}P_{k}P_{j} \}, i\leq k<j\)
其中,\(f(i,j)\)表示第\(i\)个矩阵到第\(j\)个矩阵的运算量,\(k\)为中间乘号的位置,\(P\)保存矩阵的行和列,第\(i\)行矩阵的行为\(P_{i-1}\),列为\(P_i\)。
我们来用图表演示一下这个过程。假设现在有4个矩阵\(A_{1}\)~\(A_{6}\),其输入如下:
\(p=\{3,5,2,1,10\}\)
输入内容表示的含义:
矩阵 | \(A_1\) | \(A_2\) | \(A_3\) | \(A_4\) |
---|---|---|---|---|
行\(\times\)列 | \(3\times5\) | \(5\times2\) | \(2\times1\) | \(1\times10\) |
P的下标 | 0 | 1 | 2 | 3|4 |
P值 | 3 | 5 | 2 | 1|10 |
计算过程
标着的数字是计算的顺序。在花括号里面部分的就是上文所说的\(P\),花括号外面的部分就是\(Q\)
下面开始写代码
首先,根据输入的数组构建一个二维数组,这个二维数组是用来存放计算结果的:
int n=arr.length-1;
//为了让数组的指针能够和上图对应起来,二维数组的大小设置为n+1
int[][] m=int[n+1][n+1];
结合上面的例子来讲解。输入的数组一共有5个元素,表示一共有4个矩阵。所以构造的二维数组就是5*5的。其中m[1][2]
存放的是\(A1 \times A2\)的最终结果;同理,m[1][3]
存放的是\(A1 \times A2 \times A3\)的最终结果————虽然在计算\(A1 \times A2 \times A3\)的过程中有两种可能,但我们只取结果最小的那一个作为最终结果储存在数组中。
接着,我们用一个指针i
来标识二维数组中的行
用指针j
标识二维数组中的列
用t
表示链乘的矩阵的长度:如\(A1 \times A2 \times A3\)的长度为3,\(A1 \times A2\)的长度为2
用k
来标识乘号的位置,比如上图表1行4列中的第二个式子中k=2
n
、i
、t
之间的关系为:
i<=n-t+1;
i
、j
、t
之间的关系为:
j=i+t-1;
k
、j
之间的关系为:
k<=j-1
根据t
的大小,我们可以写一个循环:
int i,j,k,t;
for(t=2;t<=n;t++){
for(i=1;i<=n-t+1;i++){
j=i+t-1;
m[i][j]=Integer.MAX_VALUE
for(k=i;i<=j-1;k++){
int temp=m[i][k]+m[k+1][j]+a[i-1]*a[k]*a[j];
if(temp<m[i][j])
m[i][j]=temp;
}
}
}
System.out.println(m[1][n]);
完整代码:
public class MCP{
public static void MatrixChainProduct(int[] arr){
int n=arr.length-1;
int[][] m=new int[n+1][n+1];
int i,j,k,t;
for(t=2;t<=n;t++) {
for(i=1;i<=n-t+1;i++) {
j=i+t-1;
m[i][j]=Integer.MAX_VALUE;
for(k=i;k<=j-1;k++) {
int temp=m[i][k]+m[k+1][j]+arr[i-1]*arr[k]*arr[j];
if(temp<m[i][j])
m[i][j]=temp;
}
}
}
System.out.println("最少计算次数:"+m[1][n]);
}
public static void main(String[] args){
//测试数据
int[] a={3, 5, 2, 1, 10};
int[] b={2, 7, 3, 6, 10};
int[] c={10, 3, 15, 12, 7, 2};
int[] d={7, 2, 4, 15, 20, 5};
System.out.println("3, 5, 2, 1,10");
MatrixChainProduct(a);
System.out.println("2, 7, 3, 6, 10");
MatrixChainProduct(b);
System.out.println("10, 3, 15, 12, 7, 2");
MatrixChainProduct(c);
System.out.println("7, 2, 4, 15, 20, 5");
MatrixChainProduct(d);
}
}
输出结果:
3, 5, 2, 1,10
最少计算次数:55
2, 7, 3, 6, 10
最少计算次数:198
10, 3, 15, 12, 7, 2
最少计算次数:678
7, 2, 4, 15, 20, 5
最少计算次数:990