pytorch(二)数据准备工作ETL

1.Extract: 从网络等下载图像数据集
2.Transform: 图片---->tensor
3.Loader: tensor装进数据流管道,以便获取到流出batch长度数据。
()


1.torch.utils.data.datasets ---(Extract, Transform)

抽象类:具有必须待实现(重写)的方法的Python类. 因此我们可以通过扩展这个抽象类的功能,创建子类来构造自定义数据集。

需要重写-override的函数:

  • len:实现数据集长度功能
  • getitem:实现对数据的位置索引,可以根据索引来访问数据元素
#这里直接继承了MNIST类,MNIST类也是继承Dataset类实现。
class FashionMNIST(MNIST):(
    urls = [
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
    ]

)

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
"""
#设置保存路径
#选择训练集,默认测试集
#是否下载,下载前方法会检查目录下是否已经,不会重复下载
#通过torchvison.transforms对数据做变换
#开代理会很快

"""

2.torch.utils.data.dataloader

train_loader = torch.utils.data.DataLoader(train_set, batch_size = 1000,shuffle = True,num_workers = 0)
#num_workers默认为0表示用主进程来装载数据
# 注意:Pytorch multiprocessing does not work on Windows!!!因此在windows系统上此处必须设置为0,即只有一个进程存在。
#数据成可训练的pipeline最后只需要给三个入参:数据集,批次,是否打乱。
#batch_size:  data size of  per batch
"""
batch_size小对训练的影响:
1.很小比如为1,训练震荡严重,不易收敛
2.增大,下山路线,开始变正确
3.继续增大,已经足够准确,不再变化
4.但是随着batch_size增大,相同的epoch次数下,迭代次数变小,因此需要注意在增大batch_size的同时,增加epoch。不能增大了batch导致迭代次数明显减少,会导致最优化效果变差。
"""

3.利用dataset和dataloader探索数据

train_set.targets#每个的标签

train_set.targets.bincount()
#统计每个标签类别的数目
#用途:检查是否有严重的类别不平衡问题,class imbalanace

sample = next(iter(train_set))
#iter和next都是python内置函数(既可以用在train_set,也可以用在封装好的dataloader)
# iter返回迭代器对象, next获取一个迭代器元素

len(sample)
print(type(sample))
#从torchvision获取的数据集每个sample样式:(tentor(img),tensor(label))

image.shape
# torch.Size([1, 28, 28]) 

plt.imshow(image.squeeze(), cmap="gray") # 因为plt.imshow对于单通道显示格式是 H W

display_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
images, labels = next(iter(display_loader)  # len(next(iter(...))) = 2

print('types:', type(images), type(labels))
# types: <class 'torch.Tensor'> <class 'torch.Tensor'>

print('shapes:', images.shape, labels.shape)
# shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])
#两种可视化数据集的方式
#1.torchvision.utils.make_grid()

grid = torchvision.utils.make_grid(images, nrow=10)
plt.figure(figsize=(15,15))
plt.imshow(np.transpose(grid, (1,2,0))) #plt.imshow(grid.permute(1,2,0))

#2.利用DataLoader显示

how_many_to_plot = 20

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)

plt.figure(figsize=(50,50))
for i, img_batch in enumerate(train_loader, start=1):
#enumerate(sequence,start=1):枚举出的就是(i,img_batch)形状,如(1,a),(2,b), (3,c)
    image, label = img_batch
    plt.subplot(10,10,i)
    plt.imshow(image.reshape(28,28), cmap='gray')
    plt.axis('off')
    plt.title(train_set.classes[label.item()], fontsize=28)
    if (i >= how_many_to_plot): break
plt.show()
posted @ 2020-06-09 10:03  Parallax  阅读(311)  评论(0编辑  收藏  举报