tensorflow学习014——tf.data运用实例

3.2tf.data运用实例

使用tf.data作为输入,改写之前写过的MNIST代码

点击查看代码
import tensorflow as tf
#下载数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
#对图片数据进行归一化
train_images = train_images / 255
test_images = test_images / 255

ds_train_images = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
#zip到一起,为了后面的shuffle,否则image与label的会对应错误
ds_train = tf.data.Dataset.zip((ds_train_images,ds_train_labels))

ds_train  = ds_train.shuffle(10000).repeat().batch(4)

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28,28)),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation= 'softmax')
])
model.compile(optimizer = 'adam',
              loss= 'sparse_categorical_crossentropy',
              metrics = ['accuracy'])
ds_test = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
ds_test = ds_test.batch(4)
steps_per_epoch = train_images.shape[0] / 4 #表明每轮训练多少步,这是因为上面对dataser进行了repeat()所以需要指定每一轮训练多少步
model.fit(ds_train,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=ds_test) 

posted @ 2021-11-20 15:29  白菜茄子  阅读(48)  评论(0编辑  收藏  举报