机器学习-利用pickle加载cifar文件

 

首先这里有百度云的数据集供大家下载:(官网太慢了)

链接:https://pan.baidu.com/s/1G0MxZIGSK_DyZTcuNbxraQ
提取码:ui51
复制这段内容后打开百度网盘手机App,操作更方便哦

然后奉献代码

def load_CIFAR10(ROOT):
    """ 载入cifar全部数据 """
    xs = []
    ys = []
    for b in range(1, 2):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)         #将所有batch整合起来
        ys.append(Y)
    Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

找到cifar文件夹下面的二进制文件:

然后对每次的文件进行批处理:

def load_CIFAR_batch(filename):
    """ 直接读入cifar数据集的一个batch """
    with open(filename, 'rb') as f:
        datadict = p.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y

测试:

import numpy as np

# 载入CIFAR-10数据集
cifar10_dir = 'data\cifar10\cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

# 看看数据集中的一些样本:每个类别展示一些
print('训练数据的形状: ', X_train.shape)
print('训练集标签的形状: ', y_train.shape)
print('测试数据的形状: ', X_test.shape)
print('测试数据的形状: ', y_test.shape)
import pickle as p
import os


def load_CIFAR_batch(filename):
    """ 载入cifar数据集的一个batch """
    with open(filename, 'rb') as f:
        datadict = p.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y


def load_CIFAR10(ROOT):
    """ 载入cifar全部数据 """
    xs = []
    ys = []
    for b in range(1, 2):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)         #将所有batch整合起来
        ys.append(Y)
    Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

if __name__ == '__main__':
    import numpy as np


    # 载入CIFAR-10数据集
    cifar10_dir = 'data\cifar10\cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # 看看数据集中的一些样本:每个类别展示一些
    print('Training data shape: ', X_train.shape)
    print('Training labels shape: ', y_train.shape)
    print('Test data shape: ', X_test.shape)
    print('Test labels shape: ', y_test.shape)

 

posted @ 2019-07-22 15:33  Timcode  阅读(638)  评论(0编辑  收藏  举报