【深度学习】PyTorch Dataset类的使用与实例分析
Dataset类
介绍
当我们得到一个数据集时,Dataset类可以帮我们提取我们需要的数据,我们用子类继承Dataset类,我们先给每个数据一个编号(idx),在后面的神经网络中,初始化Dataset子类实例后,就可以通过这个编号去实例对象中读取相应的数据,会自动调用__getitem__方法,同时子类对象也会获取相应真实的Label(人为去复写即可)
Dataset类的作用:提供一种方式去获取数据及其对应的真实Label
在Dataset类的子类中,应该有以下函数以实现某些功能:
- 获取每一个数据及其对应的Label
- 统计数据集中的数据数量
关于2,神经网络经常需要对一个数据迭代多次,只有知道当前有多少个数据,进行训练时才知道要训练多少次,才能把整个数据集迭代完
Dataset官方文档解读
首先看一下Dataset的官方文档解释
导入Dataset类:
from torch.utils.data import Dataset
我们可以通过在Jupyter中查看官方文档
from torch.utils.data import Dataset
help(Dataset)
输出:
Help on class Dataset in module torch.utils.data.dataset:
class Dataset(typing.Generic)
| An abstract class representing a :class:`Dataset`.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:`__len__`, which is expected to return the size of the dataset by many
| :class:`~torch.utils.data.Sampler` implementations and the default options
| of :class:`~torch.utils.data.DataLoader`.
|
| .. note::
| :class:`~torch.utils.data.DataLoader` by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
|
| Method resolution order:
| Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| __getattr__(self, attribute_name)
|
| __getitem__(self, index) -> +T_co
|
| ----------------------------------------------------------------------
| Class methods defined here:
|
| register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
|
| register_function(function_name, function) from typing.GenericMeta
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __abstractmethods__ = frozenset()
|
| __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
|
| __args__ = None
|
| __extra__ = None
|
| __next_in_mro__ = <class 'object'>
| The most base type
|
| __orig_bases__ = (typing.Generic[+T_co],)
|
| __origin__ = None
|
| __parameters__ = (+T_co,)
|
| __tree_hash__ = -9223371872509358054
|
| functions = {'concat': functools.partial(<function Dataset.register_da...
|
| ----------------------------------------------------------------------
| Static methods inherited from typing.Generic:
|
| __new__(cls, *args, **kwds)
| Create and return a new object. See help(type) for accurate signature.
还有一种方式获取官方文档信息:
Dataset??
输出:
Init signature: Dataset(*args, **kwds)
Source:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
functions: Dict[str, Callable] = {}
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
def __getattr__(self, attribute_name):
if attribute_name in Dataset.functions:
function = functools.partial(Dataset.functions[attribute_name], self)
return function
else:
raise AttributeError
@classmethod
def register_function(cls, function_name, function):
cls.functions[function_name] = function
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
if function_name in cls.functions:
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
result_pipe = cls(source_dp, *args, **kwargs)
if isinstance(result_pipe, Dataset):
if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
if function_name not in UNTRACABLE_DATAFRAME_PIPES:
result_pipe = result_pipe.trace_as_dataframe()
return result_pipe
function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
cls.functions[function_name] = function
File: d:\environment\anaconda3\envs\py-torch\lib\site-packages\torch\utils\data\dataset.py
Type: GenericMeta
Subclasses: Dataset, IterableDataset, Dataset, TensorDataset, ConcatDataset, Subset, Dataset, Subset, Dataset, IterableDataset[+T_co], ...
其中我们可以看到:
"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
"""
以上内容显示:
该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写__getitem__
方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写__len__
方法,返回数据集的size大小
实例:GetData类
准备工作
首先我们创建一个类,类名为GetData,这个类要继承Dataset类
class GetData(Dataset):
一般在类中首先需要写的是__init__
方法,此方法用于对象实例化,通常用来提供类中需要使用的变量,可以先不写
class GetData(Dataset):
def __init__(self):
pass
我们可以先写__getitem__
方法:
class GetData(Dataset):
def __init__(self):
pass
def __getitem__(self, idx): # 默认是item,但常改为idx,是index的缩写
pass
其中,idx是index的简称,就是一个编号,以便以后数据集获取后,我们使用索引编号访问每个数据
在实现GetData类之前,我们首先需要解决的问题就是如何读取一个图像数据,通常我们使用PIL来读取
PIL获取图像数据
我们使用PIL来读取数据,它提供一个Image模块,可以让我们提取图像数据,我们先导入这个模块
from PIL import Image
我们可以在Python Console中看看如何使用 Image
在Python Console中,输入代码:
from PIL import Image
将数据集放入项目文件夹,我们需要获取图片的绝对路径,选中具体的图片,右键选择Copy Path,然后选择 Absolute path(快捷键:Ctrl + Shift + C)
img_path = "D:\\DeepLearning\\dataset\\train\\ants\\0013035.jpg"
在Windows下,路径分割需要是
\\
,来表示转译也可以在字符串前面加
r
防转译
使用Image的open方法读取图片:
img = Image.open(img_path)
可以在Python控制台看到读取出来的 img,是一个JpegImageFile类的对象
在图中,可以看到这个对象的一些属性,比如size
我们查看这个属性的内容,输入以下代码:
img.size
输出:
(768, 512)
我们可以看到此图的宽是768,高是512,__len__
表示的是这个size元组的长度,有两个值,所以为 2
show方法显示图片:
img.show()
获取图片的文件名
从数据集路径中,获取所有文件的名字,存储到一个列表中
一个简单的例子(在Python Console中):
我们需要借助os模块
import os
dir_path = "dataset/train/ants_image"
img_path_list = os.listdir(dir_path)
listdir方法会将路径下的所有文件名(包括后缀名)组成一个列表
我们可以使用索引去访问列表中的每个文件名
img_path_list[0]
Out[14]: '0013035.jpg'
构建数据集路径
我们需要搭建数据集的路径表示,一个根目录路径和一个具体的子目录路径,以作为不同数据集的区分
一个简单的案例,在Python Console中输入:
root_dir = "dataset/train"
child_dir = "ants_image"
我们使用os.path.join
方法,将两个路径拼接起来,就得到了ants子数据集的相对路径
path = os.path.join(root_dir, child_dir)
path的值此时是:
path={str}'dataset/train\\ants_image'
我们有了这个数据集的路径后,就可以使用之前所讲的listdir方法,获取这个路径中所有文件的文件名,存储到一个列表中
img_path_list = os.listdir(path)
idx = 0
img_path_list[idx]
Out[21]: '0013035.jpg'
可以看到结果与我们之前的小案例是一样的
有了具体的名字,我们还可以将这个文件名与路径进行组合,然后使用PIL获取具体的图像img对象
img_name = img_path_list[idx]
img_item_path = os.path.join(root_dir, child_dir, img_name)
img = Image.open(img_item_path)
在掌握了如何组装路径、获取路径中的文件名以及获取具体图像对象后,我们可以完善我们的__init__
与__getitem__
方法了
完善__init__方法
在init中为啥使用self:一个函数中的变量是不能拿到另外一个函数中使用的,self可以当做类中的全局变量
class GetData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
很简单,就是接收实例化时传入的参数:获取根目录路径、子目录路径
然后将两个路径进行组合,就得到了目标数据集的路径
我们将这个路径作为参数传入listdir函数,从而让img_path_list中存储该目录下所有文件名(包含后缀名)
此时通过索引就可以轻松获取每个文件名
接下来,我们要使用这些初始化的信息去获取其中的每一个图片的JpegImageFile对象
完善__getitem__方法
我们在初始化中,已经通过组装数据集路径,进而通过listdir方法获取了数据集中每个文件的文件名,存入了一个列表中。
在__getitem__方法中,默认会有一个 item 参数,常命名为 idx,这个参数是一个索引编号,用于对我们初始化中得到的文件名列表进行索引访问,我们就得到了具体的文件名,然后与根目录、子目录再次组装,得到具体数据的相对路径,我们可以通过这个路径获取到索引编号对应的数据对象本身。
这样巧妙的让索引与数据集中的具体数据对应了起来
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 从文件名列表中获取了文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 组装路径,获得了图片具体的路径
获取了具体的图像路径后,我们需要使用PIL读取这个图像
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
此处img是一个JpegImageFile对象,label是一个字符串
自此,这个函数我们就实现完成了
以后使用这个类进行实例化时,传入的参数是根目录路径,以及对应的label名,我们就可以得到一个GetData对象。
有了这个GetData对象后,我们可以直接使用索引来获取具体的图像对象(类:JpegImageFile),因为__getitem__方法已经帮我们实现了,我们只需要使用索引即可调用__getitem__方法,会返回我们根据索引提取到的对应数据的图像对象以及其label
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GetData(root_dir, bees_label_dir)
img1, label1 = ants_dataset[0] # 返回一个元组,返回值是__getitem__方法的返回值
img2, label2 = bees_dataset[0]
完善__len__方法
__len__实现很简单
主要功能是获取数据集的长度,由于我们在初始化中已经获取了所有文件名的列表,所以只需要知道这个列表的长度,就知道了有多少个文件,也就是知道了有多少个具体的数据
def __len__(self):
return len(self.img_path_list)
组合数据集
我们还可以将两个数据集对象进行组合,组合成一个大的数据集对象
train_dataset = ants_dataset + bees_dataset
我们看看这三个数据集对象的大小(在python Console中):
len1 = len(ants_dataset)
len2 = len(bees_dataset)
len3 = len(train_dataset)
输出:
124
121
245
我们可以看到刚好 $$124 + 121 = 245$$
而对这个组合的数据集的访问也很有意思,也同样是使用索引,0 ~ 123 都是ants数据集的内容,124 - 244 都是bees数据集的内容
img1, label1 = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()
完整代码
from torch.utils.data import Dataset
from PIL import Image
import os
class GetData(Dataset):
# 初始化为整个class提供全局变量,为后续方法提供一些量
def __init__(self, root_dir, label_dir):
# self
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path_list[idx] # 只获取了文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每个图片的位置
# 读取图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = GetData(root_dir, ants_label_dir)
bees_dataset = GeyData(root_dir, bees_label_dir)
img, lable = ants_dataset[0] # 返回一个元组,返回值就是__getitem__的返回值
# 获取整个训练集,就是对两个数据集进行了拼接
train_dataset = ants_dataset + bees_dataset
len1 = len(ants_dataset) # 124
len2 = len(bees_dataset) # 121
len = len(train_dataset) # 245
img1, label1 = train_dataset[123] # 获取的是蚂蚁的最后一个
img2, label2 = train_dataset[124] # 获取的是蜜蜂第一个