caffe源码解析:卷积乘法中用到的im2col及col2im

这两个函数其实完成的功能比较简单,im2col就是把矩阵按卷积乘法所需,变换成列向量,col2im是一个逆过程

从下面这张图你一眼就能看明白im2col的操作(caffe中卷积计算都是Matrix_Kernel * Matrix_Col),因为都列出来太长了,我只列出了前4个,注意这是四周围完全没有填充0的情况,

 

col2im是一个反过来的过程,那么你可能会好奇,这两个操作能完全可逆吗?

事实上,结构是可逆的,结果不是,下面这个图很好地说明了展开的计算过程(图片比较大,可下载到电脑上看),

下面是一个可单独运行的测试源码,你可以随便编译跑一跑

#include <iostream>
using namespace std;

inline bool is_a_ge_zero_and_a_lt_b(int a, int b) {
	return static_cast<unsigned>(a) < static_cast<unsigned>(b);
}

template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype* Y) {
	if (alpha == 0) {
		memset(Y, 0, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)
		return;
	}
	for (int i = 0; i < N; ++i) {
		Y[i] = alpha;
	}
}

template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
	const int height, const int width, const int kernel_h, const int kernel_w,
	const int pad_h, const int pad_w,
	const int stride_h, const int stride_w,
	const int dilation_h, const int dilation_w,
	Dtype* data_col) {
	const int output_h = (height + 2 * pad_h -
		(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
	const int output_w = (width + 2 * pad_w -
		(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
	const int channel_size = height * width;
	for (int channel = channels; channel--; data_im += channel_size) {
		for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
			for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
				int input_row = -pad_h + kernel_row * dilation_h;
				for (int output_rows = output_h; output_rows; output_rows--) {
					if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
						for (int output_cols = output_w; output_cols; output_cols--) {
							*(data_col++) = 0;
						}
					}
					else {
						int input_col = -pad_w + kernel_col * dilation_w;
						for (int output_col = output_w; output_col; output_col--) {
							if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
								*(data_col++) = data_im[input_row * width + input_col];
							}
							else {
								*(data_col++) = 0;
							}
							input_col += stride_w;
						}
					}
					input_row += stride_h;
				}
			}
		}
	}
}


template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
	const int height, const int width, const int kernel_h, const int kernel_w,
	const int pad_h, const int pad_w,
	const int stride_h, const int stride_w,
	const int dilation_h, const int dilation_w,
	Dtype* data_im) {
	caffe_set(height * width * channels, Dtype(0), data_im);
	const int output_h = (height + 2 * pad_h -
		(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
	const int output_w = (width + 2 * pad_w -
		(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
	const int channel_size = height * width;
	for (int channel = channels; channel--; data_im += channel_size) {
		for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
			for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
				int input_row = -pad_h + kernel_row * dilation_h;
				for (int output_rows = output_h; output_rows; output_rows--) {
					if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
						data_col += output_w;
					}
					else {
						int input_col = -pad_w + kernel_col * dilation_w;
						for (int output_col = output_w; output_col; output_col--) {
							if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
								data_im[input_row * width + input_col] += *data_col;
							}
							data_col++;
							input_col += stride_w;
						}
					}
					input_row += stride_h;
				}
			}
		}
	}
}

// 如果想运行6x6的矩阵,请取消下面的注释,并把5X5那段注释掉
int dataim[] = {
	1,2,3,4,5,6,
	5,6,7,8,9,10,
	6,5,4,3,2,1,
	10,9,8,7,6,5,
	4,3,2,1,5,6,
	3,2,1,6,5,4,
};

int datacol[1000];
int outim[50];

int main()
{
	im2col_cpu(dataim, 1, 6, 6, 3, 3, 0, 0, 1, 1, 1, 1, datacol);
	col2im_cpu(datacol, 1, 6, 6, 3, 3, 0, 0, 1, 1, 1, 1, outim);
	return 0;
}

// 如果想运行5x5的矩阵,请取消下面的注释, 并把上面那段注释掉
/* 
int dataim[] = {
	1,2,3,4,5,
	6,7,8,9,10,
	5,4,3,2,1,
	10,9,8,7,6,
	4,3,2,1,5,
};

int datacol[1000];
int outim[50];

int main()
{
	im2col_cpu(dataim, 1, 5, 5, 3, 3, 0, 0, 1, 1, 1, 1, datacol);
	col2im_cpu(datacol, 1, 5, 5, 3, 3, 0, 0, 1, 1, 1, 1, outim);
	return 0;
}*/

按上面源码的操作,先运行im2col,再运行col2im,结果就很有意思了,相当于每个元素都乘了一个放大系数,只是不同的位置的放大系数是不一样的,看下面的图

仔细看那个放大系数矩阵,非常有规律,有木有?

AI视像算法学习群:824991413

posted @ 2018-09-26 10:24  SpaceVision  阅读(115)  评论(0编辑  收藏  举报