Python算法|矩阵链乘法

概述

题目来源:Python 100天算法挑战
题目解析参考:【动态规划】矩阵链乘法

矩阵乘法是一个满足结合律的运算。显然,对于矩阵A、B、C来说,(AB)C 与 A(BC) 是等价的,我们可以根据自己的心情选择任意的运算顺序,总之,结果都是一样的。
糟糕的是,对计算机来说可不是这么回事,若我们假定矩阵 A=[10,20], B=[20,30], C=[30,40],那么在以下两种运算顺序中,标量相乘的次数是天差地别:
(AB)C = 10*20*30 + 10*30*40 = 18000
A(BC) = 20*30*40 + 10*20*40 = 32000

为了计算表达式,我们可以先用括号明确计算次序,然后利用标准的矩阵相乘算法进行计算。
完全括号化(fully parenthesized):它是单一矩阵,或者是两个完全括号化的矩阵乘积链的积。
例如如果有矩阵链为<A1,A2,A3,A4>,则共有5种完全括号化的矩阵乘积链。

(A1(A2(A3A4)))
(A1((A2A3)A4))
((A1A2)(A3A4))
((A1(A2A3))A4)
((A1(A2A3))A4)

对矩阵链加括号的方式会对乘积运算的代价产生巨大影响。我们先来分析两个矩阵相乘的代价。下面的伪代码的给出了两个矩阵相乘的标准算法,属性rows和columns是矩阵的行数和列数。

MATRIX-MULTIPKLY(A,B)
if A.columns≠B.rows 
	error "incompatible dimensions"
else let C be a new A.rows×B.columns matrix
	for i = 1 to A.rows
	     for j = 1 to B.columns
		  c(ij)=0
		   for k = 1 to A.columns
			 c(ij)=c(ij)+a(ik)*b(kj)
return C

两个矩阵A和B只有相容(compatible),即A的列数等于B的行数时,才能相乘。如果A是p×q的矩阵,B是q×r的矩阵,那么乘积C是p×r的矩阵。计算C所需要时间由第8行的标量乘法的次数决定的,即pqr。
以矩阵链<A1,A2,A3>为例,来说明不同的加括号方式会导致不同的计算代价。假设三个矩阵的规模分别为10×100、100×5和5×50。
如果按照((A1A2)A3)的顺序计算,为计算A1A2(规模10×5),需要做10*100*5=5000次标量乘法,再与A3相乘又需要做10*5*50=2500次标量乘法,共需7500次标量乘法。
如果按照(A1(A2A3))的顺序计算,为计算A2A3(规模100×50),需100*5*50=25000次标量乘法,再与A1相乘又需10*100*50=50000次标量乘法,共需75000次标量乘法。因此第一种顺序计算要比第二种顺序计算快10倍。
矩阵链乘法问题(matrix-chain multiplication problem)可描述如下:给定n个矩阵的链<A1,A2,...,An>,矩阵Ai的规模为p(i-1)×p(i) (1<=i<=n),求完全括号化方案,使得计算乘积A1A2...An所需标量乘法次数最少。
因为括号方案的数量与n呈指数关系,所以通过暴力搜索穷尽所有可能的括号化方案来寻找最优方案是一个糟糕策略。我们可以使用递归关系来找到我们需要的最优解法,首先,我们要用一个函数MCM来得到最小标量相乘次数,那么MCM也可用来定义在所有情况下的最优子段。再使用动态规划和备忘录法即可得到结果。

应用动态规划方法
下面用动态规划方法来求解矩阵链的最优括号方案,我们还是按照之前提出的4个步骤进行:
1.刻画一个最优解的结构特征
2.递归地定义最优解的值
3.计算最优解的值,通常采用自底向上的方法
4.利用计算出的信息构造一个最优解

算法实现

def mult(chain):
    n = len(chain)
    # single matrix chain has zero cost
    aux = {(i, i): (0,) + chain[i] for i in range(n)}
    print(aux)
    # i: length of subchain(子链)
    for i in range(1, n):
        # j: starting index of subchain
        for j in range(0, n - i):
            best = float('inf') #inf is infinite(无穷大)
            # k: splitting point of subchain
            for k in range(j, j + i):
                # multiply subchains at splitting point
                lcost, lname, lrow, lcol = aux[j, k]
                rcost, rname, rrow, rcol = aux[k + 1, j + i]
                cost = lcost + rcost + lrow * lcol * rcol
                var = '(%s%s)' % (lname, rname)
                print(cost, var)
                # pick the best one
                if cost < best:
                    best = cost
                    aux[j, j + i] = cost, var, lrow, rcol
                    print(aux)
    return dict(zip(['cost', 'order', 'rows', 'cols'], aux[0, n - 1]))

结果

{(0, 0): (0, 'A', 10, 20), 
 (1, 1): (0, 'B', 20, 30), 
 (2, 2): (0, 'C', 30, 40)}
 6000 (AB)

{(0, 0): (0, 'A', 10, 20),
 (1, 1): (0, 'B', 20, 30),
 (2, 2): (0, 'C', 30, 40),
 (0, 1): (6000, '(AB)', 10, 30)}
 24000 (BC)

{(0, 0): (0, 'A', 10, 20), 
 (1, 1): (0, 'B', 20, 30), 
 (2, 2): (0, 'C', 30, 40), 
 (0, 1): (6000, '(AB)', 10, 30),
 (1, 2): (24000, '(BC)', 20, 40)}
 32000 (A(BC))

{(0, 0): (0, 'A', 10, 20),
 (1, 1): (0, 'B', 20, 30),
 (2, 2): (0, 'C', 30, 40),
 (0, 1): (6000, '(AB)', 10, 30),
 (1, 2): (24000, '(BC)', 20, 40),
 (0, 2): (32000, '(A(BC))', 10, 40)}
 18000 ((AB)C)

{(0, 0): (0, 'A', 10, 20),
 (1, 1): (0, 'B', 20, 30),
 (2, 2): (0, 'C', 30, 40),
 (0, 1): (6000, '(AB)', 10, 30),
 (1, 2): (24000, '(BC)', 20, 40),
 (0, 2): (18000, '((AB)C)', 10, 40)}

{'cost': 18000, 'order': '((AB)C)', 'rows': 10, 'cols': 40}
posted @ 2020-07-15 10:59  胡椒椒椒(弃用勿联系)  阅读(774)  评论(0编辑  收藏  举报