第三讲 神经网络八股--mnist数据集分类
1 import tensorflow as tf 2 from matplotlib import pyplot as plt 3 4 5 mnist = tf.keras.datasets.mnist 6 (x_train, y_train), (x_test, y_test) = mnist.load_data() 7 8 9 # 可视化训练集输入特征的第一个元素 10 plt.imshow(x_train[0], cmap='gray') #绘制灰度图 11 plt.show() 12 13 14 # 打印出训练集输入特征的第一个元素 15 print("x_train[0]:\n", x_train[0]) 16 # 打印出训练集标签的第一个元素 17 print("y_train[0]:\n", y_train[0]) 18 19 20 21 # 打印出整个训练集输入特征形状 22 print("x_train.shape:\n", x_train.shape) 23 # 打印出整个训练集标签的形状 24 print("y_train.shape:\n", y_train.shape) 25 # 打印出整个测试集输入特征的形状 26 print("x_test.shape:\n", x_test.shape) 27 # 打印出整个测试集标签的形状 28 print("y_test.shape:\n", y_test.shape) 29 30 31 32 import tensorflow as tf 33 34 mnist = tf.keras.datasets.mnist 35 (x_train, y_train), (x_test, y_test) = mnist.load_data() 36 x_train, x_test = x_train/255.0, x_test/255.0 37 38 model = tf.keras.models.Sequential([ 39 tf.keras.layers.Flatten(), 40 tf.keras.layers.Dense(128, activation='relu'), 41 tf.keras.layers.Dense(10, activation='softmax') 42 ]) 43 44 model.compile(optimizer='adam', 45 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 46 metrics=["sparse_categorical_accuracy"]) 47 48 model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test)) 49 model.summary() 50 51 52 53 54 55 import tensorflow as tf 56 from tensorflow.keras.layers import Dense, Flatten 57 from tensorflow.keras import Model 58 59 mnist = tf.keras.datasets.mnist 60 (x_train, y_train), (x_test, y_test) = mnist.load_data() 61 x_train, x_test = x_train/255.0, x_test/255.0 62 63 64 class MnistModel(Model): 65 def __init__(self): 66 super(MnistModel, self).__init__() 67 self.flatten = Flatten() 68 self.d1 = Dense(128, activation='relu') 69 self.d2 = Dense(10, activation='softmax') 70 71 def call(self, x): 72 x = self.flatten(x) 73 x = self.d1(x) 74 y = self.d2(x) 75 return y 76 77 model = MnistModel() 78 79 model.compile(optimizer='adam', 80 loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 81 metrics=['sparse_categorical_accuracy']) 82 83 model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1) 84 model.summary()