TensorFlow识别CiFar10物品分类
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
In [2]:
(x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train_all.shape
Out[2]:
(50000, 32, 32, 3)
In [3]:
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform( x_train.astype(np.float32).reshape(-1,1) ).reshape(-1, 32, 32, 3)
x_valid_scaled = scaler.transform( x_valid.astype(np.float32).reshape(-1,1) ).reshape(-1, 32, 32, 3)
x_test_scaled = scaler.transform( x_test.astype(np.float32).reshape(-1,1) ).reshape(-1, 32, 32, 3)
In [4]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))
model.add(tf.keras.layers.Dense(512, activation='selu'))
model.add(tf.keras.layers.AlphaDropout(0.2))
model.add(tf.keras.layers.Dense(256, activation='selu'))
model.add(tf.keras.layers.Dense(128, activation='selu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# 配置网络
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)
]
# 训练
history = model.fit(x_train_scaled, y_train, epochs=10, validation_data=(x_valid_scaled, y_valid), callbacks=callbacks)
Epoch 1/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.9129 - acc: 0.3619 - val_loss: 1.6277 - val_acc: 0.4324 Epoch 2/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.5629 - acc: 0.4414 - val_loss: 1.5839 - val_acc: 0.4422 Epoch 3/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.4790 - acc: 0.4725 - val_loss: 1.5413 - val_acc: 0.4626 Epoch 4/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.4196 - acc: 0.4940 - val_loss: 1.4940 - val_acc: 0.4888 Epoch 5/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.3624 - acc: 0.5118 - val_loss: 1.4644 - val_acc: 0.4958 Epoch 6/10 1407/1407 [==============================] - 13s 9ms/step - loss: 1.3151 - acc: 0.5298 - val_loss: 1.5182 - val_acc: 0.4866 Epoch 7/10 1407/1407 [==============================] - 14s 10ms/step - loss: 1.2662 - acc: 0.5484 - val_loss: 1.5000 - val_acc: 0.5164 Epoch 8/10 1407/1407 [==============================] - 14s 10ms/step - loss: 1.2238 - acc: 0.5640 - val_loss: 1.5809 - val_acc: 0.4986 Epoch 9/10 1407/1407 [==============================] - 14s 10ms/step - loss: 1.1800 - acc: 0.5755 - val_loss: 1.5581 - val_acc: 0.5086 Epoch 10/10 1407/1407 [==============================] - 14s 10ms/step - loss: 1.1452 - acc: 0.5930 - val_loss: 1.4877 - val_acc: 0.5300
In [5]:
pd.DataFrame(history.history)
loss | acc | val_loss | val_acc | |
---|---|---|---|---|
0 | 1.912864 | 0.361933 | 1.627748 | 0.4324 |
1 | 1.562863 | 0.441444 | 1.583859 | 0.4422 |
2 | 1.479033 | 0.472489 | 1.541267 | 0.4626 |
3 | 1.419567 | 0.494044 | 1.494043 | 0.4888 |
4 | 1.362388 | 0.511800 | 1.464406 | 0.4958 |
5 | 1.315090 | 0.529844 | 1.518175 | 0.4866 |
6 | 1.266191 | 0.548422 | 1.499979 | 0.5164 |
7 | 1.223794 | 0.564000 | 1.580877 | 0.4986 |
8 | 1.179963 | 0.575489 | 1.558058 | 0.5086 |
9 | 1.145220 | 0.593000 | 1.487734 | 0.5300 |
In [10]:
pd.DataFrame(history.history).plot(figsize=(8,4))
plt.grid()
plt.gca().set_ylim(0,2)
plt.show()

In [7]:
model.evaluate(x_test_scaled, y_test)
313/313 [==============================] - 1s 2ms/step - loss: 1.4554 - acc: 0.5235
Out[7]:
[1.4553571939468384, 0.5235000252723694]
In [8]:
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten (Flatten) (None, 3072) 0 dense (Dense) (None, 512) 1573376 alpha_dropout (AlphaDropout) (None, 512) 0 dense_1 (Dense) (None, 256) 131328 dense_2 (Dense) (None, 128) 32896 dense_3 (Dense) (None, 10) 1290 ================================================================= Total params: 1738890 (6.63 MB) Trainable params: 1738890 (6.63 MB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步