Caffe源码-DataTransformer类
DataTransformer类简介
DataTransformer类中主要用于图像预处理操作,layer中可设置TransformationParameter
类型的消息来对输入图像进行减均值、随机镜像、随机裁剪或缩放。DataTransformer类中主要包含重载函数Transform()
,可以对各种类型的图像数据进行预处理,并存入到Blob类型的数据中。类中还包含了以下变量。
TransformationParameter param_; //预处理参数
shared_ptr<Caffe::RNG> rng_; //随机数生成器
Phase phase_; //网络状态,TRAIN还是TEST
Blob<Dtype> data_mean_; //数据的均值,从mean_file中读取到的均值数据
vector<Dtype> mean_values_; //均值数值,以mean_value形式设置一系列数据
其中TransformationParameter消息中包含的内容如下。
// Message that stores parameters used to apply transformation to the data layer's data
message TransformationParameter {
// For data pre-processing, we can do simple scaling and subtracting the data mean,
// if provided. Note that the mean subtraction is always carried out before scaling.
optional float scale = 1 [default = 1]; //数值缩放系数 //缩放操作总是在减均值之后进行
// Specify if we want to randomly mirror data.
optional bool mirror = 2 [default = false]; //预处理时是否需要随机镜像
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 3 [default = 0]; //裁剪后的图像尺寸,非0值表示预处理时需要裁剪图像
// mean_file and mean_value cannot be specified at the same time
optional string mean_file = 4; //均值文件的路径,均值文件为二进制proto类型文件
// if specified can be repeated once (would subtract it from all the channels)
// or can be repeated the same number of times as channels
// (would subtract them from the corresponding channel)
// mean_file与mean_value不能同时设置
repeated float mean_value = 5; //均值数值,mean_value的个数等于1或图像通道数
// Force the decoded image to have 3 color channels.
optional bool force_color = 6 [default = false]; //编码数据解码时强制转化为3通道彩色图
// Force the decoded image to have 1 color channels.
optional bool force_gray = 7 [default = false]; //编码数据解码时强制转化为单通道灰度图
}
data_transformer.cpp源码
template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Phase phase)
: param_(param), phase_(phase) { //构造函数,读取均值文件中的数据或者均值数值
// check if we want to use mean_file
if (param_.has_mean_file()) { //设置了均值文件
//TransformationParameter消息中不能同时设置mean_file和mean_value参数
CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time";
const string& mean_file = param.mean_file(); //均值文件名
if (Caffe::root_solver()) {
LOG(INFO) << "Loading mean file from: " << mean_file; //主线程中打印文件名
}
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); //从该二进制proto文件中读取数据到blob_proto消息中
data_mean_.FromProto(blob_proto); //将BlobProto类型的消息中的数据拷贝到Blob类型的变量中
}
// check if we want to use mean_value
if (param_.mean_value_size() > 0) { //如果设置了均值数值
CHECK(param_.has_mean_file() == false) <<
"Cannot specify mean_file and mean_value at the same time"; //同样先检查不能同时设置
for (int c = 0; c < param_.mean_value_size(); ++c) {
mean_values_.push_back(param_.mean_value(c)); //将设置的值全部保存到类中
}
}
}
//对Datum类中的图像进行预处理操作(减均值/裁剪/镜像/数值缩放),将处理后的图像数据存入缓冲区中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Dtype* transformed_data) {
const string& data = datum.data(); //图像原始数据
const int datum_channels = datum.channels(); //原始图像的通道数
const int datum_height = datum.height(); //原始图像高度
const int datum_width = datum.width(); //原始图像宽度
const int crop_size = param_.crop_size(); //裁剪后的尺寸,非0为有效值
const Dtype scale = param_.scale(); //数值缩放系数
const bool do_mirror = param_.mirror() && Rand(2); //是否需要镜像, mirror()为是否需要随机镜像,Rand(2)会返回0或1的值
const bool has_mean_file = param_.has_mean_file(); //是否设置了均值文件
const bool has_uint8 = data.size() > 0; //datum中uint8数据的个数是否不为空
const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值
CHECK_GT(datum_channels, 0); //有效性检查,图像通道数是否大于0
CHECK_GE(datum_height, crop_size); //原始图像高度大于等于裁剪后的尺寸
CHECK_GE(datum_width, crop_size); //原始图像宽度大于等于裁剪后的尺寸
Dtype* mean = NULL;
if (has_mean_file) {
//设置了均值文件,则检查均值文件中数据的channel/height/width与原始图像的是否匹配
CHECK_EQ(datum_channels, data_mean_.channels());
CHECK_EQ(datum_height, data_mean_.height());
CHECK_EQ(datum_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data(); //最后返回均值文件的数据指针
}
if (has_mean_values) {
//设置了均值数值,则设置的数值的个数要么为1(图像的所有通道都减去相同的值),要么设置的个数与图像的通道数相等
CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
"Specify either 1 mean_value or as many as channels: " << datum_channels;
if (datum_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity
for (int c = 1; c < datum_channels; ++c) { //设置的数值的个数为1,但是图像通道数个数不为1
mean_values_.push_back(mean_values_[0]); //将每个通道对应的均值均设置为该值mean_values_[0]
}
}
}
int height = datum_height; //height/width为预处理后图像的长宽,初始时为原图尺寸
int width = datum_width;
int h_off = 0; //裁剪时的h/w方向的偏移量
int w_off = 0;
if (crop_size) {
height = crop_size; //如果设置了裁剪的尺寸,则更新
width = crop_size;
// We only do random crop when we do training.
if (phase_ == TRAIN) { //训练模式下,随机得到裁剪的h和w方向的偏移
h_off = Rand(datum_height - crop_size + 1); //返回一个 0 ~ datum_height - crop_size 之间的随机数
w_off = Rand(datum_width - crop_size + 1);
} else { //测试模式下,固定为中心裁剪
h_off = (datum_height - crop_size) / 2; //中心裁剪的h/w的偏移
w_off = (datum_width - crop_size) / 2;
}
}
//datum内只存有一张图像,num=1,n=0
//top_index为输出图像的某个点的在输出图像中的索引,data_index为该点在原始图像中的索引
//datum_element为该点在原始图像中的值
Dtype datum_element;
int top_index, data_index;
for (int c = 0; c < datum_channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
//原始图像中的(0, c, h_off + h, w_off + w)点
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
if (do_mirror) { //此处可以看出镜像为width方向的镜像
top_index = (c * height + h) * width + (width - 1 - w); //镜像,则对应输出图像的(0,c,h,width - 1 - w)点
} else {
top_index = (c * height + h) * width + w; //无需镜像,则对应输出的(0,c,h,w)点
}
if (has_uint8) { //如果datum中存在uint8数据
datum_element = static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); //原始图像上该点的值
} else { //如果datum中不存在uint8数据,则从float_data中读取float类型的数据
datum_element = datum.float_data(data_index); //同样,该点的值
}
if (has_mean_file) {
//设置了均值文件,则每个数据都有个对应的均值mean[data_index],减去均值后乘上数值缩放系数,得到输出的值
transformed_data[top_index] = (datum_element - mean[data_index]) * scale;
} else {
if (has_mean_values) {
//设置了均值数值,则图像每个通道上的数据都存在一个均值,减均值乘上缩放系数
transformed_data[top_index] = (datum_element - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = datum_element * scale; //未设置均值,直接缩放
}
}
}
}
}
}
//对Datum类中的图像进行预处理操作(减均值/裁剪/镜像/数值缩放),将处理后的图像数据存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Blob<Dtype>* transformed_blob) {
// If datum is encoded, decode and transform the cv::image.
if (datum.encoded()) { //如果数据为编码过的数据,则需要使用opencv进行解码
#ifdef USE_OPENCV
//force_color表示解码后的数据为3通道彩色图,force_gray表示解码后的图像为单通道的灰度图,两者不能同时设置
CHECK(!(param_.force_color() && param_.force_gray())) << "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //从内存缓冲区中读取一张图像
} else {
cv_img = DecodeDatumToCVMatNative(datum); //未设置force_color/force_gray,则按原始格式读取图像
}
// Transform the cv::image into blob.
return Transform(cv_img, transformed_blob); //将读取的图像进行预处理,然后存入transformed_blob中
#else
LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
} else {
//未编码数据,不能设置force_color或force_gray,否则报错
if (param_.force_color() || param_.force_gray()) {
LOG(ERROR) << "force_color and force_gray only for encoded datum";
}
}
const int crop_size = param_.crop_size(); //裁剪后的尺寸
const int datum_channels = datum.channels(); //原始图像的通道数/高度/宽度
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
const int channels = transformed_blob->channels(); //输出blob的通道数/高度/宽度/个数
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, datum_channels); //检查原始图像与输出图像的尺寸是否匹配
CHECK_LE(height, datum_height);
CHECK_LE(width, datum_width);
CHECK_GE(num, 1);
if (crop_size) {
CHECK_EQ(crop_size, height); //需要裁剪,则原始图像的宽高大于等于输出的图像的宽高
CHECK_EQ(crop_size, width);
} else {
CHECK_EQ(datum_height, height); //无需裁剪,则两者相等
CHECK_EQ(datum_width, width);
}
Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针
Transform(datum, transformed_data); //预处理图像,并将数据存入transformed_data中
}
//对datum_vector中的多张图像进行预处理,并将结果存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
Blob<Dtype>* transformed_blob) {
const int datum_num = datum_vector.size(); //原始图像数据的个数
const int num = transformed_blob->num(); //输出blob的各个维度的值
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
//检查输入的原始图像的个数,大于0,不超过blob的num维度的值
CHECK_GT(datum_num, 0) << "There is no datum to add";
CHECK_LE(datum_num, num) << "The size of datum_vector must be no greater than transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width); //用于存放单个图像数据
for (int item_id = 0; item_id < datum_num; ++item_id) {
int offset = transformed_blob->offset(item_id); //(n=item_id, c=0, h=0, w=0)点的偏移量,用于存放一张新的图像
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //将uni_blob的数据指针指向transformed_blob的缓冲区
Transform(datum_vector[item_id], &uni_blob); //预处理,并将预处理后的图像保存在uni_blob中
}
}
//对mat_vector中的多张图像进行预处理,并将结果存入Blob类型的数据中
#ifdef USE_OPENCV
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
Blob<Dtype>* transformed_blob) {
const int mat_num = mat_vector.size(); //输入图像的个数
const int num = transformed_blob->num(); //输出blob的个数/通道数/高度/宽度
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
//同样检查输入图像的个数大于0,小于输出blob的num维度的值
CHECK_GT(mat_num, 0) << "There is no MAT to add";
CHECK_EQ(mat_num, num) << "The size of mat_vector must be equals to transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width);
for (int item_id = 0; item_id < mat_num; ++item_id) {
int offset = transformed_blob->offset(item_id); //(n=item_id, c=0, h=0, w=0)的偏移
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //将uni_blob的数据指针指向transformed_blob的缓冲区
Transform(mat_vector[item_id], &uni_blob); //预处理图像,结果存入uni_blob中
}
}
//对cv_img(单张图像)进行预处理,并将结果存入Blob类型的数据中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size(); //裁剪后的图像尺寸
const int img_channels = cv_img.channels(); //原始图像的通道数/高度/宽度
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// Check dimensions.
const int channels = transformed_blob->channels(); //输出blob的通道数/高度/宽度/个数
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, img_channels); //检查输入图像与输出blob的各个维度是否匹配
CHECK_LE(height, img_height);
CHECK_LE(width, img_width);
CHECK_GE(num, 1);
//cv_img中的图像数据必须为uint8类型
CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
const Dtype scale = param_.scale(); //设置的数值缩放系数
const bool do_mirror = param_.mirror() && Rand(2); //是否镜像
const bool has_mean_file = param_.has_mean_file(); //是否设置了均值文件
const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值
CHECK_GT(img_channels, 0); //检查输入图像的维度/高度/宽度是否有效
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
Dtype* mean = NULL;
if (has_mean_file) {
//存在均值文件,则还会检查均值blob的数据的形状与输入图像的形状是否匹配
CHECK_EQ(img_channels, data_mean_.channels());
CHECK_EQ(img_height, data_mean_.height());
CHECK_EQ(img_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data(); //均值数据指针
}
if (has_mean_values) {
//如果设置了均值数值,则会检查均值数值的个数是否为1或者等于输入图像的通道数
CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
"Specify either 1 mean_value or as many as channels: " << img_channels;
if (img_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity
for (int c = 1; c < img_channels; ++c) {
//均值数值的个数为1,图像通道数不为0,则将该均值mean_values_[0]作为每个通道的均值
mean_values_.push_back(mean_values_[0]);
}
}
}
int h_off = 0; //裁剪的h/w方向的偏移
int w_off = 0;
cv::Mat cv_cropped_img = cv_img; //裁剪后的图像,初始设置为原始图像
if (crop_size) {
CHECK_EQ(crop_size, height); //检查输出blob的尺寸是否等于裁剪后的图像
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) { //同样,训练模式下会随机得到裁剪时h/w方向的偏移值
h_off = Rand(img_height - crop_size + 1);
w_off = Rand(img_width - crop_size + 1);
} else { //测试模式下会使用中心裁剪方式得到h/w方向的偏移
h_off = (img_height - crop_size) / 2;
w_off = (img_width - crop_size) / 2;
}
cv::Rect roi(w_off, h_off, crop_size, crop_size); //设置图像兴趣区域的位置
cv_cropped_img = cv_img(roi); //得到裁剪后的图像
} else {
CHECK_EQ(img_height, height); //非裁剪模式,检查输入图像的尺寸与输入blob的形状是否一致
CHECK_EQ(img_width, width);
}
CHECK(cv_cropped_img.data); //裁剪后的图像数据不为空
//此处注意opencv中图像是以(h,w,c)形式存放的
Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h); //裁剪图像的第h行数据的指针
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < img_channels; ++c) {
if (do_mirror) { //镜像模式下
top_index = (c * height + h) * width + (width - 1 - w); //得到裁剪图像上(h,w,c)点在输出blob上的索引(c,h,width - 1 - w)
} else {
top_index = (c * height + h) * width + w; //裁剪图像上(h,w,c)点对应输出blob上的(c,h,w)点
}
// int top_index = (c * height + h) * width + w;
Dtype pixel = static_cast<Dtype>(ptr[img_index++]); //裁剪图像上(h,w,c)点的值
if (has_mean_file) {
//裁剪图像上(h,w,c)点对应均值文件上的(c, h_off + h, w_off + w)点
int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
transformed_data[top_index] = (pixel - mean[mean_index]) * scale; //减均值,缩放
} else {
if (has_mean_values) {
//裁剪图像上(h,w,c)点对应均值数值的mean_values_[c]
transformed_data[top_index] = (pixel - mean_values_[c]) * scale; //减均值,缩放
} else {
transformed_data[top_index] = pixel * scale; //未设置均值,直接缩放
}
}
}
}
}
}
#endif // USE_OPENCV
//对input_blob中的所有图像进行预处理,并将结果存入transformed_blob中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size(); //裁剪后的尺寸
const int input_num = input_blob->num(); //输入图像的个数
const int input_channels = input_blob->channels(); //输入图像的通道数/高度/宽度
const int input_height = input_blob->height();
const int input_width = input_blob->width();
if (transformed_blob->count() == 0) { //如果输出blob为空,则先按照输出图像的尺寸调整blob的形状
// Initialize transformed_blob with the right shape.
if (crop_size) { //设置了裁剪尺寸 //调整形状,在实际访问内部数据的之后便会为其分配相应的空间
transformed_blob->Reshape(input_num, input_channels, crop_size, crop_size);
} else {
transformed_blob->Reshape(input_num, input_channels, input_height, input_width);
}
}
const int num = transformed_blob->num(); //输出图像的个数/通道数/高度/宽度
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int size = transformed_blob->count(); //输出blob的大小
CHECK_LE(input_num, num); //输入图像个数不超过输出图像个数
CHECK_EQ(input_channels, channels); //输入输出图像通道数相同
CHECK_GE(input_height, height); //输入图像尺寸不小于输出图像尺寸
CHECK_GE(input_width, width);
const Dtype scale = param_.scale(); //设置的数值缩放系数
const bool do_mirror = param_.mirror() && Rand(2); //是否镜像
const bool has_mean_file = param_.has_mean_file(); //是否设置了均值文件
const bool has_mean_values = mean_values_.size() > 0; //是否设置了均值数值
int h_off = 0; //裁剪图像时h/w方向的偏移量
int w_off = 0;
if (crop_size) { //需要裁剪
CHECK_EQ(crop_size, height); //输出图像与裁剪尺寸一致
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) { //训练模式,随机获取裁剪时h/w方向的偏移量
h_off = Rand(input_height - crop_size + 1);
w_off = Rand(input_width - crop_size + 1);
} else { //测试模式,获取中心裁剪时h/w方向的偏移量
h_off = (input_height - crop_size) / 2;
w_off = (input_width - crop_size) / 2;
}
} else {
CHECK_EQ(input_height, height); //非裁剪模式,检查输入图像与输出图像尺寸是否一致
CHECK_EQ(input_width, width);
}
Dtype* input_data = input_blob->mutable_cpu_data(); //输入blob的数据指针
if (has_mean_file) {
CHECK_EQ(input_channels, data_mean_.channels()); //设置了均值文件,则检查均值文件中的blob与输入blob的c/h/w是否一致
CHECK_EQ(input_height, data_mean_.height());
CHECK_EQ(input_width, data_mean_.width());
for (int n = 0; n < input_num; ++n) {
int offset = input_blob->offset(n); //输入blob中第n张图像数据的起始偏移
caffe_sub(data_mean_.count(), input_data + offset,
data_mean_.cpu_data(), input_data + offset); //相减,(input_data + offset)[] -= data_mean_cpp_data[]
}
}
if (has_mean_values) { //设置了均值数值
//同样,检查均值数值的个数等于1或等于通道数
CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<
"Specify either 1 mean_value or as many as channels: " << input_channels;
if (mean_values_.size() == 1) {
caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data); //input_data[i] += -(mean_values_[0])
} else {
for (int n = 0; n < input_num; ++n) {
for (int c = 0; c < input_channels; ++c) {
int offset = input_blob->offset(n, c); //输入blob的第n张图的第c通道的起始偏移,同一通道需减去相同的均值数值
// (input_data + offset)[i] += -(mean_values_[c])
caffe_add_scalar(input_height * input_width, -(mean_values_[c]), input_data + offset);
}
}
}
}
Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //输出blob的数据指针
for (int n = 0; n < input_num; ++n) {
int top_index_n = n * channels; //计算输出偏移的中间量,不好描述,大致可理解为输出blob的(n, ?, ?, ?)点的偏移
int data_index_n = n * channels; //输入blob的的(n, ?, ?, ?)点的偏移
for (int c = 0; c < channels; ++c) {
int top_index_c = (top_index_n + c) * height; //输出blob的(n, c, ?, ?)点的偏移
int data_index_c = (data_index_n + c) * input_height + h_off; //输入blob的(n, c, h_off, ?)点的偏移
for (int h = 0; h < height; ++h) {
int top_index_h = (top_index_c + h) * width; //输出blob的(n, c, h, ?)点的偏移
int data_index_h = (data_index_c + h) * input_width + w_off; //输入blob的(n, c, h_off + h, w_off)点的偏移
if (do_mirror) { //需要镜像
int top_index_w = top_index_h + width - 1; //输出blob的(n, c, h, width - 1)点的偏移
for (int w = 0; w < width; ++w) {
//输出blob的(n, c, h, width - 1 - w)点对应输入blob的(n, c, h_off + h, w_off + w)点
transformed_data[top_index_w-w] = input_data[data_index_h + w];
}
} else {
for (int w = 0; w < width; ++w) {
//输出blob的(n, c, h, w)点对应输入blob的(n, c, h_off + h, w_off + w)点
transformed_data[top_index_h + w] = input_data[data_index_h + w];
}
}
}
}
}
if (scale != Dtype(1)) { //非1,则还需缩放数据
DLOG(INFO) << "Scale: " << scale;
caffe_scal(size, scale, transformed_data); //transformed_data[] *= scale
}
}
//推断图像在预处理之后的形状
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
if (datum.encoded()) { //编码过的数据
#ifdef USE_OPENCV
CHECK(!(param_.force_color() && param_.force_gray()))
<< "cannot set both force_color and force_gray"; //同样,force_color/force_gray不能同时设置
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //读取数据,返回图像
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
// InferBlobShape using the cv::image.
return InferBlobShape(cv_img); //判断图像在预处理后的形状,返回
#else
LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
}
//非编码数据,直接判断
const int crop_size = param_.crop_size(); //裁剪后的尺寸
const int datum_channels = datum.channels(); //输入数据的通道数/高度/宽度
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
CHECK_GT(datum_channels, 0); //有效性检查,输入数据的通道数大于0,宽高不小于裁剪后的尺寸
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
// Build BlobShape.
vector<int> shape(4); //图像形状
shape[0] = 1; //单张图像,num固定为1
shape[1] = datum_channels;
shape[2] = (crop_size)? crop_size: datum_height; //需要裁剪则为裁剪的尺寸,否则为原始尺寸
shape[3] = (crop_size)? crop_size: datum_width;
return shape;
}
//推断datum_vector中的图像在预处理之后的形状
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const vector<Datum> & datum_vector) {
const int num = datum_vector.size();
CHECK_GT(num, 0) << "There is no datum to in the vector"; //图像个数需大于0
// Use first datum in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(datum_vector[0]); //得到形状,(1, channel, height, width)
// Adjust num to the size of the vector.
shape[0] = num; //以图像个数设置num维度的值
return shape;
}
#ifdef USE_OPENCV
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) { //推断cv_img在预处理之后的图像尺寸
const int crop_size = param_.crop_size(); //裁剪尺寸
const int img_channels = cv_img.channels(); //输入图像的通道数/高度/宽度
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// Check dimensions.
CHECK_GT(img_channels, 0); //同理,有效性检查
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = img_channels;
shape[2] = (crop_size)? crop_size: img_height; //输出尺寸为裁剪后的尺寸或者原始尺寸
shape[3] = (crop_size)? crop_size: img_width;
return shape;
}
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
const vector<cv::Mat> & mat_vector) { //推断mat_vector中的图像在预处理之后的形状
const int num = mat_vector.size();
CHECK_GT(num, 0) << "There is no cv_img to in the vector"; //图像个数大于0
// Use first cv_img in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(mat_vector[0]); //得到单张图像预处理后的尺寸
// Adjust num to the size of the vector.
shape[0] = num; //以图像个数设置num维度的值
return shape;
}
#endif // USE_OPENCV
template <typename Dtype>
void DataTransformer<Dtype>::InitRand() { //初始化随机数生成器
//是否需要随机数生成器,只有设置了随机镜像或训练模式下设置了随机裁剪才需要随即操作
const bool needs_rand = param_.mirror() || (phase_ == TRAIN && param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand(); //随机得到一个随机种子
rng_.reset(new Caffe::RNG(rng_seed)); //使用该种子创建一个随机数生成器
} else {
rng_.reset(); //不需要随机,释放
}
}
template <typename Dtype>
int DataTransformer<Dtype>::Rand(int n) { //返回一个0 ~ n-1 之间的随机数
CHECK(rng_);
CHECK_GT(n, 0);
caffe::rng_t* rng = static_cast<caffe::rng_t*>(rng_->generator()); //随机数生成器
return ((*rng)() % n); //随机数,取余
}
小结
- 注意opencv中图像是以(height, width, channel)形式存放的,与caffe中的(num, channel, height,width)形式不同。
- caffe::RNG类中封装了boost库和CUDA的CURAND库的随机数函数,实现了跨平台编译。CURAND库的函数可参考官方提供的文档。
参考
https://docs.nvidia.com/cuda/pdf/CURAND_Library.pdf
Caffe的源码笔者是第一次阅读,一边阅读一边记录,对代码的理解和分析可能会存在错误或遗漏,希望各位读者批评指正,谢谢支持!