【算法导论C++代码】Strassen算法

简单方阵矩乘法

SQUARE-MATRIX-MULTIPLY(A,B)

1  n = A.rows

2  let C be a new n*n natrix

3  for  i = 1 to n

4 for j =1 to n

5 cij = 0

6 for k=1 to n

7 cij=cij+aik·bkj

8 return C

一个简单的分治算法

SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B)

1 n = A.rows

2 let C be a new n*n matrix

3 if n==1

4 c11=a11·b11

5 else partition A,B,and C as in equations (4.9)

6  C11=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11)

      +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21)

7   C12=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12)

      +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22)

 

8 C21=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11)

      +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21)

9 C22=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12)

      +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22)

10 return C

 

矩阵乘法的Strassen算法

SQUARE-MATRIX-STRASSEN-RECURSIVE(A,B)

1 n=A.rows

2 let C be a new n*n matrix

3 if n==1

4    c11=a11·b11

5 else partition A,B,and C as in equations(4-9)

6       S1=B12-B22;

7   S2=A11+A12;

8   S3=A21+A22;

9   S4=B21-B11;

10   S5=A11+A22;

11   S6=B11+B22;

12   S7=A12-A22;

13   S8=B21+B22;

14   S9=A11-A21;

15   S10=B11+B12;

16   P1=SQUARE-MATRIX-STRASSEN-RECURSIVE(A11,S1);

17   P2=SQUARE-MATRIX-STRASSEN-RECURSIVE(S2,B22);

18   P3=SQUARE-MATRIX-STRASSEN-RECURSIVE(S3,B11);

19   P4=SQUARE-MATRIX-STRASSEN-RECURSIVE(A22,S4);

20   P5=SQUARE-MATRIX-STRASSEN-RECURSIVE(S5,S6);

21   P6=SQUARE-MATRIX-STRASSEN-RECURSIVE(S7,S8);

22   P7=SQUARE-MATRIX-STRASSEN-RECURSIVE(S9,S10);

23   C11=P5+P4-P2+P6;

24   C12=P1+P2;

25   C21=P3+P4;

26   C22=P5+P1-P3-P7;

26 return C;

/*C++代码。书上给的分解矩阵的做法是用角标计算而不用建立新的对象,不过我并没有想到可以不用新建对象而进行递归的办法,所以这里还是和书上有些不一样的。另外因为演示,所以新建了个类,不过这个类并不稳定,仅作测试时了解功能就好*/
Matrix.h
class SquareMatrix
{
public:
    SquareMatrix();
    SquareMatrix(int **data,int rows);
    SquareMatrix(int rows);
    ~SquareMatrix();
    int CreateSqMa(int rows);
    int SetData(int rows,int *data);
    int    **iData;
    int iRows;
    friend    SquareMatrix operator+(SquareMatrix A,SquareMatrix B);
    friend    SquareMatrix operator-(SquareMatrix A,SquareMatrix B);
    int SprintSqMa();
};
Matrix.cpp
#include <iostream>
#include "Matrix.h"
SquareMatrix::SquareMatrix()
{

}
SquareMatrix::SquareMatrix(int rows)
{
    this->CreateSqMa(rows);
}
SquareMatrix::SquareMatrix(int **data,int rows)
{
    iData=data;
    iRows=rows;
}

SquareMatrix::~SquareMatrix()
{

}
int SquareMatrix::CreateSqMa(int rows)
{
    iRows = rows;
    iData =new int *[rows];
    for (int i=0;i<iRows;i++)
    {
        iData[i]=new int [rows] ;
        for (int j=0;j<iRows;j++)
        {
            iData[i][j]=0;
        }
    }
    return 0;
}
int SquareMatrix::SetData(int rows,int *data)
{
    int length=rows;
    for (int i = 0; i < length; i++)
    {
        for (int j = 0; j < length; j++)
        {
            iData[j][i]=data[i*rows+j];
        }
    }
    iRows=rows;
    return 0;
}

SquareMatrix operator+(SquareMatrix A,SquareMatrix B)
{
    SquareMatrix C(A.iRows);

    for(int i=0;i<B.iRows;i++)
    {
        for(int j=0;j<B.iRows;j++)
        {
            C.iData[i][j]=A.iData[i][j]+B.iData[i][j];
        }
    }
    C.iRows=A.iRows;
    return C;
}
SquareMatrix operator-(SquareMatrix A,SquareMatrix B)
{
    SquareMatrix C(A.iRows);
    for(int i=0;i<B.iRows;i++)
    {
        for(int j=0;j<B.iRows;j++)
        {
            C.iData[i][j]=A.iData[i][j]-B.iData[i][j];
        }
    }
    return C;
}
int SquareMatrix::SprintSqMa()
{
    for(int i=0;i<iRows;i++)
    {
        for(int j=0;j<iRows;j++)
        {
            std::cout<<iData[i][j]<<' ';
            if(j==(iRows-1))
            {
                std::cout<<std::endl;
            }
        }
    }
    return 0;
}

MAIN.cpp
#include <iostream>
#include "Matrix.h"
using namespace std;
SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B);
SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B);
SquareMatrix Strassen(SquareMatrix A,SquareMatrix B);
int main()
{
    SquareMatrix A,B,C;

    B.CreateSqMa(4);
    C.CreateSqMa(4);
    A.CreateSqMa(4);
    int arr[16]={1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
    B.SetData(4,arr);
    A.SetData(4,arr);

    //C=A+B;
    //C=SquareMatrixMultiply(A,B);
    C=SquareMatrixMultiplyRecursive(A,B);
    //C=Strassen(A,B);
    cout<<"算法导论4.2矩阵乘法Strassen算法"<<endl;

    A.SprintSqMa();
    cout<<endl;

    B.SprintSqMa();
    cout<<endl;

    C.SprintSqMa();
    cout<<endl;

    system("pause");
    return 0;
}

SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B)
{
    SquareMatrix C(A.iRows);
    int n=A.iRows;
    for (int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {            
            for(int k=0;k<n;k++)
            {
                C.iData[i][j]=C.iData[i][j]+A.iData[i][k]*B.iData[k][j];
            }
        }
    }
    return C;
}

SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B)
{
    SquareMatrix C(A.iRows);
    int n=A.iRows;
    if(n==1)
    {
        C.iData[0][0]=A.iData[0][0]*B.iData[0][0];
    }
    else
    {
        int rows_n=n/2;
        SquareMatrix A11(rows_n),A12(rows_n),
            A21(rows_n),A22(rows_n),
            B11(rows_n),B12(rows_n),
            B21(rows_n),B22(rows_n),
            C11(rows_n),C12(rows_n),
            C21(rows_n),C22(rows_n);
        for (int i=0;i<rows_n;i++)
        {
            for(int j=0;j<rows_n;j++)
            {
                A11.iData[i][j]=A.iData[i][j];
                A12.iData[i][j]=A.iData[i][j+rows_n];
                A21.iData[i][j]=A.iData[i+rows_n][j];
                A22.iData[i][j]=A.iData[i+rows_n][j+rows_n];
                B11.iData[i][j]=B.iData[i][j];
                B12.iData[i][j]=B.iData[i][j+rows_n];
                B21.iData[i][j]=B.iData[i+rows_n][j];
                B22.iData[i][j]=B.iData[i+rows_n][j+rows_n];
            }
        }

        C11=SquareMatrixMultiplyRecursive(A11,B11)
            +SquareMatrixMultiplyRecursive(A12,B21);

        C12=SquareMatrixMultiplyRecursive(A11,B12)
            +SquareMatrixMultiplyRecursive(A12,B22);

        C21=SquareMatrixMultiplyRecursive(A21,B11)
            +SquareMatrixMultiplyRecursive(A22,B21);

        C22=SquareMatrixMultiplyRecursive(A21,B12)
            +SquareMatrixMultiplyRecursive(A22,B22);

        for (int i=0;i<rows_n;i++)
        {
            for(int j=0;j<rows_n;j++)
            {
                C.iData[i][j]=C11.iData[i][j];
                C.iData[i][j+rows_n]=C12.iData[i][j];
                C.iData[i+rows_n][j]=C21.iData[i][j];
                C.iData[i+rows_n][j+rows_n]=C22.iData[i][j];
            }
        }
    }
    return C;
}

SquareMatrix Strassen(SquareMatrix A,SquareMatrix B)
{    
    SquareMatrix C(A.iRows);
    int n=A.iRows;
    if(n==1)
    {
        C.iData[0][0]=A.iData[0][0]*B.iData[0][0];
    }
    else
    {
        int rows_n=n/2;
        SquareMatrix A11(rows_n),A12(rows_n),
            A21(rows_n),A22(rows_n),
            B11(rows_n),B12(rows_n),
            B21(rows_n),B22(rows_n),
            C11(rows_n),C12(rows_n),
            C21(rows_n),C22(rows_n),
            S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,
            P1,P2,P3,P4,P5,P6,P7;
        for (int i=0;i<rows_n;i++)
        {
            for(int j=0;j<rows_n;j++)
            {
                A11.iData[i][j]=A.iData[i][j];
                A12.iData[i][j]=A.iData[i][j+rows_n];
                A21.iData[i][j]=A.iData[i+rows_n][j];
                A22.iData[i][j]=A.iData[i+rows_n][j+rows_n];
                B11.iData[i][j]=B.iData[i][j];
                B12.iData[i][j]=B.iData[i][j+rows_n];
                B21.iData[i][j]=B.iData[i+rows_n][j];
                B22.iData[i][j]=B.iData[i+rows_n][j+rows_n];
            }
        }
        S1=B12-B22;
        S2=A11+A12;
        S3=A21+A22;
        S4=B21-B11;
        S5=A11+A22;
        S6=B11+B22;
        S7=A12-A22;
        S8=B21+B22;
        S9=A11-A21;
        S10=B11+B12;

        P1=Strassen(A11,S1);
        P2=Strassen(S2,B22);
        P3=Strassen(S3,B11);
        P4=Strassen(A22,S4);
        P5=Strassen(S5,S6);
        P6=Strassen(S7,S8);
        P7=Strassen(S9,S10);

        C11=P5+P4-P2+P6;
        C12=P1+P2;
        C21=P3+P4;
        C22=P5+P1-P3-P7;

        for (int i=0;i<rows_n;i++)
        {
            for(int j=0;j<rows_n;j++)
            {
                C.iData[i][j]=C11.iData[i][j];
                C.iData[i][j+rows_n]=C12.iData[i][j];
                C.iData[i+rows_n][j]=C21.iData[i][j];
                C.iData[i+rows_n][j+rows_n]=C22.iData[i][j];
            }
        }
    }
    return C;
}

 

 

posted on 2015-09-02 11:20  毛尹航  阅读(1153)  评论(0编辑  收藏  举报