TensorFlow迁移学习Resnet50预测10-monkey-species
In [15]:
from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
from scipy import ndimage
import matplotlib.pyplot as plt
In [2]:
resnet50 = keras.applications.ResNet50(include_top=False, pooling='avg')
In [3]:
classes = 10
model = keras.models.Sequential()
model.add(resnet50)
model.add(keras.layers.Dense(classes, activation='softmax'))
In [4]:
model.layers
Out[4]:
[<keras.src.engine.functional.Functional at 0x25ed1b6fb80>, <keras.src.layers.core.dense.Dense at 0x25ed1b06220>]
In [5]:
# 把除最后一层之外的参数都冻结
model.layers[0].trainable = False
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= resnet50 (Functional) (None, 2048) 23587712 dense (Dense) (None, 10) 20490 ================================================================= Total params: 23608202 (90.06 MB) Trainable params: 20490 (80.04 KB) Non-trainable params: 23587712 (89.98 MB) _________________________________________________________________
In [6]:
# 文件下载地址 https://www.kaggle.com/datasets/slothkong/10-monkey-species
train_dir = './10-monkey-species/training/training'
valid_dir = './10-monkey-species/validation/validation'
label_file = './10-monkey-species/monkey_labels.txt'
df = pd.read_csv(label_file, header=0)
df
Label | Latin Name | Common Name | Train Images | Validation Images | |
---|---|---|---|---|---|
0 | n0 | alouatta_palliata\t | mantled_howler | 131 | 26 |
1 | n1 | erythrocebus_patas\t | patas_monkey | 139 | 28 |
2 | n2 | cacajao_calvus\t | bald_uakari | 137 | 27 |
3 | n3 | macaca_fuscata\t | japanese_macaque | 152 | 30 |
4 | n4 | cebuella_pygmea\t | pygmy_marmoset | 131 | 26 |
5 | n5 | cebus_capucinus\t | white_headed_capuchin | 141 | 28 |
6 | n6 | mico_argentatus\t | silvery_marmoset | 132 | 26 |
7 | n7 | saimiri_sciureus\t | common_squirrel_monkey | 142 | 28 |
8 | n8 | aotus_nigriceps\t | black_headed_night_monkey | 133 | 27 |
9 | n9 | trachypithecus_johnii | nilgiri_langur | 132 | 26 |
In [7]:
# 图片数据生成器
height = 224
width = 224
channels = 3
batch_size = 32
classes = 10
train_datagen = keras.preprocessing.image.ImageDataGenerator(
preprocessing_function = keras.applications.resnet50.preprocess_input, # 使用原版resnet50自身的数据处理
rotation_range = 40, # 随机旋转 0~40°之间
width_shift_range = 0.2, # 随机水平移动
height_shift_range = 0.2, # 随机垂直移动
shear_range = 0.2, # 随机裁剪比例
zoom_range = 0.2, # 随机缩放比例
horizontal_flip = True, # 随机水平翻转
vertical_flip = True, # 随机垂直翻转
fill_mode = 'nearest', # 填充模式
)
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(height, width),
batch_size=batch_size, shuffle=True, class_mode='categorical')
valid_datagen = keras.preprocessing.image.ImageDataGenerator(
preprocessing_function = keras.applications.resnet50.preprocess_input
)
valid_generator = valid_datagen.flow_from_directory(valid_dir, target_size=(height, width),
batch_size=batch_size, shuffle=False, class_mode='categorical')
Found 1098 images belonging to 10 classes. Found 272 images belonging to 10 classes.
In [8]:
# 训练
history = model.fit(train_generator, steps_per_epoch=train_generator.samples//batch_size, epochs=10,
validation_data=valid_generator, validation_steps=valid_generator.samples//batch_size)
Epoch 1/10 34/34 [==============================] - 41s 1s/step - loss: 1.4802 - acc: 0.5394 - val_loss: 0.4833 - val_acc: 0.9062 Epoch 2/10 34/34 [==============================] - 41s 1s/step - loss: 0.4419 - acc: 0.8856 - val_loss: 0.2153 - val_acc: 0.9531 Epoch 3/10 34/34 [==============================] - 41s 1s/step - loss: 0.2977 - acc: 0.9184 - val_loss: 0.1501 - val_acc: 0.9727 Epoch 4/10 34/34 [==============================] - 41s 1s/step - loss: 0.2131 - acc: 0.9531 - val_loss: 0.1265 - val_acc: 0.9805 Epoch 5/10 34/34 [==============================] - 42s 1s/step - loss: 0.1880 - acc: 0.9531 - val_loss: 0.0946 - val_acc: 0.9922 Epoch 6/10 34/34 [==============================] - 41s 1s/step - loss: 0.1616 - acc: 0.9662 - val_loss: 0.0941 - val_acc: 0.9883 Epoch 7/10 34/34 [==============================] - 41s 1s/step - loss: 0.1419 - acc: 0.9662 - val_loss: 0.1031 - val_acc: 0.9805 Epoch 8/10 34/34 [==============================] - 42s 1s/step - loss: 0.1120 - acc: 0.9709 - val_loss: 0.0717 - val_acc: 0.9922 Epoch 9/10 34/34 [==============================] - 39s 1s/step - loss: 0.1212 - acc: 0.9728 - val_loss: 0.0757 - val_acc: 0.9844 Epoch 10/10 34/34 [==============================] - 39s 1s/step - loss: 0.1045 - acc: 0.9700 - val_loss: 0.0707 - val_acc: 0.9844
In [14]:
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid()
plt.gca().set_ylim(0, 1.5)
plt.show()

In [43]:
model.evaluate(valid_generator) # 准确率非常高
9/9 [==============================] - 7s 721ms/step - loss: 0.0705 - acc: 0.9853
Out[43]:
[0.07051555812358856, 0.9852941036224365]
In [53]:
monkey = plt.imread('./10-monkey-species/n9010.jpg')
plt.imshow(monkey)
monkey.shape
Out[53]:
(311, 472, 3)

In [54]:
x = width / monkey.shape[0]
y = height / monkey.shape[1]
monkey_zoomed = ndimage.zoom(monkey, (x, y, 1))
monkey_zoomed.shape
Out[54]:
(224, 224, 3)
In [55]:
monkey_zoomed = keras.applications.resnet50.preprocess_input(monkey_zoomed)
monkey_zoomed = monkey_zoomed.reshape(1, width, height, 3) # 这里1就是每批次样本数
predict=model.predict(monkey_zoomed)predict
1/1 [==============================] - 0s 64ms/step
Out[55]:
array([[2.0728735e-03, 2.4864163e-02, 8.3858357e-04, 7.8938156e-03, 2.2254630e-04, 1.1891891e-03, 1.4768102e-03, 1.1359336e-03, 1.0765824e-03, 9.5922953e-01]], dtype=float32)
In [57]:
predict.argmax(axis=1) # 预测正确
Out[57]:
array([9], dtype=int64)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步