Dataset 和 DataLoader 详解
Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少需要覆写下面两个方法:
1)__len__:一般用来返回数据集大小。
2)__getitem__:实现这个方法后,可以通过下标的方式 dataset[i] 的来取得第 $i$ 个数据。
DataLoader 本质上就是一个 iterable(内部定义了 __iter__ 方法),__iter__ 被定义成生成器,使用 yield 来返回数据,
并利用多进程来加速 batch data 的处理,DataLoader 组装好数据后返回的是 Tensor 类型的数据。
注意:DataLoader 是间接通过 Dataset 来获得数据的,然后进行组装成一个 batch 返回,因为采用了生成器,所以每次只会组装
一个 batch 返回,不会一次性组装好全部的 batch,所以 DataLoader 节省的是 batch 的内存,并不是指数据集的内存,数据集可
以一开始就全部加载到内存里,也可以分批加载,这取决于 Dataset 中 __init__ 函数的实现。
举个例子:
import torch import numpy as np from torch.utils.data import Dataset from torch.utils.data import DataLoader class DiabetesDataset(Dataset): def __init__(self, filepath): # 因为数据集比较小,所以全部加载到内存里了 data = np.loadtxt(filepath, delimiter=',', dtype=np.float32) self.len = data.shape[0] self.x_data = torch.from_numpy(data[:,:-1]) self.y_data = torch.from_numpy(data[:,[-1]]) def __getitem__(self, index): return self.x_data[index], self.y_data[index] def __len__(self): return self.len dataset = DiabetesDataset('diabetes.csv.gz') train_loader = DataLoader(dataset=dataset, # 传递数据集 batch_size=32, # 小批量的数据大小,每次加载一batch数据 shuffle=True, # 打乱数据之间的顺序 num_workers=2) # 使用多少个子进程来加载数据,默认为0, 代表使用主线程加载batch数据 for epoch in range(100): # 训练 100 轮 for i, data in enumerate(train_loader, 0): # 每次惰性返回一个 batch 数据 iuputs, label = data ...