Tensorflow05-简单的全连接神经网络案例

第一个神经元网络就使用最简单的全连接神经网络。

 

使用tensorflow里的 fashion_mnist 服饰数据集 来完成此次的入门案例,建议使用 jupyter 分步执行,每步都理解掌握。

数据集介绍:大概60000张图片,分成了衣服帽子鞋子等等10个类别。每张图片是由 28*28 个像素组成的,每个像素取值 0 ~ 255。

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt


# 加载数据集
fashion_mnist = keras.datasets.fashion_mnist
# 得到训练/测试 数据,训练/测试 标签
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# 查看数据形状
train_images.shape, train_labels.shape

plt.imshow(test_images[0])   # 画图用  imshow !

# 创建神经元模型
model = keras.Sequential()
# 第一层使用Flatten
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))

# 查看神经网络结构
model.summary()

#配置训练方法,optimizer(优化器)为 经常使用的Adam,损失函数使用sparse_categorical_crossentropy,注意还有不带sparse的,则表示数据为 独热编码形式的。
model.compile(optimizer=tf.optimizers.Adam(), loss=tf.losses.sparse_categorical_crossentropy, metrics=['accuracy'])

# 为防止过拟合定义该类
class myCallback(tf.keras.callbacks.Callback): # 继承自 Callback
    def on_epoch_end(self, epoch, logs={}): # 重写该方法
        if(logs.get('loss') < 0.4): # 如果 loss < 0.4, 认为发生过拟合
            print("\nLoss is low so cancelling training")
            self.model.stop_training = True # 停止训练

callbacks = myCallback()

# 归一化
train_images = train_images/255
test_images_scaled = test_images/255

# 训练数据得到 history 对象,最后一个参数表示自动中止训练,类的定义在上方
history = model.fit(train_images, train_labels, epochs=5, callbacks=[callbacks])

# 利用 测试数据/测试标签 评估模型
model.evaluate(test_images_scaled, test_labels)

# 预测数据,并提取第一个(0)的预测结果
model.predict(test_images_scaled)[0]

 对该案例代码中的一些解释:

首先这个数据集的每个元素是二维的,即这个数据集存放着若干张图片,每个图片是一个像素 28*28 的二维矩阵存储。

所以我们的模型第一层使用 Flatten,作用是将二维输入数据转换成一维的。也就是输入层。

Dense 表示全连接网络,至于参数 激活函数 activation 在上篇博客中有详细解释。

第二个 Dense 是输出层,一共有 10 个类别,所以输出的神经元个数为 10。这层也叫输出层。

介于输入输出层之间为 隐含层,这里的隐含层只有一个,也是 Dense,这里神经元数量128,可以自己更改,以得到更好的训练结果。

配置模型的编译 compile ,优化器为 Adam(),损失函数为 sparse_categorical_crossentropy

自定义的 Callback 的继承类,防止过拟合。

fit 训练数据

evaluate 利用测试集评估模型

predict 预测数据

posted @ 2021-02-06 18:06  大雪初晴丶  阅读(425)  评论(0编辑  收藏  举报