第四讲 网络八股拓展--训练增强的mnist数据集
#注意训练时间较长,建议在google colab上计算。
1 import tensorflow as tf 2 from tensorflow.keras.preprocessing.image import ImageDataGenerator 3 4 mnist = tf.keras.datasets.mnist 5 (x_train, y_train), (x_test, y_test) = mnist.load_data() 6 x_train, x_test = x_train/255.0, x_test/255.0 7 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) 8 9 10 image_gen_train = ImageDataGenerator( 11 rescale = 1/ 1., 12 rotation_range = 45, 13 width_shift_range = .15, 14 height_shift_range = .15, 15 horizontal_flip = False, 16 zoom_range = 0.5 17 ) 18 19 image_gen_train.fit(x_train) 20 21 22 model = tf.keras.models.Sequential([ 23 tf.keras.layers.Flatten(), 24 tf.keras.layers.Dense(128, activation='relu'), 25 tf.keras.layers.Dense(10, activation='softmax') 26 ]) 27 28 model.compile(optimizer='adam', 29 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 30 metrics = ['sparse_categorical_accuracy']) 31 32 model.fit(image_gen_train.flow(x_train, y_train, batch_size=32), epochs=5, validation_data=(x_test, y_test), validation_freq=2) 33 model.summary()