libtorch c++ 自定义数据类型并使用

上述几节主要介绍了如何利用MNIST数据集搭建多层神经网络并完成模型的训练,用到的数据都是torch::data::dataset自带的数据集,这节介绍如何根据实际情况创建自己的数据集。

(1)自定义类型的设计方法

实际上,自定义数据类型很简单,只需要继承torch::data::datasets::Dataset<self, SingleExample>,同时重写get(size_t index)以获取指定元素和样本总数size()即可。

Dataset类继承的定义在base.h中,它继承自BatchDataset,它支持随机方式获取元素,也支持批量的方式获取元素。仅需要重写两个函数:

第一个是get(size_t index),用来获取指定编号的样本

 /// Returns the example at the given index.
  virtual ExampleType get(size_t index) = 0;

第二个是size(),用来获取总的样本数

/// Returns the size of the dataset, or an empty optional if it is unsized.
  virtual optional<size_t> size() const = 0;

对于返回的元素Example是模板类型,它是个数据项data和标签项target的组合,而data和target都是张量Tensor,数据项data用于给传递给神经网络正向传播,数据标签项target用于和正向传播的结果output一起结算损失大小。

换句话说,自定义的类,只要继承了Dataset,同时重写了get函数,其他部分可以任意设计。这里介绍两种可能的情况,

(1)比如自定义类型中存储了数据项坐在的文件夹和标签项所在的文件夹,两个文件夹中文件数相同,比如都是10000个,文件中存储了地质模型和地震正演模拟结果,具有一一对应关系,自定义类型按照列表方式分别存储两个文件件中文件名,通过索引的方式可以获取文件名,然后get(size_t index)可以根据指定的文件名读取文件内容并返回张量格式的数据。

(2)比如自定义类型中存储了文件,文件中存储了所有的样本数据,比如列存储方式的测井数据,具有m行,n列数据,每行数据中前n-1列为数据项,第n列为标签项,读取内容后以两个列表形式存储数据,通过指定索引号,可以通过get(size_t index)方式获取张量格式的有n-1个数据组成的数据项和1个数据组成的标签项。

(3)再比如,自定义类型中存储了一个地震数据体,同时一个文本文件存储了其他方式解释的沉积相或产量等,分别是沉积相类型或产量及其空间坐标xyz,这些数据按行存储,通过指定行号,的get方法可以从地震数据体重获取点周围的地震数据组成数据项和沉积相、产量自身作为标签项。

总之,自定义类型的格式各种各样,不一一列举,

(2)自定义类型的通用格式

假设自定义类型名称为MyDataset,构造函数中分别制定两种数据所在的文件位置,形成两个文件名组成的字符串列表,或者直接传递两个文件名字符串列表给构造函数。

如果单个样本数据量不大,可以在自定义的类型中设置两个张量作为数据项和标签项,其中的数据项为states_, 标签项为labels_, 两者都是Tensor类型,把该批数据都存储在两个张量中。

如果单个样本数据量很大,内存不足,数据类型自身不存储整个批次的数据,只存储每个样本的文件名称列表,然后只在get(size_t index)中创建单个样本的数据,并通过make_data_loader的方式叠置成整个批次的数据。

下面的代码展示后者,即在数据集内部只存储文件名,实际在data_loader 时再真正读取数据,其中read_source和read_target函数的具体内容不在展示。

class CustomDataset :public torch::data::Dataset<CustomDataset> {
private:
	//declare 2 vectors for sources and targets
	//std::vector<torch::Tensor> sources, targets;
	std::vector<std::string>sourceFiles, targetFiles;
	int rows;
	int cols;
public:
	//constructor
	CustomDataset(std::vector<std::string> sources_list, std::vector<std::string>targets_list, int rows, int cols) {
		if (sources_list.size() != targets_list.size()) {
			std::cout << "sources_list size must be equal as target_list size" << std::endl;
			return ;
		}
		//sources = process_sources(sources_list, rows, cols);
		//targets = process_targets(targets_list, rows, cols);
		this->sourceFiles = sources_list;
		this->targetFiles = targets_list;
		this->rows = rows;
		this->cols = cols;

	};

	//override get() function to return tensor at location index
	torch::data::Example<>get(size_t index)override {
		/*torch::Tensor sample_source = sources.at(index);
		torch::Tensor sample_target = targets.at(index);*/
		torch::Tensor sample_source = read_source(sourceFiles[index], rows, cols);
		torch::Tensor sample_target = read_target(targetFiles[index], rows, cols);
		std::cout << index << std::endl;
		std::cout << sample_source.max() << std::endl;
		std::cout << sample_target.max() << std::endl;
		/*torch::Tensor sample_source = read_source(sourceFiles[index]);
		torch::Tensor sample_target = read_target(targetFiles[index]);*/
		return { sample_source.clone(), sample_target.clone() };
	};

	//return the length of the data
	torch::optional<size_t>size()const override {
		//return targets.size();
		return sourceFiles.size();
	};
};

下面是主函数中对自定义数据的调用和打印输出以作测试

int main()
{

	std::vector<std::string> home_root, sourceList, targetList;
	home_root.push_back("floder1");
	home_root.push_back("floader2");

	for (int i = 0; i < 5000; i++) {
		std::string dataFile = home_root[0] + "\\case_" + std::to_string(i) + ".txt";
		std::string targetFile = home_root[1] + "\\case_" + std::to_string(i) + ".txt";
		if (std::filesystem::exists(dataFile) && std::filesystem::exists(targetFile)) {
			sourceList.push_back(dataFile);
			targetList.push_back(targetFile);			
		}
		else{
			std::cout << dataFile << " exist status:" << std::filesystem::exists(dataFile) << std::endl;
			std::cout << targetFile << " exist status:" << std::filesystem::exists(targetFile) << std::endl;
			continue;
		}		
	}
	int rows = 1000;
	int cols = 1000;
	auto dataset = CustomDataset(sourceList, targetList, rows, cols).map(torch::data::transforms::Stack<>());
	int batchSize = 10;
	auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), batchSize);

	for (auto& batch : *dataLoader) {

		auto data = batch.data;
		auto target = batch.target;
		std::cout << data.sizes() << std::endl;
		std::cout << data.max() << std::endl;
		std::cout << data << std::endl;
	}

	return EXIT_SUCCESS;
}

打印结果如下:

......... 
-0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000
  0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000
 -0.0264 -0.0292 -0.0313 -0.0323 -0.0318 -0.0296 -0.0254 -0.0190 -0.0106
  0.0074  0.0024 -0.0021 -0.0057 -0.0086 -0.0105 -0.0116 -0.0120 -0.0118
  0.0000  0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
 -0.0000 -0.0000  0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000 -0.0000 -0.0000  0.0000 -0.0000 -0.0000  0.0000 -0.0000
 -0.0000  0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000
  0.0000 -0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000  0.0000 -0.0000  0.0000 -0.0000 -0.0000  0.0000 -0.0000
  0.0150  0.0152  0.0141  0.0118  0.0086  0.0047  0.0005 -0.0036 -0.0072
  0.0367  0.0215  0.0039 -0.0155 -0.0355 -0.0551 -0.0731 -0.0883 -0.0998
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000
  0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000
 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000
 -0.0000 -0.0000 -0.0000 -0.0000 -0.0000  0.0000  0.0000 -0.0000 -0.0000
  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000
 -0.0000  0.0000 -0.0000  0.0000  0.0000 -0.0000  0.0000  0.0000  0.0000
 -0.0000 -0.0000  0.0000  0.0000  0.0000  0.0000  0.0000 -0.0000  0.0000
....

 

posted @ 2022-08-21 10:13  Oliver2022  阅读(218)  评论(0编辑  收藏  举报