Fashion MNIST的下载与导入
在动手写深度学习的TensorFlow实现版本中,需要用到数据集Fashion MNIST,如果直接用TensorFlow导入数据集:
from tensorflow.keras.datasets import fashion_mnist (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
就会报错,下载数据集时会显示服务器连接超时,可能因为服务器在国内被墙了。
下面是如何手动下载数据集并导入的步骤:
1.下载数据集
去GitHub上该数据集的主页下载:https://github.com/zalandoresearch/fashion-mnist
下载完成后解压放在./data/fashion/文件夹下
接下导入数据集:
import mnist_reader x_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train') x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')
注意这里面的mnist_reader是GitHub上该项目里面的一个文件,不要以为是某个库
可以直接clone整个项目,再把这个文件放在和上文data相同的文件夹下
不想下载这个项目呢,这里给出这个文件的具体代码,在导入数据集时把这个函数加入到你的代码中也可以:
def load_mnist(path, kind='train'): import os import gzip import numpy as np """Load MNIST data from `path`""" labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) with gzip.open(labels_path, 'rb') as lbpath: labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(images_path, 'rb') as imgpath: images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784) return images, labels
最后可以测试一下是否导入成功:
最后如果你还是导入不成功,或者GitHub上数据集你就是下载不下来,可以私信我。
转载请注明出处