TensorFlow入门文章

文档说明

1. 介绍

TensorFlow 是一个开源的机器学习框架,由谷歌开发和维护。它广泛应用于各类机器学习任务,包括但不限于图像分类、自然语言处理和时间序列预测。本文将介绍如何使用TensorFlow创建一个简单的神经网络进行图像分类。

2. 安装

在开始之前,请确保你已经安装了TensorFlow。你可以通过以下命令安装:

pip install tensorflow

3. 数据准备

我们将使用MNIST数据集,这是一个包含手写数字的经典数据集。以下是加载数据集的代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

4. 模型构建

构建一个简单的神经网络模型:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

5. 编译和训练模型

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

6. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'\nTest accuracy: {test_acc}')

代码测试用例

为了验证代码的正确性,可以编写一些简单的测试用例:

def test_model_accuracy():
    # 确保测试集的准确率不低于 90%
    assert test_acc >= 0.90, f"Test accuracy too low: {test_acc}"

def test_model_output():
    # 确保模型输出的shape正确
    predictions = model.predict(x_test)
    assert predictions.shape == (10000, 10), f"Prediction shape incorrect: {predictions.shape}"

test_model_accuracy()
test_model_output()

UML图

posted @   mychat  阅读(16)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)
点击右上角即可分享
微信分享提示