Python 加载mnist、cifar数据
import tensorflow.examples.tutorials.mnist.input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
1、加载mnist数据
执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:
下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train
),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test
)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。
print (mnist.train.images.shape) print (mnist.train.labels.shape) print (mnist.validation.images.shape) print (mnist.validation.labels.shape) print (mnist.test.images.shape) print (mnist.test.labels.shape)
(55000, 784)
(55000, 10)
(5000, 784)
(5000, 10)
(10000, 784)
(10000, 10)
在训练过程中可以按批次获取
from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集 mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集 X_mb, _ = mnist.train.next_batch(128) print(X_mb.shape)
Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
(128, 784)
2、加载cifar数据
import torch import torchvision.datasets as dsets import torchvision.transforms as transforms transform = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) def load_data_CIFAR10(): train_dataset = dsets.CIFAR10(root='./data/', train=True,download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) return train_loader train_loader = load_data_CIFAR10()
Using downloaded and verified file: ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data/
cifar-10 训练集和测试集分别有50000和10000张图片,RGB3通道,尺寸32×32,
一个样本由3037个字节组成,其中第一个字节是label,剩余3036(32*32*3)个字节是image,每个文件由连续的10000个样本组成,打开文件,发现是一堆二进制数据
https://www.cnblogs.com/denny402/p/5852689.html