第四讲 网络八股拓展--训练增强的mnist数据集

#注意训练时间较长,建议在google colab上计算。

 

 1 import tensorflow as tf
 2 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 3 
 4 mnist = tf.keras.datasets.mnist
 5 (x_train, y_train), (x_test, y_test) = mnist.load_data()
 6 x_train, x_test = x_train/255.0, x_test/255.0
 7 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
 8 
 9 
10 image_gen_train = ImageDataGenerator(
11     rescale = 1/ 1.,
12     rotation_range = 45,
13     width_shift_range = .15, 
14     height_shift_range = .15,
15     horizontal_flip = False,
16     zoom_range = 0.5
17 )
18 
19 image_gen_train.fit(x_train)
20 
21 
22 model = tf.keras.models.Sequential([
23       tf.keras.layers.Flatten(),
24       tf.keras.layers.Dense(128, activation='relu'),
25       tf.keras.layers.Dense(10, activation='softmax')
26 ])
27 
28 model.compile(optimizer='adam',
29               loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
30               metrics = ['sparse_categorical_accuracy'])
31 
32 model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test), validation_freq=2)
33 model.summary()

 

posted @ 2020-05-06 21:19  WWBlog  阅读(403)  评论(0编辑  收藏  举报