自己动手读取MNIST数据集并存入四个np.array

从下载http://yann.lecun.com/exdb/mnist/四个.gz压缩包:
2020-01-10 19-56-43 的屏幕截图
他们分别是训练用数据、训练用标签、测试用数据、测试用标签。

然后将他们放入一个名为dataPath的文件夹中,我放入的是/home/zzz/intern/data:
2020-01-10 19-58-51 的屏幕截图

然后是读取数据的代码,readData()函数返回的就是四个np.array

import gzip
import numpy as np
def read_idx3(filename):
    with gzip.open(filename, 'rb') as fo:
        buf = fo.read()
        index = 0
        header = np.frombuffer(buf, '>i', 4, index)
        index += header.size * header.itemsize
        data = np.frombuffer(buf, '>B', header[1]*header[2]*header[3], index).reshape(header[1],-1)
        return data

def read_idx1(filename):
    with gzip.open(filename, 'rb') as fo:
        buf = fo.read()
        index = 0
        header = np.frombuffer(buf, '>i', 2, index)
        index += header.size * header.itemsize
        data = np.frombuffer(buf, '>B', header[1], index)
        return data

def readData(dataPath):
    X_train = read_idx3(dataPath + '/train-images-idx3-ubyte.gz')  # 训练数据集的样本特征
    y_train = read_idx1(dataPath + '/train-labels-idx1-ubyte.gz')  # 训练数据集的标签
    X_test = read_idx3(dataPath + '/t10k-images-idx3-ubyte.gz')  # 测试数据集的样本特征
    y_test = read_idx1(dataPath + '/t10k-labels-idx1-ubyte.gz')  # 测试数据集的标签
    return X_train, y_train, X_test, y_test

可以输出一下他们的维度:

if __name__=="__main__":
    dataPath = "/home/zzz/intern/data"
    X_train, y_train, X_test, y_test = readData(dataPath)
    print(X_train.shape, y_train.shape)
    print(X_test.shape, y_test.shape)

如果结果如下图所示即为正确:
2020-01-10 20-01-20 的屏幕截图

posted @ 2020-01-10 20:02  明卿册  阅读(590)  评论(0编辑  收藏  举报