MNIST 数据集
机器学习的入门就是MNIST。
MNIST 数据集来自美国国家标准与技术研究所,是NIST(National Institute of Standards and Technology)的缩小版,训练集(training set)由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局(the Census Bureau)的工作人员,测试集(test set)也是同样比例的手写数字数据。
获取MNIST
MNIST 数据集可在http://yann.lecun.com/exdb/mnist/获取,图片是以字节的形式进行存储,它包含了四个部分:
Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
此数据集中,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证。测试样本:共10000个,验证数据比例相同。
from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())
数据加载
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
dataloader = DataLoader(dataset=mnist_train, batch_size=2, shuffle=True, num_workers=2)
for (images, labels) in dataloader:
print(labels)
image = make_grid(images).permute(1, 2, 0).numpy()
plt.imshow(image)
plt.show()
exit()
其中参数含义:
- dataset:提前定义的dataset的实例
- batch_size:传入数据的batch的大小,常用128,256等等
- shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
num_workers
:加载数据的线程数
transforms
由于 DataLoader 这个加载器只能加载 tensors, numpy arrays, numbers, dicts or lists
但是 found <class 'PIL.Image.Image'>,所以就很尴尬,我们需要将图片转换一下
transforms 用于图形变换,在使用时我们还可以使用 transforms.Compose
将一系列的transforms操作链接起来。
torchvision.transforms.Compose([ ts,ts,ts... ])
ts为transforms操作
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
大多数情况下我们不会只transforms 一下,所以可以用如下方案
from torchvision import transforms
transforms.Compose(
[ #文档 https://pytorch.org/vision/stable/transforms.html
transforms.ToPILImage(), # 转成PIL图片
# transforms.Resize(size), # 缩放
transforms.ToTensor(), # 变张量
transforms.Normalize(mean=(0.1307, ), std=(0.3081, )) ]
)
介绍一个概念:
transforms 处理过后,会把通道移到最前边。比如 MNIST h*w*c
为:28281
tensor处理完,通道数会提前,并且做了轴交换,变为了 c*h*w
为:12828
至于为什么要这么设计?听传言是做矩阵加减乘除以及卷积等运算是需要调用cuda和cudnn的函数的,而这些接口都设成成 chw 格式了