使用八股搭建手写数据集神经网络
写在前面
今天是初五,好好的玩了几天后还是回归到了学习的正轨上。今天主要学习了神经网络的搭建八股,使用这种模型搭建了一个训练手写数据集的神经网络
搭建网络八股
六步法:
import
train,test
model = tf.keras.models.Sequential
model.compile
model.fit
model.summary
总的来说,首先导包,然后指定出训练集和测试集。使用tensorflow提供的API搭建好每层神经网络结构,进行compile,指定优化器损失函数和衡量标准。使用fit函数来训练神经网络,最后使用summary来输出训练结果。
训练手写数据集
先来看代码:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
代码不长,我们是严格按照六步法来搭建神经网络,可以看到十分简单。核心部分就是指定神经网络结构。
总结
总的来说,使用这种方法搭建神经网络还是十分简单的,但其中的原理一定要好好理解清楚。