Pytorch:torch.utils.data.DataLoader与迭代器转换

torch.utils.data.DataLoader与迭代器转换

在做实验时,我们常常会使用用开源的数据集进行测试。而Pytorch中内置了许多数据集,这些数据集我们常常使用DataLoader类进行加载。
如下面这个我们使用DataLoader类加载torch.vision中的FashionMNIST数据集。

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

我们接下来定义Dataloader对象用于加载这两个数据集:

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

那么这个train_dataloader究竟是什么类型呢?

print(type(train_dataloader))  # <class 'torch.utils.data.dataloader.DataLoader'>

我们可以将先其转换为迭代器类型。

print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>

然后再使用next(iter(train_dataloader))从迭代器里取数据,如下所示:

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0]
label = train_labels[0]
plt.imshow(torch.permute(img, (1, 2, 0)))
plt.show()
print(f"Label: {label}")

可以看到我们成功获取了数据集中第一张图片的信息,控制台打印:

Feature batch shape: torch.Size([64, 3, 32, 32])
Labels batch shape: torch.Size([64])
Label: 1

图片可视化显示如下:
NLP多任务学习

PS: 事实上我们也可以直接索引datasets对象以访问第1个样本的图片数据和标签数据:

print(training_data[0][0].shape) # torch.Size([3, 32, 32])
print(training_data[0][1]) # 6

这里training_data[0]是第一个样本对应的(图片, 标签)元组。training_data[0][0]是第一个样本的图片数据,training_data[0][1]则是第一个样本的标签数据。如上所示,通过直接索引得到的单张图片通道也在第一维。这种顺序也就是经典的NCHW(N、C、H、W分别为批量、通道、图片高度、图片宽度维度)。Caffe的通道顺序也是NCHW,但Tensorlfow1.*和Tensorlfow2.*的顺序都为HWC顺序吗?那是什么时候变成的CHW顺序呢?原来,这是因为我们在最开始时设置了transform=ToTensor(),而 ToTensor()函数会在获取图片时将其顺序变为CHW(关于ToTensor()函数更多的作用,可参见我的博客《Pytorch:以单通道(灰度图)加载图片》)。

此外,我们也访问datasets对象的datatarget属性来分别获得第1个样本的图片数据和标签数据,如下列命令所示:

print(training_data.data[0].shape) # (32, 32, 3)
print(training_data.targets[0]) # 6

注意,此时datasets对象的data属性获得的单张图片中通道在最后一维! 如果再将这样的图片数据用Dataloader对象进行迭代,则最后获得的单张图片中通道仍会保持在最后一维:


training_data = training_data.data

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

train_features = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}") # torch.Size([64, 32, 32, 3])

诶,我们不是设置了ToTensor()函数吗?原来,ToTensor函数是在对datasets对象调用__getitem__方法时触发调用的(这里使用了惰性求值(lazy evaluation)的思想),__getitem__方法大致如下所示:

def __getitem__(self, idx):

    x, y = self.dataset[idx]
    
    if self.transform:
        x = self.transform(x)
    
    return x, y   

而我们使用training_data.data[0]获取数据,是在对training_data.data进行索引,而没有对datasets对象本身进行索引的操作,就不会去调用datasets对象的__getitem__方法,自然就不会进行图片维度顺序的转换了。

最后,这里再说过个题外话,这里Pytorch默认的三个通道像素顺序为RGB,事实上PIL库、Tensorflow1.*/Tensorflow2.*和我们日常图片存储的通道像素顺序都是RGB,但并非所有软件都是如此。例如OpenCV的通道像素顺序就为BGR。

上面提到的这些点在做实验时都需要额外注意。

接下来我们言归正传,接着来看DataLoader迭代器。有读者可能就会产生疑问,很多时候我们并没有将DataLoader类型强制转换成迭代器类型呀,大多数时候我们会写如下代码:

for train_features, train_labels in train_dataloader: 
    print(train_features.shape) # torch.Size([64, 3, 32, 32])
    print(train_features[0].shape) # torch.Size([3, 32, 32])
    
    img = train_features[0]
    label = train_labels[0]
    plt.imshow(torch.permute(img, (1, 2, 0)))
    plt.show()
    print(f"Label: {label}")

可以看到,该代码也能够正常迭代训练数据,前三个样本的控制台打印输出为:

torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 6
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 7
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])
Label: 9
torch.Size([64, 3, 32, 32])
torch.Size([3, 32, 32])

那么为什么我们这里没有显式将Dataloader转换为迭代器类型呢,其实是Python语言for循环的一种机制,一旦我们用for ... in ...句式来迭代一个对象,那么Python解释器就会偷偷地自动帮我们创建好迭代器,也就是说

for train_features, train_labels in train_dataloader:

实际上等同于

for train_features, train_labels in iter(train_dataloader):

更进一步,这实际上等同于

train_iterator = iter(train_dataloader)
try:
    while True:
        train_features, train_labels = next(train_iterator)
except StopIteration:
    pass

推而广之,我们在用Python迭代直接迭代列表时:

for x in [1, 2, 3, 4]:

其实Python解释器已经为我们隐式转换为迭代器了:

list_iterator = iter([1, 2, 3, 4])
try:
    while True:
        x = next(list_iterator)
except StopIteration:
    pass

参考

  • [1] https://pytorch.org/
  • [2] Martelli A, Ravenscroft A, Ascher D. Python cookbook[M]. " O'Reilly Media, Inc.", 2005.
posted @ 2021-12-06 17:59  orion-orion  阅读(1254)  评论(1编辑  收藏  举报