AI识别图片(为初始版本,还要继续优化)

1. 数据搜集

数据爬取

使用爬虫,将百度图片搜索到的六类图片(冰淇淋、甜甜圈、米饭、披萨、汉堡包、蛋挞)爬取到本地。
image

2. 训练模型

2.1划分数据

将所有数据分为三个部分:

  • 训练集:拿这部分的数据进行模型的训练,就是拿这里的图片给 AI程序学习。

  • 交叉验证集:拿这部分的数据进行模型训练结果的验证,这部分给出了当前模型的效果,如果效果不好,可以在之后更换模型或者算法!

  • 测试集:交叉验证集评估出了不错的模型,我们将用此模型在测试集上观察实际效果。

点击查看代码
import os, shutil

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def split_data(src_dir, dst_dir, train_prop, val_prop, test_prop):
    # 我们使用 kreas 一个文件夹代表一个种类 的 数据处理函数
    src_classes = os.listdir(src_dir)

    data_sets = ['train', 'val', 'test']

    for data_set in data_sets:
        for src_class in src_classes:
            mkdirs(os.path.join(dst_dir, data_set, src_class))

    for src_class in src_classes:
        class_images = os.listdir(os.path.join(src_dir, src_class))
        num = len(class_images)

        train_class_images = class_images[:int(num * train_prop)]
        val_class_images = class_images[int(num * train_prop): int(num * (train_prop + val_prop))]
        test_class_images = class_images[int(num * (train_prop + val_prop)):]

        print("Copying class:{} to train set!".format(src_class))
        for class_image in train_class_images:
            src = os.path.join(src_dir, src_class, class_image)
            dst = os.path.join(dst_dir, 'train' , src_class,  class_image)
            shutil.copyfile(src, dst)

        print("Copying class:{} to val set!".format(src_class))
        for class_image in val_class_images:
            src = os.path.join(src_dir, src_class, class_image)
            dst = os.path.join(dst_dir, 'val' , src_class,  class_image)
            shutil.copyfile(src, dst)

        print("Copying class:{} to test set!".format(src_class))
        for class_image in test_class_images:
            src = os.path.join(src_dir, src_class, class_image)
            dst = os.path.join(dst_dir, 'test' , src_class,  class_image)
            shutil.copyfile(src, dst)

        print("train_num:{}, val_num:{}, test_num:{}".format(len(train_class_images), len(val_class_images), len(test_class_images)))

    print('done!')


if __name__ == '__main__':
    src_dir = "foodData"
    dst_dir = "foodData_cnn_split"
    split_data(src_dir, dst_dir, 0.5, 0.25, 0.25)

2.2 keras CNN 构建及训练

点击查看代码
from keras import layers
from keras import models
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from tensorflow import optimizers
import tensorflow as tf
import os

# model build

model = models.Sequential()

model.add(layers.Conv2D(32, (3, 3), activation='relu',input_shape=(224, 224, 3)))
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))


model.add(layers.Flatten())
model.add(layers.Dropout(0.1))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(6, activation='softmax'))

model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()


# Data preprocessing and Using data augmentation
train_dir = './foodData_cnn_split/train'
validation_dir = './foodData_cnn_split/val'

train_datagen = ImageDataGenerator(
                    rescale=1./255,
                    rotation_range=40,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    horizontal_flip=True,
                    fill_mode='nearest')

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
                    train_dir,
                    target_size=(224, 224),
                    batch_size=30,  # 每次找30张图片变换
                    # save_to_dir = './train_gen', 这个会把数据增强的图片保存下来
                    class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
                    validation_dir,
                    target_size=(224, 224),
                    batch_size=30,
                    class_mode='categorical')

if not os.path.exists('./train_gen'):
    os.mkdir('./train_gen')

print(train_generator.class_indices)   # 输出对应的标签文件夹

# fit the model
history = model.fit(
                    train_generator,
                    steps_per_epoch=300*6//30,  # 总计1800张图片,每次30张,所以需要60次可以遍历一遍
                    epochs=3,
                    validation_data=validation_generator,
                    validation_steps=150*6//30)

# save the model
model.save('six_class_cnn.h5')

# show train data

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(1, len(acc) + 1)

plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

遍历80次的模型精准率,也不是很高,需继续提高遍历次数或者改善算法
精准率

3. 图片识别

最终结果

image
image
image
image

posted @   camellia*  阅读(139)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 一个费力不讨好的项目,让我损失了近一半的绩效!
· 实操Deepseek接入个人知识库
· CSnakes vs Python.NET:高效嵌入与灵活互通的跨语言方案对比
· Plotly.NET 一个为 .NET 打造的强大开源交互式图表库
· 【.NET】调用本地 Deepseek 模型
点击右上角即可分享
微信分享提示