CUDA 简单矩阵乘法修订版(选自《大规模并行处理器编程实战》)
__global__ void MatrixMulKernel(float *Md, float *Nd, float *Pd, int Width) { int tx = threadIdx.x; int ty = threadIdx.y; float Pvalue = 0; for(int k = 0; k < Width; k++) { float Mdelement = Md[ty * Width + k]; float Ndelement = Nd[k * Width + tx]; Pvalue += Mdelement * Ndelement; } Pd[ty * Width + tx] = Pvalue; } void MatrixMultiplication(float *M, float *N, float *P, int width) { int size = width * width * sizeof(float); float *Md, *Nd, *Pd; checkCudaErrors(cudaMalloc((void**) &Md, size)); checkCudaErrors(cudaMemcpy(Md, M, size, cudaMemcpyHostToDevice)); checkCudaErrors(cudaMalloc((void**) &Nd, size)); checkCudaErrors(cudaMemcpy(Nd, N, size, cudaMemcpyHostToDevice)); checkCudaErrors(cudaMalloc((void**) &Pd, size)); //这是书中的代码,是不对的 //dim3 dimBlock(width, width); //dim3 dimGrid(1, 1); //MatrixMulKernel<<<dimBlock, dimGrid>>>(Md, Nd, Pd, width); //这个才对哦 dim3 dimGrid(width, width); MatrixMulKernel<<<1, dimGrid>>>(Md, Nd, Pd, width); getLastCudaError("Kernel execution failed"); checkCudaErrors(cudaMemcpy(P, Pd, size, cudaMemcpyDeviceToHost)); checkCudaErrors(cudaFree(Md)); checkCudaErrors(cudaFree(Nd)); checkCudaErrors(cudaFree(Pd)); } int main( int argc, char** argv) { shrQAStart(argc, argv); // use command-line specified CUDA device, otherwise use device with highest Gflops/s int devID = findCudaDevice(argc, (const char**)argv); const int width = 3; float *M = (float*)malloc(sizeof(float) * width * width); float *N = (float*)malloc(sizeof(float) * width * width); float *P = (float*)malloc(sizeof(float) * width * width); //float M[width * width]; //float N[width * width]; //float P[width * width]; M[0] = 1.0f; M[1] = 1.0f; M[2] = 1.0f; M[3] = 1.0f; M[4] = 1.0f; M[5] = 1.0f; M[6] = 1.0f; M[7] = 1.0f; M[8] = 1.0f; N[0] = 2.0f; N[1] = 2.0f; N[2] = 2.0f; N[3] = 2.0f; N[4] = 2.0f; N[5] = 2.0f; N[6] = 2.0f; N[7] = 2.0f; N[8] = 2.0f; MatrixMultiplication(M, N, P, width); for(int i = 0; i < width * width; i++) { printf("%f\n", P[i]); } printf("hello CUDA\n"); free(M); free(N); free(P); getchar(); //runTest( argc, argv); }
原书的代码竟然是有错误的,害我纠结了好多天,幸好我看了NVIDIA CUDA C Programming Guide,发现了问题所在。
哈哈!