第四讲 网络八股拓展--图像增强
1 # 显示原始图像和增强后的图像 2 import tensorflow as tf 3 from matplotlib import pyplot as plt 4 from tensorflow.keras.preprocessing.image import ImageDataGenerator 5 import numpy as np 6 7 8 9 mnist = tf.keras.datasets.mnist 10 (x_train, y_train), (x_test, y_test) = mnist.load_data() 11 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1) 12 13 14 image_gen_train = ImageDataGenerator( 15 rescale = 1. / 255, 16 rotation_range = 45, 17 width_shift_range = .15, 18 height_shift_range = .15, 19 horizontal_flip = False, 20 zoom_range = 0.5 21 ) 22 23 image_gen_train.fit(x_train) 24 print("xtrain", x_train.shape) 25 x_train_subset1 = np.squeeze(x_train[:12]) 26 print("xtrain_subset1", x_train_subset1.shape) 27 print("xtrain", x_train.shape) 28 x_train_subset2 = x_train[:12] # 一次显示12张图片 29 print("xtrain_subset2", x_train_subset2.shape) 30 31 32 33 fig = plt.figure(figsize=(20,2)) 34 plt.set_cmap('gray') 35 #显示原始图片 36 for i in range(0, len(x_train_subset1)): 37 ax = fig.add_subplot(1, 12, i+1) 38 ax.imshow(x_train_subset1[i]) 39 fig.suptitle('Subset of Original Training Images', fontsize=20) 40 plt.show()