矩阵乘法的Strassen算法(上)
原始算法
矩阵乘法相信大家都有所了解,即对于任意 A = (aij) 和 B = (bij) 都是n x n矩阵,可以定义矩阵乘法为cij = ∑k=1~n aik x bkj。
1 SQUARE-MATRIX-MULTIPLY(A,B) 2 { 3 n = A.rows 4 let C be a new n x n matrix 5 for i =1 to n 6 for j =1 to n 7 c(ij) = 0 8 for k = 1 to n 9 c(ij) += a(ik) x b(kj) 10 return C 11 }
以上代码为矩阵乘法的伪代码,第5行遍历所有行,第6行遍历所有列,第8行则将cij累加起来,所需要花费的时间显而易见,为O(n3)。
一个简单的分治算法
假设每个矩阵都是n x n矩阵,且n为2的幂,这样便可以使得每一个矩阵都可以划分为四个子矩阵,由此来计算C = A x B。
由此可以得到,若A = [A11,A12,A21,A22],同理B和C也可以分成四个子矩阵,则可以得到以下四个公式:
C11 = A11 x B11 + A12 x B21
C12 = A11 x B12 + A12 x B22
C21 = A21 x B11 + A22 x B21
C22 = A21 x B12 + A22 x B22
通过这样的方式我们可以直接设计一个递归分治算法:
1 SQUARE-MATRIX-MULTIPLY-RECURSIVE(A, B) 2 { 3 n = A.row 4 let C be a new n x n matrix 5 if n == 1 6 c11 = a11 x b11 7 else partition A, B, C 8 C11 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B11) + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B21) 9 C12 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11, B12) + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12, B22) 10 C21 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B11) + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B21) 11 C22 = SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21, B12) + SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22, B22) 12 return C 13 }
显然,这里的步骤7被省略了。那么应该如何分解子矩阵呢?如果我们重新创建12个子矩阵进行计算,那么我们将花费O(n2)的时间来创建这些子矩阵。实际上,我们使用下标来对子矩阵进行标记,所需要的时间就是常数项时间即O(1),对递归所消耗的时间就毫无影响。
现在我们来推测一下该递归分治算法的运行时间,令T(n)表示两个n x n矩阵乘积的时间。对于第3~7行,所消耗的时间均为常数项时间。对于8~11行,所消耗的时间则为T(n/2),同时,要计算4次矩阵加法,所消耗的时间为O(n2),因此我们可以得到时间递归公式:
T(n) = 8T(n/2) + O(n2)
该递归公式的解为O(n3),并不优于直接对矩阵进行乘法运算。
下一次我将详细介绍Strassen算法的过程❤