Pytorch Dataset入门
Dataset入门
Pytorch Dataset code:
Pytorch Dataset tutorial:
理论:
PyTorch中的Dataset
是一个抽象类,用来表示数据集的接口,所有其他数据集都需要继承这个类,并且覆写以下三个方法:
-
__init__:初始化数据集的一些配置,例如加载所有的数据标签。
-
__len__:以便
len(dataset)
可以返回数据集的大小,例如n。如果n小于数据集长度,则只会取前n个的数据。 -
__getitem__:输入是数据的索引,以便可以使用
dataset[i]
来获取第i个样本,数据增强一般会在这里做。
代码:
下面是一个自定义的Dataset样例(不可执行):
总结:
值得注意的是,Dataset
只负责数据的加载和预处理,对于如何训练数据(例如:是否进行shuffle,是否进行并行加速等)这部分的逻辑是由DataLoader
实现的。通常情况下,我们会将Dataset
和DataLoader
一起使用。
另外,PyTorch还提供了一些常用的数据集,如:ImageFolder
,CIFAR10
,MNIST
等,这些数据集都是继承Dataset
类,同时在init
方法中进行数据的下载,以及在getitem
方法中进行数据的加载和预处理。
Dataset是单线程读取数据,每次只能读取一个样本,不能一次性读取一个mini-batch的数据。
Dataset的主要特性包含:
-
抽象接口:PyTorch通过定义一个抽象
Dataset
类,让用户可以使用统一的方式来加载各种不同的数据,提供了很好的扩展性。 -
懒加载:实际的数据载入并不发生在构造数据集实例时,而是发生在用到这些数据时,这样可以提高内存利用率,并且可以实现对大规模数据的处理。
-
预处理:
Dataset
的一个重要应用就是数据预处理,你可以在getitem
函数中进行任何你的数据预处理过程。
嗨,欢迎大家关注我的公众号《CV之路》,一起讨论问题,一起学习进步~