如何自定义数据集
pytorch读取图片,主要是通过Dataset类。
Dataset类源代码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
这个类中最核心的就是getitem函数,上面介绍中写的是这个函数提供一个合理范围内的index。我们在自己定义数据集的时候,在这个类中,我们一般是定义这个函数的功能是接受一个index,然后返回图片数据和标签。所以在这个函数中,需要包含打开图片的函数和获取图片lable的语句
getitem函数接受的是一个index,这个index通常指的是一个list中index,这个list中的每个元素就是对应的每个图片的文件路径和标签。
所以在读取自己数据的时候基本流程就是这样的:
首先制作图片存储路径和标签信息的txt
然后将这个信息转化为list
通过这个list中的index,使用getitem函数,我们获取对应的图片数据和标签信息
现在问题是如何制作这个一个list。这个东西我们一般是外部制作就好,保存为一个txt格式就好
然后我们制作一个Dataset子类
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.strip()
words = line.split()
imgs.append((words[0], int(words[1]))) # words[0]是路径 words[1]是类别数
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
注意看我自己定义的类,在初始化函数中,我通过对txt文件的读取,得到了一个list,也就是self.imgs
然后在__getitem__ 函数中,通过index,我们得到文件路径和lable,然后使用open函数,将图像文件打开并转化为RGB数据,同时进行一些相应的转化
这个部分建立好了,其实自定义数据集基本就好了,因为接下来的操作就交给了DataLoder,代码基本不需要变化。
我现在有一个思考,就是说上面我说的图像数据,如果是文本数据呢?我如何进行自定义数据呢?