第四讲 网络八股拓展--图像增强

 1 # 显示原始图像和增强后的图像
 2 import tensorflow as tf
 3 from matplotlib import pyplot as plt
 4 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 5 import numpy as np
 6 
 7 
 8 
 9 mnist = tf.keras.datasets.mnist
10 (x_train, y_train), (x_test, y_test) = mnist.load_data()
11 x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
12 
13 
14 image_gen_train = ImageDataGenerator(
15     rescale = 1. / 255,
16     rotation_range = 45,
17     width_shift_range = .15,
18     height_shift_range = .15,
19     horizontal_flip = False,
20     zoom_range = 0.5
21 )
22 
23 image_gen_train.fit(x_train)
24 print("xtrain", x_train.shape)
25 x_train_subset1 = np.squeeze(x_train[:12])
26 print("xtrain_subset1", x_train_subset1.shape)
27 print("xtrain", x_train.shape)
28 x_train_subset2 = x_train[:12] # 一次显示12张图片
29 print("xtrain_subset2", x_train_subset2.shape)
30 
31 
32 
33 fig = plt.figure(figsize=(20,2))
34 plt.set_cmap('gray')
35 #显示原始图片
36 for i in range(0, len(x_train_subset1)):
37   ax = fig.add_subplot(1, 12, i+1)
38   ax.imshow(x_train_subset1[i])
39 fig.suptitle('Subset of Original Training Images', fontsize=20)
40 plt.show()

 

posted @ 2020-05-05 20:39  WWBlog  阅读(303)  评论(0编辑  收藏  举报