行云

行至水穷处,坐看云起时。

  博客园 :: 首页 :: 博问 :: 闪存 :: 新随笔 :: 联系 :: 订阅 订阅 :: 管理 ::

C=A*B
以C来划分,假设线程数为m,矩阵维度为n*n。那么每个线程计算C的元素个数为n*n/m;

/*
矩阵并行计算
C=A*B  --- C(i,j)等于A的第i行乘以第j列
*/
#include <stdio.h>
#include <time.h>
#include <stdlib.h>
#include <math.h>
#include <windows.h>


/*
    生成n*n矩阵
*/
void GenerateMatrix(float *m, int n);
void PrintMatrix(float *p, int n);
void GeneralMul(float *A, float *B, float *C, int n);
void ClearMatrix(float *m, int n);
/*
    矩阵并行计算
*/
void ParallelCacul(float *A, float *B, float *C, int n, int thread_num);
/*
    两个矩阵的误差
*/
float diff(float *C1, float *C0, int n);

struct ARG {
    float *A;
    float *B;
    float *C;
    int cx, cy; //第一个元素的坐标
    int m;    //当前线程要计算的C的元素个数
    int n;
};
int main(int argc, char **argv)
{
    if (argc != 3)
    {
        printf("Usage: %s N thread_num\n", argv[0]);
        return 0;
    }
    int n=atoi(argv[1]);
    int thread_num = atoi(argv[2]);
    float *A = new float[n*n];
    float *B = new float[n*n];
    float *C = new float[n*n];
    float *C0 = new float[n*n];

    GenerateMatrix(A, n);
    GenerateMatrix(B, n);

    clock_t start;
    float time_used;

    ClearMatrix(C0, n);
    start=clock();    
    GeneralMul(A, B, C0, n);
    time_used = static_cast<float>(clock() - start)/CLOCKS_PER_SEC*1000;
    printf("General:   time = %f\n", time_used);
    
    ClearMatrix(C, n);
    start=clock();
    ParallelCacul(A, B, C, n, thread_num);
    time_used = static_cast<float>(clock() - start)/CLOCKS_PER_SEC*1000;
    printf("Block:  time = %f\n", time_used);
    printf("Difference of two result: %f\n", diff(C0, C, n));

    
    delete [] A;
    delete [] B;
    delete [] C;
    delete [] C0;
    return 0;
}


void ClearMatrix(float *m, int n)
{
    for (int i=0; i<n; i++)
    {
        for (int j=0; j<n; j++)
            m[i*n+j]=0.0;
    }
}
/*
    普通矩阵相乘
*/
void GeneralMul(float *A, float *B, float *C, int n)
{
    for (int i=0; i<n; i++)
    {
        for (int j=0; j<n; j++)
        {
            float *p = C+i*n+j;
            for (int k=0; k<n; k++)
            {
                *p += A[i*n+k]*B[k*n+j];
            }
        }
    }
}


DWORD WINAPI Mul_Fun(LPVOID arg)
{
    struct ARG *p = (struct ARG *)arg;

    float *A = p->A;
    float *B = p->B;
    float *C = p->C;
    int m = p->m;
    int n = p->n;

    for (int i=p->cx; i<n; i++)
    {
        int j;
        if (i==p->cx)
            j = p->cy;
        else
            j=0;
        for (; j<n; j++)
        {
            float *t = C+i*n+j;
            for (int k=0; k<n; k++)
            {
                *t += A[i*n+k]*B[k*n+j];
            }
            m--;
            if (m == 0)
                return 0;
        }
    }

    return 0;
}

void ParallelCacul(float *A, float *B, float *C, int n, int thread_num)
{
    int m = n*n/thread_num; //每个线程需要计算的元素个数,不考虑不能整除的情况
    
    struct ARG *args = new struct ARG[thread_num];
    HANDLE *h = new HANDLE[thread_num];
    
    int i;
    for (i = 0; i<thread_num; i++)
    {
        args[i].A = A;
        args[i].B = B;
        args[i].C = C;
        args[i].n = n;
        args[i].m = m;
    }

    for (i=0; i<thread_num; i++)
    {
        args[i].cx = i*m/n;
        args[i].cy = i*m%n;
        h[i] = CreateThread(NULL, 0, Mul_Fun, (LPVOID)(&args[i]), 0, 0 );
    }
    for (i=0; i<thread_num; i++)
    {
        WaitForMultipleObjects(thread_num, &h[i],TRUE,INFINITE);
    }

}

void GenerateMatrix(float *p, int n)
{
    srand(time(NULL)+rand());
    for (int i=0; i<n*n; i++)
    {
        *p = static_cast<float>(rand())/ (static_cast<float>(rand())+ static_cast<float>(0.55));
        p++;
    }
}

float diff(float *C1, float *C0, int n)
{
    float rst=0.0;
    float t;

    for (int i=0; i<n; i++)
    {
        for (int j=0; j<n; j++)
        {
            t = C1[i*n+j]-C0[i*n+j];
            if (t<0)
                t = -t;
            rst += t;
        }
    }
    return rst;
}

void PrintMatrix(float *p, int n)
{
    for (int i=0; i<n; i++)
    {
        for (int j=0; j<n; j++)
        {
            printf("%.2f\t", p[i*n+j]);
        }
        printf("\n");
    }
    printf("\n");
}

 

posted on 2012-05-28 21:38  windflying  阅读(4485)  评论(0编辑  收藏  举报