Pytorch:torch.utils.data.DataLoader与迭代器转换
torch.utils.data.DataLoader与迭代器转换
在做实验时,我们常常会使用用开源的数据集进行测试。而Pytorch中内置了许多数据集,这些数据集我们常常使用DataLoader
类进行加载。
如下面这个我们使用DataLoader
类加载torch.vision
中的FashionMNIST
数据集。
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.CIFAR10(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.CIFAR10(
root="data",
train=False,
download=True,
transform=ToTensor()
)
我们接下来定义Dataloader
对象用于加载这两个数据集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那么这个train_dataloader
究竟是什么类型呢?
print(type(train_dataloader)) # <class 'torch.utils.data.dataloader.DataLoader'>
我们可以将先其转换为迭代器类型。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然后再使用next(iter(train_dataloader))
从迭代器里取数据,如下所示:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")
可以看到我们成功获取了数据集中第一张图片的信息,控制台打印:
Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])
Label: 1
图片可视化显示如下:
PS: 事实上我们也可以直接索引
datasets
对象以访问第1个样本的图片数据和标签数据:
print(training_data[0][0].shape) # torch.Size([3, 32, 32])
print(training_data[0][1]) # 6
这里
training_data[0]
是第一个样本对应的(图片, 标签)
元组。training_data[0][0]
是第一个样本的图片数据,training_data[0][1]
则是第一个样本的标签数据。如上所示,通过直接索引得到的单张图片通道也在第一维。这种顺序也就是经典的NCHW(N、C、H、W分别为批量、通道、图片高度、图片宽度维度)。Caffe的通道顺序也是NCHW,但Tensorlfow1.*和Tensorlfow2.*的顺序都为HWC顺序吗?那是什么时候变成的CHW顺序呢?原来,这是因为我们在最开始时设置了transform=ToTensor()
,而ToTensor()
函数会在获取图片时将其顺序变为CHW(关于ToTensor()
函数更多的作用,可参见我的博客《Pytorch:以单通道(灰度图)加载图片》)。此外,我们也访问
datasets
对象的data
和target
属性来分别获得第1个样本的图片数据和标签数据,如下列命令所示:
print(training_data.data[0].shape) # (32, 32, 3)
print(training_data.targets[0]) # 6
注意,此时
datasets
对象的data
属性获得的单张图片中通道在最后一维! 如果再将这样的图片数据用Dataloader
对象进行迭代,则最后获得的单张图片中通道仍会保持在最后一维:
training_data = training_data.data
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # torch.Size([64, 32, 32, 3])
诶,我们不是设置了
ToTensor()
函数吗?原来,ToTensor
函数是在对datasets
对象调用__getitem__
方法时触发调用的(这里使用了惰性求值(lazy evaluation)的思想),__getitem__
方法大致如下所示:
def __getitem__(self, idx):
x, y = self.dataset[idx]
if self.transform:
x = self.transform(x)
return x, y
而我们使用
training_data.data[0]
获取数据,是在对training_data.data
进行索引,而没有对datasets
对象本身进行索引的操作,就不会去调用datasets
对象的__getitem__
方法,自然就不会进行图片维度顺序的转换了。最后,这里再说过个题外话,这里Pytorch默认的三个通道像素顺序为RGB,事实上PIL库、Tensorflow1.*/Tensorflow2.*和我们日常图片存储的通道像素顺序都是RGB,但并非所有软件都是如此。例如OpenCV的通道像素顺序就为BGR。
上面提到的这些点在做实验时都需要额外注意。
接下来我们言归正传,接着来看DataLoader
迭代器。有读者可能就会产生疑问,很多时候我们并没有将DataLoader
类型强制转换成迭代器类型呀,大多数时候我们会写如下代码:
for train_features, train_labels in train_dataloader:
print(train_features.shape) # torch.Size([64, 3, 32, 32])
print(train_features[0].shape) # torch.Size([3, 32, 32])
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")
可以看到,该代码也能够正常迭代训练数据,前三个样本的控制台打印输出为:
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 6
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 7
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 9
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
那么为什么我们这里没有显式将Dataloader
转换为迭代器类型呢,其实是Python语言for循环的一种机制,一旦我们用for ... in ...
句式来迭代一个对象,那么Python解释器就会偷偷地自动帮我们创建好迭代器,也就是说
for train_features, train_labels in train_dataloader:
实际上等同于
for train_features, train_labels in iter(train_dataloader):
更进一步,这实际上等同于
train_iterator = iter(train_dataloader)
try:
while True:
train_features, train_labels = next(train_iterator)
except StopIteration:
pass
推而广之,我们在用Python迭代直接迭代列表时:
for x in [1, 2, 3, 4]:
其实Python解释器已经为我们隐式转换为迭代器了:
list_iterator = iter([1, 2, 3, 4])
try:
while True:
x = next(list_iterator)
except StopIteration:
pass
参考
- [1] https://pytorch.org/
- [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.