最优矩阵链乘

什么是矩阵链乘?

我们都学过矩阵的乘法。矩阵的乘法不满足分配率,但是满足结合律,因此矩阵\((A×B)×C\)\(A×(B×C)\)的结果是一样的,但是中间的运算量可能是不同的。

比如三个矩阵\(A=2\times3\)\(B=3\times4\)\(C=4\times5\),则\((A \times B)\times C\)需要运算\(2\times3\times4+2\times4\times5=64\)次,而\(A\times(B \times C)\)的运算量是\(3\times5\times5+2\times3\times5=90\)次。显然第一次的计算方法比较节约计算量

最优矩阵链乘,就是在一堆矩阵相乘的情况下,找出需要运算次数最少的那个方法。通俗的说,就是给矩阵乘法加括号

如何求得最优矩阵链乘?

我们假设有\(n\)个矩阵相乘:

\(A_1 \times A_2 \times A_3 \times ...... \times A_n\)

我们可以先将这n个矩阵分为两部分\(P\)\(Q\)

\(P=A_1 \times A_2 \times A_3 \times ...... \times A_k\)

\(Q=A_{k+1} \times A_{k+2} \times A_{k+3} \times ...... \times A_n\)

其中\(1 \leq k<n\)。矩阵链乘的结果就可以用\(P\times Q\)来表示了,且答案为\(k\)\(1\)\(n-1\)循环一遍,运算结果最小的那个。状态转移方程如下:

\(f(i,j)=min \{ f(i,k) + f(k+1,j) + P_{i-1}P_{k}P_{j} \}, i\leq k<j\)

其中,\(f(i,j)\)表示第\(i\)个矩阵到第\(j\)个矩阵的运算量,\(k\)为中间乘号的位置,\(P\)保存矩阵的行和列,第\(i\)行矩阵的行为\(P_{i-1}\),列为\(P_i\)

我们来用图表演示一下这个过程。假设现在有4个矩阵\(A_{1}\)~\(A_{6}\),其输入如下:

\(p=\{3,5,2,1,10\}\)

输入内容表示的含义:

矩阵 \(A_1\) \(A_2\) \(A_3\) \(A_4\)
\(\times\) \(3\times5\) \(5\times2\) \(2\times1\) \(1\times10\)
P的下标 0 1 2 3|4
P值 3 5 2 1|10

计算过程

标着的数字是计算的顺序。在花括号里面部分的就是上文所说的\(P\),花括号外面的部分就是\(Q\)

下面开始写代码

首先,根据输入的数组构建一个二维数组,这个二维数组是用来存放计算结果的:

int n=arr.length-1;
//为了让数组的指针能够和上图对应起来,二维数组的大小设置为n+1
int[][] m=int[n+1][n+1];

结合上面的例子来讲解。输入的数组一共有5个元素,表示一共有4个矩阵。所以构造的二维数组就是5*5的。其中m[1][2]存放的是\(A1 \times A2\)最终结果;同理,m[1][3]存放的是\(A1 \times A2 \times A3\)的最终结果————虽然在计算\(A1 \times A2 \times A3\)的过程中有两种可能,但我们只取结果最小的那一个作为最终结果储存在数组中。

接着,我们用一个指针i来标识二维数组中的行

用指针j标识二维数组中的列

t表示链乘的矩阵的长度:如\(A1 \times A2 \times A3\)的长度为3,\(A1 \times A2\)的长度为2

k来标识乘号的位置,比如上图表1行4列中的第二个式子中k=2

nit之间的关系为:

i<=n-t+1;

ijt之间的关系为:

j=i+t-1;

kj之间的关系为:

k<=j-1

根据t的大小,我们可以写一个循环:

int i,j,k,t;

for(t=2;t<=n;t++){
    for(i=1;i<=n-t+1;i++){
        j=i+t-1;
        m[i][j]=Integer.MAX_VALUE
        for(k=i;i<=j-1;k++){
            int temp=m[i][k]+m[k+1][j]+a[i-1]*a[k]*a[j];
            if(temp<m[i][j])
                m[i][j]=temp;
        }
    }
}

System.out.println(m[1][n]);

完整代码:

public class MCP{
    public static void MatrixChainProduct(int[] arr){
    	int n=arr.length-1;
		int[][] m=new int[n+1][n+1];
		int i,j,k,t;
		
		for(t=2;t<=n;t++) {
			for(i=1;i<=n-t+1;i++) {
				j=i+t-1;
				m[i][j]=Integer.MAX_VALUE;
				for(k=i;k<=j-1;k++) {
					int temp=m[i][k]+m[k+1][j]+arr[i-1]*arr[k]*arr[j];
					if(temp<m[i][j])
						m[i][j]=temp;
				}
			}
		}
		System.out.println("最少计算次数:"+m[1][n]);
    }
    
    public static void main(String[] args){
        //测试数据
    	int[] a={3, 5, 2, 1, 10};
    	int[] b={2, 7, 3, 6, 10};
    	int[] c={10, 3, 15, 12, 7, 2};
    	int[] d={7, 2, 4, 15, 20, 5};
        System.out.println("3, 5, 2, 1,10");
        MatrixChainProduct(a);
        System.out.println("2, 7, 3, 6, 10");
        MatrixChainProduct(b);
        System.out.println("10, 3, 15, 12, 7, 2");
        MatrixChainProduct(c);
        System.out.println("7, 2, 4, 15, 20, 5");
        MatrixChainProduct(d);

    }
}

输出结果:

3, 5, 2, 1,10

最少计算次数:55

2, 7, 3, 6, 10

最少计算次数:198

10, 3, 15, 12, 7, 2

最少计算次数:678

7, 2, 4, 15, 20, 5

最少计算次数:990

posted @ 2020-05-01 22:19  maurrinho  阅读(495)  评论(0编辑  收藏  举报