Fashion MNIST的下载与导入

在动手写深度学习的TensorFlow实现版本中,需要用到数据集Fashion MNIST,如果直接用TensorFlow导入数据集:

from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

就会报错,下载数据集时会显示服务器连接超时,可能因为服务器在国内被墙了。

下面是如何手动下载数据集并导入的步骤:

1.下载数据集

去GitHub上该数据集的主页下载:https://github.com/zalandoresearch/fashion-mnist

 

 下载完成后解压放在./data/fashion/文件夹下

 

 接下导入数据集:

 

import mnist_reader

x_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')

注意这里面的mnist_reader是GitHub上该项目里面的一个文件,不要以为是某个库

 

 可以直接clone整个项目,再把这个文件放在和上文data相同的文件夹下

 

 

 不想下载这个项目呢,这里给出这个文件的具体代码,在导入数据集时把这个函数加入到你的代码中也可以:

def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

最后可以测试一下是否导入成功:

 

 最后如果你还是导入不成功,或者GitHub上数据集你就是下载不下来,可以私信我。

posted @ 2020-11-20 16:08  荒唐了年少  阅读(7659)  评论(1编辑  收藏  举报