Tensorflow2.0笔记22——回顾

Tensorflow2.0笔记

本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师

1.1 tf.keras 搭建神经网络八股——六步法

  1. import——导入所需的各种库和包

  2. x_train, y_train——导入数据集、自制数据集、数据增强

3)model=tf.keras.models.Sequential /class MyModel(Model) model=MyModel——定义模型

4)model.compile——配置模型

  1. model.fit——训练模型、断点续训

  2. model.summary——参数提取、acc/loss 可视化、前向推理实现应用

import tensorflow as tf

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

posted @ 2021-02-03 21:01  Mr_WildFire  阅读(44)  评论(0编辑  收藏  举报