AI识别图片(为初始版本,还要继续优化)
1. 数据搜集
数据爬取
使用爬虫,将百度图片搜索到的六类图片(冰淇淋、甜甜圈、米饭、披萨、汉堡包、蛋挞)爬取到本地。
2. 训练模型
2.1划分数据
将所有数据分为三个部分:
-
训练集:拿这部分的数据进行模型的训练,就是拿这里的图片给 AI程序学习。
-
交叉验证集:拿这部分的数据进行模型训练结果的验证,这部分给出了当前模型的效果,如果效果不好,可以在之后更换模型或者算法!
-
测试集:交叉验证集评估出了不错的模型,我们将用此模型在测试集上观察实际效果。
点击查看代码
import os, shutil
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def split_data(src_dir, dst_dir, train_prop, val_prop, test_prop):
# 我们使用 kreas 一个文件夹代表一个种类 的 数据处理函数
src_classes = os.listdir(src_dir)
data_sets = ['train', 'val', 'test']
for data_set in data_sets:
for src_class in src_classes:
mkdirs(os.path.join(dst_dir, data_set, src_class))
for src_class in src_classes:
class_images = os.listdir(os.path.join(src_dir, src_class))
num = len(class_images)
train_class_images = class_images[:int(num * train_prop)]
val_class_images = class_images[int(num * train_prop): int(num * (train_prop + val_prop))]
test_class_images = class_images[int(num * (train_prop + val_prop)):]
print("Copying class:{} to train set!".format(src_class))
for class_image in train_class_images:
src = os.path.join(src_dir, src_class, class_image)
dst = os.path.join(dst_dir, 'train' , src_class, class_image)
shutil.copyfile(src, dst)
print("Copying class:{} to val set!".format(src_class))
for class_image in val_class_images:
src = os.path.join(src_dir, src_class, class_image)
dst = os.path.join(dst_dir, 'val' , src_class, class_image)
shutil.copyfile(src, dst)
print("Copying class:{} to test set!".format(src_class))
for class_image in test_class_images:
src = os.path.join(src_dir, src_class, class_image)
dst = os.path.join(dst_dir, 'test' , src_class, class_image)
shutil.copyfile(src, dst)
print("train_num:{}, val_num:{}, test_num:{}".format(len(train_class_images), len(val_class_images), len(test_class_images)))
print('done!')
if __name__ == '__main__':
src_dir = "foodData"
dst_dir = "foodData_cnn_split"
split_data(src_dir, dst_dir, 0.5, 0.25, 0.25)
2.2 keras CNN 构建及训练
点击查看代码
from keras import layers
from keras import models
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow import optimizers
import tensorflow as tf
import os
# model build
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu',input_shape=(224, 224, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.1))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(6, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# Data preprocessing and Using data augmentation
train_dir = './foodData_cnn_split/train'
validation_dir = './foodData_cnn_split/val'
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=30, # 每次找30张图片变换
# save_to_dir = './train_gen', 这个会把数据增强的图片保存下来
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(224, 224),
batch_size=30,
class_mode='categorical')
if not os.path.exists('./train_gen'):
os.mkdir('./train_gen')
print(train_generator.class_indices) # 输出对应的标签文件夹
# fit the model
history = model.fit(
train_generator,
steps_per_epoch=300*6//30, # 总计1800张图片,每次30张,所以需要60次可以遍历一遍
epochs=3,
validation_data=validation_generator,
validation_steps=150*6//30)
# save the model
model.save('six_class_cnn.h5')
# show train data
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
遍历80次的模型精准率,也不是很高,需继续提高遍历次数或者改善算法
3. 图片识别
最终结果
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 【.NET】调用本地 Deepseek 模型