TensorFlow入门文章
文档说明
1. 介绍
TensorFlow 是一个开源的机器学习框架,由谷歌开发和维护。它广泛应用于各类机器学习任务,包括但不限于图像分类、自然语言处理和时间序列预测。本文将介绍如何使用TensorFlow创建一个简单的神经网络进行图像分类。
2. 安装
在开始之前,请确保你已经安装了TensorFlow。你可以通过以下命令安装:
pip install tensorflow
3. 数据准备
我们将使用MNIST数据集,这是一个包含手写数字的经典数据集。以下是加载数据集的代码:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化
x_train, x_test = x_train / 255.0, x_test / 255.0
4. 模型构建
构建一个简单的神经网络模型:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
5. 编译和训练模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
6. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'\nTest accuracy: {test_acc}')
代码测试用例
为了验证代码的正确性,可以编写一些简单的测试用例:
def test_model_accuracy():
# 确保测试集的准确率不低于 90%
assert test_acc >= 0.90, f"Test accuracy too low: {test_acc}"
def test_model_output():
# 确保模型输出的shape正确
predictions = model.predict(x_test)
assert predictions.shape == (10000, 10), f"Prediction shape incorrect: {predictions.shape}"
test_model_accuracy()
test_model_output()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)