重新blas库函数

有时候需要对blas库函数进行重写

我写的一个函数为

#include <iostream>
using namespace std;
int main() {
    const int M = 4;//A的行数,C的行数
    const int N = 2;//B的列数,C的列数
    const int K = 3;//A的列数,B的行数
    const float alpha = 1;
    const float beta = 0;
    const float A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 };
    const float B[K*N] = { 5,4,3,2,1,0 };
    float C[M*N];
    for (int i = 0; i<M; i++) {
        for (int j = 0; j<N; j++) {
            float sum = 0;
            for (int k = 0; k<K; k++) {
                sum += A[i*K + k] * B[k*N + j];
            }
            C[i*N + j] = alpha * sum + beta*C[i*N + j];
        }
    }
        for (int i = 0; i < M; i++) {
                for (int j = 0; j < N; j++) {
                        cout << C[i*N + j] << " ";
                        }
                        cout << endl;
        }
}

但是有问题。

在caffe的tools文件夹下建立keshan.cpp

#include <cblas.h>
#include <iostream>
using namespace std;
int main(){
	const int M = 4;//A的行数,C的行数
	const int N = 2;//B的列数,C的列数
	const int K = 3;//A的列数,B的行数
	const double alpha = 1;
	const double beta = 0;
	const double A[K*M] = { 1,2,3,4,5,6,7,8,9,8,7,6 };
	const double B[K*N] = { 5,4,3,2,1,0 };
	double C[M*N];
	int lda = K;
	int ldb = N;
	cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B,ldb, beta, C, N);
	for(int i=0;i<M;i++){
		for(int j=0;j<N;j++){
		cout<<C[i*N+j]<<"\t";
		}
		cout<<endl;
	} 
	return 0;
}

输出结果为

14      8
41      26
68      44
67      46

和我写的一样,似乎没什么错误,但是这个只是对一种错误的描写。

在caffe中caffe_cpu_gemm定义在math_function.cpp中

重写caffe_cpu_gemm只需要增加

//重写caffe_cpu_gemm(float),假设没有transpose
template<>
void caffe_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const float alpha, const float* A, const float* B, const float beta,
    float* C,int flag) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  if(TransB != CblasNoTrans){
			//转置
	  float BT[N*K];
	  for(int i=0;i<N;i++){
		  for(int j=0;j<K;j++){
			  BT[j*N+i] = B[i*K+j];
		  }
	  }
	  //相乘
	  for(int i=0;i<M;i++){
		  for(int j=0;j<N;j++){
			  float sum = 0;
			  for(int k=0;k<K;k++){
				 sum += A[i*K+k]*BT[k*N+j];
			  }
			  C[i*N+j] = alpha * sum + beta*C[i*N+j];
		  }
	  }
  }else{
	  //相乘
	  for(int i=0;i<M;i++){
		  for(int j=0;j<N;j++){
			  float sum = 0;
			  for(int k=0;k<K;k++){
				 sum += A[i*K+k]*B[k*N+j];
			  }
			  C[i*N+j] = alpha * sum + beta*C[i*N+j];
		  }
	  }
  }

  
}
//重写caffe_cpu_gemm(double)
template<>
void caffe_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const double alpha, const double* A, const double* B, const double beta,
    double* C,int flag) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  
  if(TransB != CblasNoTrans){
		  //转置
	  double BT[N*K];
	  for(int i=0;i<N;i++){
		  for(int j=0;j<K;j++){
			  BT[j*N+i] = B[i*K+j];
		  }
	  }	  
	  //乘法
	  for(int i=0;i<M;i++){
		  for(int j=0;j<N;j++){
			  double sum = 0;
			  for(int k=0;k<K;k++){
				 sum += A[i*K+k]*BT[k*N+j];
			  }
			  C[i*N+j] = alpha * sum + beta*C[i*N+j];
		  }
	  }
	  
  }else{	  
	  //乘法
	  for(int i=0;i<M;i++){
		  for(int j=0;j<N;j++){
			  double sum = 0;
			  for(int k=0;k<K;k++){
				 sum += A[i*K+k]*B[k*N+j];
			  }
			  C[i*N+j] = alpha * sum + beta*C[i*N+j];
		  }
	  }
  }
  
}
这样就对caffe_cpu_gemm函数进行了重载,caffe_cpu_gemm中增加

template <typename Dtype>
void caffe_cpu_gemm(const CBLAS_TRANSPOSE TransA,
    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
    const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
    Dtype* C,int flag);
对base_conv_layer.cpp和inner_product.cpp中的caffe_cpu_gemm都改为带flag的即可。



下面写的矩阵小框架倒是不错,值得参考

//transpote转置矩阵
#include "Stdio.h"
#include "memory.h"
template<typename T>
void TypePrint(T v);
template<typename T, int M, int N>
class Matrix
{
public:
	Matrix(void) {

		data = new T[M*N];

	};
	~Matrix(void) {};

	int getVIndex()
	{
		return M;
	}
	int getHIndex()
	{
		return N;
	}

	T getxy(int x, int y)
	{
		return data[x*N + y];
	}

	void setxy(int x, int y, T f)
	{
		data[x*N + y] = f;
	}

	void setdata(T*datap, int size)
	{
		memcpy(data, datap, size);
	}

	Matrix<T, N, M> transpote()
	{
		Matrix<T, N, M> m;
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				m.setxy(j, i, getxy(i, j));
			}
		}
		return m;

	}

	Matrix<T, M, N> operator+(Matrix<T, M, N> &adv)
	{
		Matrix<T, N, M> m;
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				m.setxy(getxy(i, j) + adv.getxy(i, j));
			}
		}
		return m;
	}

	Matrix<T, M, N> operator-(Matrix<T, M, N> &adv)
	{
		Matrix<T, N, M> m;
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				m.setxy(getxy(i, j) - adv.getxy(i, j));
			}
		}
		return m;
	}

	bool operator==(Matrix<T, M, N> &adv)
	{
		Matrix<T, N, M> m;
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				if (getxy(i, j) != adv.getxy(i, j))return false;
			}
		}
		return true;
	}

	bool operator!=(Matrix<T, M, N> &adv)
	{
		Matrix<T, N, M> m;
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				if (getxy(i, j) != adv.getxy(i, j))return true;
			}
		}
		return false;
	}

	void print()
	{
		printf("\n");
		for (int i = 0; i < M; i++)
		{
			for (int j = 0; j < N; j++)
			{
				TypePrint(getxy(i, j));
				printf(",\t");
			}
			printf("\n");
		}
	}

private:

	T *data;

};

template<typename T, int M, int N, int P>
Matrix<T, M, P> operator*(Matrix<T, M, N> &x, Matrix<T, N, P> &y)
{
	Matrix<T, M, P> m;

	for (int i = 0; i < M; i++)
	{
		for (int j = 0; j < P; j++)
		{
			T v = 0;

			for (int k = 0; k < N; k++)
			{
				v += (x.getxy(i, k)*y.getxy(k, j));
			}

			m.setxy(i, j, v);

		}
	}

	return m;
}

template<typename T, int M, int N>
Matrix<T, M, N> operator*(Matrix<T, M, N> &x, T y)
{
	Matrix<T, M, N> m;

	for (int i = 0; i < M; i++)
	{
		for (int j = 0; j < N; j++)
		{
			m.setxy(i, j, m.getxy(i, j)*y);
		}
	}

	return m;
}

template<typename T, int M, int N>
Matrix<T, M, N> operator*(T y, Matrix<T, M, N> &x)
{
	return x*y;
}




template<>
void TypePrint(float v)
{
	printf("%f", v);
}


template<>
void TypePrint(int v)
{
	printf("%d", v);
}

#define type float



int d1[] =
{
	1,2,
	2,3,
	3,0
};

int d2[] =
{
	2,-3,0,
	0,1,-2,
	-4,5,10
};



int main()
{
	Matrix<int, 3, 2> s;

	s.setdata(d1, sizeof(d1));

	Matrix<int, 3, 3> s1;

	s1.setdata(d2, sizeof(d2));

	Matrix<int, 3, 2> s2 = s1*s;

	Matrix<int, 2, 3> s3 = s2.transpote();//转置了

	s.print();
	s1.print();
	s2.print();
	s3.print();
	return 0;
}

可以看看cblas的源代码[2],perl写的...



[1] 对cblas_sgemm的说明 

[2] linux安装cblas

posted @ 2017-04-25 17:37  开往春天的拖拉机  阅读(122)  评论(0编辑  收藏  举报