第四章 分治策略 4.2 矩阵乘法的Strassen算法
package chap04_Divide_And_Conquer; import static org.junit.Assert.*; import java.util.Arrays; import org.junit.Test; /** * 矩阵相乘的算法 * * @author xiaojintao * */ public class MatrixOperation { /** * 普通的矩阵相乘算法,c=a*b。其中,a、b都是n*n的方阵 * * @param a * @param b * @return c */ static int[][] matrixMultiplicationByCommonMethod(int[][] a, int[][] b) { int n = a.length; int[][] c = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i][j] = 0; for (int k = 0; k < n; k++) { c[i][j] += a[i][k] * b[k][j]; } } } return c; } /** * strassen 算法求矩阵乘法 n为2的幂 * * @param a * @param b * @return */ static int[][] matrixMultiplicationByStrassen(int[][] a, int[][] b) { int n = a.length; if (n == 1) { int[][] c = new int[1][1]; c[0][0] = a[0][0] * b[0][0]; return c; } int m = n / 2; int[][] a11, a12, a21, a22, b11, b12, b21, b22; int[][] c = new int[n][n]; a11 = new int[m][m]; a12 = new int[m][m]; a21 = new int[m][m]; a22 = new int[m][m]; b11 = new int[m][m]; b12 = new int[m][m]; b21 = new int[m][m]; b22 = new int[m][m]; for (int i = 0; i < m; i++) { for (int j = 0; j < m; j++) { a11[i][j] = a[i][j]; } } for (int i = 0; i < m; i++) { for (int j = 0; j < m; j++) { b11[i][j] = b[i][j]; } } for (int i = 0; i < m; i++) { for (int j = m; j < n; j++) { a12[i][j - m] = a[i][j]; } } for (int i = 0; i < m; i++) { for (int j = m; j < n; j++) { b12[i][j - m] = b[i][j]; } } for (int i = m; i < n; i++) { for (int j = 0; j < m; j++) { a21[i - m][j] = a[i][j]; } } for (int i = m; i < n; i++) { for (int j = 0; j < m; j++) { b21[i - m][j] = b[i][j]; } } for (int i = m; i < n; i++) { for (int j = m; j < n; j++) { a22[i - m][j - m] = a[i][j]; } } for (int i = m; i < n; i++) { for (int j = m; j < n; j++) { b22[i - m][j - m] = b[i][j]; } } int[][] s1 = matrixMinus(b12, b22); int[][] s2 = matrixAdd(a11, a12); int[][] s3 = matrixAdd(a21, a22); int[][] s4 = matrixMinus(b21, b11); int[][] s5 = matrixAdd(a11, a22); int[][] s6 = matrixAdd(b11, b22); int[][] s7 = matrixMinus(a12, a22); int[][] s8 = matrixAdd(b21, b22); int[][] s9 = matrixMinus(a11, a21); int[][] s10 = matrixAdd(b11, b12); int[][] p1 = matrixMultiplicationByStrassen(a11, s1); int[][] p2 = matrixMultiplicationByStrassen(s2, b22); int[][] p3 = matrixMultiplicationByStrassen(s3, b11); int[][] p4 = matrixMultiplicationByStrassen(a22, s4); int[][] p5 = matrixMultiplicationByStrassen(s5, s6); int[][] p6 = matrixMultiplicationByStrassen(s7, s8); int[][] p7 = matrixMultiplicationByStrassen(s9, s10); int[][] t1, t2, t3; t1 = matrixAdd(p5, p4); t2 = matrixMinus(t1, p2); int[][] c11 = matrixAdd(t2, p6); int[][] c12 = matrixAdd(p1, p2); int[][] c21 = matrixAdd(p3, p4); t1 = matrixAdd(p5, p1); t2 = matrixMinus(t1, p3); int[][] c22 = matrixMinus(t2, p7); c = matrixConbine(c11, c12, c21, c22); return c; } /** * 矩阵加法 c=a+b * * @param a * @param b * @return */ static int[][] matrixAdd(int[][] a, int[][] b) { int n = a.length; int[][] c = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i][j] = a[i][j] + b[i][j]; } } return c; } /** * 矩阵减法 c=a-b * * @param a * @param b * @return */ static int[][] matrixMinus(int[][] a, int[][] b) { int n = a.length; int[][] c = new int[n][n]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i][j] = a[i][j] - b[i][j]; } } return c; } /** * 将矩阵的四个部分组合 * * @param t11 * @param t12 * @param t21 * @param t22 * @return */ protected static int[][] matrixConbine(int[][] t11, int[][] t12, int[][] t21, int[][] t22) { int n = t11.length; int m = 2 * n; int[][] c = new int[m][m]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i][j] = t11[i][j]; } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i][j + n] = t12[i][j]; } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i + n][j] = t21[i][j]; } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { c[i + n][j + n] = t22[i][j]; } } return c; } @Test public void testName() throws Exception { // int[][] a = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } }; // int[][] b = { { 1, 3, 5 }, { 2, 4, 6 }, { 9, 8, 7 } }; // int[][] c = commonMatrixMultiplication(a, b); // int[][] c = matrixAdd(a, b); int[][] m = { { 1, 2, 3, 4 }, { 5, 6, 7, 8 }, { 9, 10, 11, 12 }, { 13, 14, 15, 16 } }; int[][] n = { { 1, 3, 5, 7 }, { 2, 4, 6, 8 }, { 4, 3, 2, 1 }, { 9, 8, 7, 6 } }; int[][] c = matrixMultiplicationByStrassen(m, n); System.out.println(Arrays.deepToString(c)); int[][] d = matrixMultiplicationByCommonMethod(m, n); System.out.println(Arrays.deepToString(d)); } }
暴力求解复杂度为O(n3),Strassen算法为O(n log7)