pytorch(二)数据准备工作ETL
1.Extract: 从网络等下载图像数据集
2.Transform: 图片---->tensor
3.Loader: tensor装进数据流管道,以便获取到流出batch长度数据。()
1.torch.utils.data.datasets ---(Extract, Transform)
抽象类:具有必须待实现(重写)的方法的Python类. 因此我们可以通过扩展这个抽象类的功能,创建子类来构造自定义数据集。
需要重写-override的函数:
- len:实现数据集长度功能
- getitem:实现对数据的位置索引,可以根据索引来访问数据元素
#这里直接继承了MNIST类,MNIST类也是继承Dataset类实现。
class FashionMNIST(MNIST):(
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
)
train_set = torchvision.datasets.FashionMNIST(
root='./data'
,train=True
,download=True
,transform=transforms.Compose([
transforms.ToTensor()
])
)
"""
#设置保存路径
#选择训练集,默认测试集
#是否下载,下载前方法会检查目录下是否已经,不会重复下载
#通过torchvison.transforms对数据做变换
#开代理会很快
"""
2.torch.utils.data.dataloader
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 1000,shuffle = True,num_workers = 0)
#num_workers默认为0表示用主进程来装载数据
# 注意:Pytorch multiprocessing does not work on Windows!!!因此在windows系统上此处必须设置为0,即只有一个进程存在。
#数据成可训练的pipeline最后只需要给三个入参:数据集,批次,是否打乱。
#batch_size: data size of per batch
"""
batch_size小对训练的影响:
1.很小比如为1,训练震荡严重,不易收敛
2.增大,下山路线,开始变正确
3.继续增大,已经足够准确,不再变化
4.但是随着batch_size增大,相同的epoch次数下,迭代次数变小,因此需要注意在增大batch_size的同时,增加epoch。不能增大了batch导致迭代次数明显减少,会导致最优化效果变差。
"""
3.利用dataset和dataloader探索数据
train_set.targets#每个的标签
train_set.targets.bincount()
#统计每个标签类别的数目
#用途:检查是否有严重的类别不平衡问题,class imbalanace
sample = next(iter(train_set))
#iter和next都是python内置函数(既可以用在train_set,也可以用在封装好的dataloader)
# iter返回迭代器对象, next获取一个迭代器元素
len(sample)
print(type(sample))
#从torchvision获取的数据集每个sample样式:(tentor(img),tensor(label))
image.shape
# torch.Size([1, 28, 28])
plt.imshow(image.squeeze(), cmap="gray") # 因为plt.imshow对于单通道显示格式是 H W
display_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
images, labels = next(iter(display_loader) # len(next(iter(...))) = 2
print('types:', type(images), type(labels))
# types: <class 'torch.Tensor'> <class 'torch.Tensor'>
print('shapes:', images.shape, labels.shape)
# shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])
#两种可视化数据集的方式
#1.torchvision.utils.make_grid()
grid = torchvision.utils.make_grid(images, nrow=10)
plt.figure(figsize=(15,15))
plt.imshow(np.transpose(grid, (1,2,0))) #plt.imshow(grid.permute(1,2,0))
#2.利用DataLoader显示
how_many_to_plot = 20
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)
plt.figure(figsize=(50,50))
for i, img_batch in enumerate(train_loader, start=1):
#enumerate(sequence,start=1):枚举出的就是(i,img_batch)形状,如(1,a),(2,b), (3,c)
image, label = img_batch
plt.subplot(10,10,i)
plt.imshow(image.reshape(28,28), cmap='gray')
plt.axis('off')
plt.title(train_set.classes[label.item()], fontsize=28)
if (i >= how_many_to_plot): break
plt.show()