重新blas库函数
有时候需要对blas库函数进行重写
我写的一个函数为
#include <iostream> using namespace std; int main() { const int M = 4;//A的行数,C的行数 const int N = 2;//B的列数,C的列数 const int K = 3;//A的列数,B的行数 const float alpha = 1; const float beta = 0; const float A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 }; const float B[K*N] = { 5,4,3,2,1,0 }; float C[M*N]; for (int i = 0; i<M; i++) { for (int j = 0; j<N; j++) { float sum = 0; for (int k = 0; k<K; k++) { sum += A[i*K + k] * B[k*N + j]; } C[i*N + j] = alpha * sum + beta*C[i*N + j]; } } for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { cout << C[i*N + j] << " "; } cout << endl; } }
但是有问题。
在caffe的tools文件夹下建立keshan.cpp
#include <cblas.h> #include <iostream> using namespace std; int main(){ const int M = 4;//A的行数,C的行数 const int N = 2;//B的列数,C的列数 const int K = 3;//A的列数,B的行数 const double alpha = 1; const double beta = 0; const double A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 }; const double B[K*N] = { 5,4,3,2,1,0 }; double C[M*N]; int lda = K; int ldb = N; cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B,ldb, beta, C, N); for(int i=0;i<M;i++){ for(int j=0;j<N;j++){ cout<<C[i*N+j]<<"\t"; } cout<<endl; } return 0; }
输出结果为
14 8 41 26 68 44 67 46
和我写的一样,似乎没什么错误,但是这个只是对一种错误的描写。
在caffe中caffe_cpu_gemm定义在math_function.cpp中
重写caffe_cpu_gemm只需要增加
//重写caffe_cpu_gemm(float),假设没有transpose template<> void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const float alpha, const float* A, const float* B, const float beta, float* C,int flag) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; if(TransB != CblasNoTrans){ //转置 float BT[N*K]; for(int i=0;i<N;i++){ for(int j=0;j<K;j++){ BT[j*N+i] = B[i*K+j]; } } //相乘 for(int i=0;i<M;i++){ for(int j=0;j<N;j++){ float sum = 0; for(int k=0;k<K;k++){ sum += A[i*K+k]*BT[k*N+j]; } C[i*N+j] = alpha * sum + beta*C[i*N+j]; } } }else{ //相乘 for(int i=0;i<M;i++){ for(int j=0;j<N;j++){ float sum = 0; for(int k=0;k<K;k++){ sum += A[i*K+k]*B[k*N+j]; } C[i*N+j] = alpha * sum + beta*C[i*N+j]; } } } } //重写caffe_cpu_gemm(double) template<> void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const double alpha, const double* A, const double* B, const double beta, double* C,int flag) { int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; if(TransB != CblasNoTrans){ //转置 double BT[N*K]; for(int i=0;i<N;i++){ for(int j=0;j<K;j++){ BT[j*N+i] = B[i*K+j]; } } //乘法 for(int i=0;i<M;i++){ for(int j=0;j<N;j++){ double sum = 0; for(int k=0;k<K;k++){ sum += A[i*K+k]*BT[k*N+j]; } C[i*N+j] = alpha * sum + beta*C[i*N+j]; } } }else{ //乘法 for(int i=0;i<M;i++){ for(int j=0;j<N;j++){ double sum = 0; for(int k=0;k<K;k++){ sum += A[i*K+k]*B[k*N+j]; } C[i*N+j] = alpha * sum + beta*C[i*N+j]; } } } }这样就对caffe_cpu_gemm函数进行了重载,caffe_cpu_gemm中增加
template <typename Dtype> void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta, Dtype* C,int flag);对base_conv_layer.cpp和inner_product.cpp中的caffe_cpu_gemm都改为带flag的即可。
下面写的矩阵小框架倒是不错,值得参考
//transpote转置矩阵 #include "Stdio.h" #include "memory.h" template<typename T> void TypePrint(T v); template<typename T, int M, int N> class Matrix { public: Matrix(void) { data = new T[M*N]; }; ~Matrix(void) {}; int getVIndex() { return M; } int getHIndex() { return N; } T getxy(int x, int y) { return data[x*N + y]; } void setxy(int x, int y, T f) { data[x*N + y] = f; } void setdata(T*datap, int size) { memcpy(data, datap, size); } Matrix<T, N, M> transpote() { Matrix<T, N, M> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { m.setxy(j, i, getxy(i, j)); } } return m; } Matrix<T, M, N> operator+(Matrix<T, M, N> &adv) { Matrix<T, N, M> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { m.setxy(getxy(i, j) + adv.getxy(i, j)); } } return m; } Matrix<T, M, N> operator-(Matrix<T, M, N> &adv) { Matrix<T, N, M> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { m.setxy(getxy(i, j) - adv.getxy(i, j)); } } return m; } bool operator==(Matrix<T, M, N> &adv) { Matrix<T, N, M> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { if (getxy(i, j) != adv.getxy(i, j))return false; } } return true; } bool operator!=(Matrix<T, M, N> &adv) { Matrix<T, N, M> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { if (getxy(i, j) != adv.getxy(i, j))return true; } } return false; } void print() { printf("\n"); for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { TypePrint(getxy(i, j)); printf(",\t"); } printf("\n"); } } private: T *data; }; template<typename T, int M, int N, int P> Matrix<T, M, P> operator*(Matrix<T, M, N> &x, Matrix<T, N, P> &y) { Matrix<T, M, P> m; for (int i = 0; i < M; i++) { for (int j = 0; j < P; j++) { T v = 0; for (int k = 0; k < N; k++) { v += (x.getxy(i, k)*y.getxy(k, j)); } m.setxy(i, j, v); } } return m; } template<typename T, int M, int N> Matrix<T, M, N> operator*(Matrix<T, M, N> &x, T y) { Matrix<T, M, N> m; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { m.setxy(i, j, m.getxy(i, j)*y); } } return m; } template<typename T, int M, int N> Matrix<T, M, N> operator*(T y, Matrix<T, M, N> &x) { return x*y; } template<> void TypePrint(float v) { printf("%f", v); } template<> void TypePrint(int v) { printf("%d", v); } #define type float int d1[] = { 1,2, 2,3, 3,0 }; int d2[] = { 2,-3,0, 0,1,-2, -4,5,10 }; int main() { Matrix<int, 3, 2> s; s.setdata(d1, sizeof(d1)); Matrix<int, 3, 3> s1; s1.setdata(d2, sizeof(d2)); Matrix<int, 3, 2> s2 = s1*s; Matrix<int, 2, 3> s3 = s2.transpote();//转置了 s.print(); s1.print(); s2.print(); s3.print(); return 0; }
可以看看cblas的源代码[2],perl写的...
[1] 对cblas_sgemm的说明
[2] linux安装cblas