TensorFlow数据集项目

TensorFlow数据集(TFDS)项目使下载通用数据集变得非常容易,从小型数据集(如MNIST或Fashion MNIST)到大型数据集(如ImageNet)。该列表包括图像数据集、文本数据集(包括翻译数据集)以及音频和视频数据集。

TFDS没有和TensorFlow捆绑在一起,因此需要安装tensorflow_datasets库。然后调用tfds.load()函数,它会下载想要的数据,并将该数据作为数据集的目录返回(通常一个用于训练,另一个用于测试,但这取决于选择的数据集)例如,下载MNIST:

import tensorflow_datasets as tfds
dataset=tfds.load(name='mnist')
mnist_train,mnist_test=dataset['train'],dataset['test']
mnist_train=mnist_train.shuffle(10000).batch(32).prefetch(1)
for item in mnist_train:
    images=item['image']
    labels=item['label']

load()函数对下载的每个数据碎片进行乱序(仅针对训练集)

训练集中的项目都是包含特征和标签的字典。但是Keras希望每个项目都是一个包含两个元素(同样是特征和标签)的元组。可以使用map()方法转换数据集:

mnist_train=mnist_train.shuffle(10000).batch(32)
mnist_train=mnist_train.map(lambda item:(item['image'],item['label']))
mnist_train=mnist_train.prefetch(1)

通过设置as_supervied=True来使load()函数执行此操作会更简单。也可以根据需要指定批处理大小。然后直接将数据传递给tf.keras模型

dataset=tfds.load(name='mnist',batch_size=32,as_surpervised=True)
mnist_train=dataset['train'].prefetch(1)
model=keras.models.Sequential([...])
model.compile(loss='sparse_categorical_crossentropy',optimizer='sgd')
model.fit(mnist_train,epochs=5)
posted @ 2021-10-31 15:17  里列昂遗失的记事本  阅读(107)  评论(0编辑  收藏  举报