Python 加载mnist、cifar数据


import
tensorflow.examples.tutorials.mnist.input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

1、加载mnist数据

执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:

下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。

print (mnist.train.images.shape)
print (mnist.train.labels.shape)
print (mnist.validation.images.shape)
print (mnist.validation.labels.shape)
print (mnist.test.images.shape)
print (mnist.test.labels.shape)

(55000, 784)
(55000, 10)
(5000, 784)
(5000, 10)
(10000, 784)


(10000, 10)

在训练过程中可以按批次获取

from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集
X_mb, _ = mnist.train.next_batch(128)
print(X_mb.shape)

Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
(128, 784)

 2、加载cifar数据

import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
def load_data_CIFAR10():
    train_dataset = dsets.CIFAR10(root='./data/', train=True,download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    return train_loader
train_loader = load_data_CIFAR10()

Using downloaded and verified file: ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data/

cifar-10  训练集和测试集分别有50000和10000张图片,RGB3通道,尺寸32×32, 

一个样本由3037个字节组成,其中第一个字节是label,剩余3036(32*32*3)个字节是image,每个文件由连续的10000个样本组成,打开文件,发现是一堆二进制数据

 

 

https://www.cnblogs.com/denny402/p/5852689.html

posted @ 2020-02-23 14:31  小娜子成长记  阅读(655)  评论(0编辑  收藏  举报