pytorch 从Dataset类中获取数据
转自:https://www.jianshu.com/p/4818a1a4b5bd
1.介绍
Dataset类是为torch.utils.data.DataLoader做准备,支持两种类型的访问
* map-style datasets #__getitem__()
* iterable-style datasets #__iter__()
(1) print("trainDataset 的类型:", type(trainDataset)) >>> trainDataset 的类型: <class 'torchvision.datasets.mnist.MNIST'> (2) print("trainDataset 的长度:", len(trainDataset)) >>> trainDataset 的长度: 60000 (3) print("trainDataset[0] 的类型:", type(trainDataset[0])) print("trainDataset[0] 的长度:", len(trainDataset[0])) >>> trainDataset[0] 的类型: <class 'tuple'> trainDataset[0] 的长度: 2 (4) print("trainDataset[0][0] 的类型:", type(trainDataset[0][0])) print("trainDataset[0][0] 的形状:", trainDataset[0][0].shape) >>> trainDataset[0][0] 的类型: <class 'torch.Tensor'> trainDataset[0][0] 的形状: torch.Size([1, 28, 28]) (5) print("trainDataset[0][1] 的类型:", type(trainDataset[0][1])) print("trainDataset[0][1] :", trainDataset[0][1]) >>> trainDataset[0][1] 的类型: <class 'int'> trainDataset[0][1] : 5
从上述代码可以看到,能够通过一些方法去访问。