TensorFlow入门文章

文档说明

1. 介绍

TensorFlow 是一个开源的机器学习框架,由谷歌开发和维护。它广泛应用于各类机器学习任务,包括但不限于图像分类、自然语言处理和时间序列预测。本文将介绍如何使用TensorFlow创建一个简单的神经网络进行图像分类。

2. 安装

在开始之前,请确保你已经安装了TensorFlow。你可以通过以下命令安装:

pip install tensorflow

3. 数据准备

我们将使用MNIST数据集,这是一个包含手写数字的经典数据集。以下是加载数据集的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

4. 模型构建

构建一个简单的神经网络模型:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

5. 编译和训练模型

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

6. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'\nTest accuracy: {test_acc}')

代码测试用例

为了验证代码的正确性,可以编写一些简单的测试用例:

def test_model_accuracy():
    # 确保测试集的准确率不低于 90%
    assert test_acc >= 0.90, f"Test accuracy too low: {test_acc}"

def test_model_output():
    # 确保模型输出的shape正确
    predictions = model.predict(x_test)
    assert predictions.shape == (10000, 10), f"Prediction shape incorrect: {predictions.shape}"

test_model_accuracy()
test_model_output()

UML图

posted @   mychat  阅读(17)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示