TensorFlow卷积神经网络识别10-monkey-species
In [1]:
from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
In [2]:
# 文件下载地址 https://www.kaggle.com/datasets/slothkong/10-monkey-species
train_dir = './10-monkey-species/training/training'
valid_dir = './10-monkey-species/validation/validation'
label_file = './10-monkey-species/monkey_labels.txt'
In [3]:
df = pd.read_csv(label_file, header=0)
df
Label | Latin Name | Common Name | Train Images | Validation Images | |
---|---|---|---|---|---|
0 | n0 | alouatta_palliata\t | mantled_howler | 131 | 26 |
1 | n1 | erythrocebus_patas\t | patas_monkey | 139 | 28 |
2 | n2 | cacajao_calvus\t | bald_uakari | 137 | 27 |
3 | n3 | macaca_fuscata\t | japanese_macaque | 152 | 30 |
4 | n4 | cebuella_pygmea\t | pygmy_marmoset | 131 | 26 |
5 | n5 | cebus_capucinus\t | white_headed_capuchin | 141 | 28 |
6 | n6 | mico_argentatus\t | silvery_marmoset | 132 | 26 |
7 | n7 | saimiri_sciureus\t | common_squirrel_monkey | 142 | 28 |
8 | n8 | aotus_nigriceps\t | black_headed_night_monkey | 133 | 27 |
9 | n9 | trachypithecus_johnii | nilgiri_langur | 132 | 26 |
In [17]:
# 图片数据生成器
height = 128
width = 128
channels = 3
batch_size = 32
num_classes = 10
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale = 1. / 255, # 归一化&浮点数
rotation_range = 40, # 随机旋转 0~40°之间
width_shift_range = 0.2, # 随机水平移动
height_shift_range = 0.2, # 随机垂直移动
shear_range = 0.2, # 随机裁剪比例
zoom_range = 0.2, # 随机缩放比例
horizontal_flip = True, # 随机水平翻转
vertical_flip = True, # 随机垂直翻转
fill_mode = 'nearest', # 填充模式
)
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(height, width),
batch_size=batch_size, shuffle=True, class_mode='categorical')
valid_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale = 1. / 255, # 归一化&浮点数
)
valid_generator = valid_datagen.flow_from_directory(valid_dir, target_size=(height, width),
batch_size=batch_size, shuffle=False, class_mode='categorical')
Found 1098 images belonging to 10 classes. Found 272 images belonging to 10 classes.
In [5]:
x, y = train_generator.next() # 可以看出 y是经过one hot编码的
print(x.shape, y.shape)
(32, 128, 128, 3) (32, 10)
In [6]:
model = tf.keras.models.Sequential()
# 卷积 # input_shape 输入参数为 (height, width, channels)
model.add(tf.keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same', activation = 'selu', input_shape = (height, width, 3)))
model.add(tf.keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.MaxPool2D()) # 池化
model.add(tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.MaxPool2D()) # 池化
model.add(tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.MaxPool2D()) # 池化
model.add(tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.Conv2D(filters = 256, kernel_size = 3, padding = 'same', activation = 'selu')) # 卷积
model.add(tf.keras.layers.MaxPool2D()) # 池化
# 输入输出shape: 具体而言,是将一个维度大于或等于3的高维矩阵,“压扁”为一个二维矩阵。即保留第一个维度(如:batch的个数),
# 然后将剩下维度的值相乘为“压扁”矩阵的第二个维度。如输入是(None, 32,32,3),则输出是(None, 3072)
model.add(tf.keras.layers.Flatten()) # Flatten层用来将输入“压平”,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
model.add(tf.keras.layers.Dense(2048, activation='selu'))
model.add(tf.keras.layers.AlphaDropout(0.3))
model.add(tf.keras.layers.Dense(1024, activation='selu'))
model.add(tf.keras.layers.AlphaDropout(0.25))
model.add(tf.keras.layers.Dense(512, activation='selu'))
model.add(tf.keras.layers.Dense(128, activation='selu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# 配置网络
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
In [7]:
# 训练
train_num = train_generator.samples
valid_num = valid_generator.samples
history = model.fit(train_generator, steps_per_epoch=train_num//batch_size, epochs=10,
validation_data=valid_generator, validation_steps=valid_num//batch_size)
Epoch 1/10 34/34 [==============================] - 42s 1s/step - loss: 18.0979 - acc: 0.1285 - val_loss: 4.8664 - val_acc: 0.1562 Epoch 2/10 34/34 [==============================] - 33s 957ms/step - loss: 2.6509 - acc: 0.1754 - val_loss: 2.7836 - val_acc: 0.2188 Epoch 3/10 34/34 [==============================] - 31s 921ms/step - loss: 2.2543 - acc: 0.2073 - val_loss: 3.4844 - val_acc: 0.1172 Epoch 4/10 34/34 [==============================] - 31s 900ms/step - loss: 2.0929 - acc: 0.2514 - val_loss: 2.6293 - val_acc: 0.2617 Epoch 5/10 34/34 [==============================] - 31s 917ms/step - loss: 2.0205 - acc: 0.2711 - val_loss: 2.4111 - val_acc: 0.2461 Epoch 6/10 34/34 [==============================] - 31s 919ms/step - loss: 1.9451 - acc: 0.3002 - val_loss: 2.5277 - val_acc: 0.2852 Epoch 7/10 34/34 [==============================] - 32s 924ms/step - loss: 1.8503 - acc: 0.3283 - val_loss: 2.5161 - val_acc: 0.3125 Epoch 8/10 34/34 [==============================] - 30s 882ms/step - loss: 1.8364 - acc: 0.3386 - val_loss: 2.5492 - val_acc: 0.2656 Epoch 9/10 34/34 [==============================] - 30s 867ms/step - loss: 1.7762 - acc: 0.3940 - val_loss: 1.8829 - val_acc: 0.3750 Epoch 10/10 34/34 [==============================] - 30s 873ms/step - loss: 1.7863 - acc: 0.3687 - val_loss: 2.1468 - val_acc: 0.3359
In [16]:
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid()
plt.gca().set_ylim(0, 10)
plt.show()

In [13]:
model.evaluate(valid_generator) # 可以看出效果很差,还是需要用ResNet50这样的复杂网络才可以
9/9 [==============================] - 3s 297ms/step - loss: 2.0539 - acc: 0.3640
Out[13]:
[2.053945541381836, 0.3639705777168274]
分类:
计算机视觉
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?