【机器学习实战入门】使用CNN和Keras进行交通标志识别,准确率达到95% Traffic Signs Recognition with 95% Accuracy using CNN&Keras
什么是交通标志识别?
交通标志有很多种,如限速标志、禁止进入、交通信号灯、左转或右转、儿童过街、禁止重型车辆通行等。交通标志分类是识别交通标志属于哪一类的过程。
交通标志识别 – 关于 Python 项目
在这个 Python 项目示例中,我们将构建一个深度神经网络模型,能够将图像中的交通标志分类到不同的类别中。借助这个模型,我们可以读取并理解交通标志,这对于所有自动驾驶车辆来说是非常重要的任务。
交通标志识别 Python 项目创意
本项目的数据集
对于该项目,我们使用了 Kaggle 上的公开数据集:
链接: 使用CNN和Keras进行交通标志识别,准确率达到95% 源代码和数据集 Python-Project-Traffic-Sign-Classification
交通标志数据集
数据集中包含50,000多张不同交通标志的图像。它进一步划分为43个不同的类别。数据集非常多样,有些类别的图像很多,而有些类别的图像很少。数据集的大小约为300MB。数据集中有一个训练文件夹,其中包含每个类别的图像,还有一个测试文件夹,我们将使用它来测试我们的模型。
Python 项目数据集
先决条件
该项目需要具备 Keras、Matplotlib、Scikit-learn、Pandas、PIL 和图像分类的先前知识。
为了安装此 Python 数据科学项目所需的所有包,请在终端中输入以下命令:
pip install tensorflow keras sklearn matplotlib pandas pil
构建 Python 项目的步骤
要开始项目,请从此链接下载并解压文件 – 交通标志识别压缩文件
并将文件解压到一个文件夹,使你拥有一组训练集、测试集和元数据文件夹。
链接: 使用CNN和Keras进行交通标志识别,准确率达到95% 源代码和数据集 Python-Project-Traffic-Sign-Classification
在 Python 项目中探索数据集
我们的‘train’文件夹包含43个文件夹,每个代表不同类别。文件夹的编号范围从0到42。借助 OS 模块,我们遍历所有类别,并将图像及其相应的标签附加到数据和标签列表中。
PIL 库用于将图像内容打开为数组。
最后,我们将所有图像及其标签存储到列表(data 和 labels)中。
我们需要将列表转换为 NumPy 数组以便为模型提供输入。
数据的形状为 (39209, 30, 30, 3),这意味着有 39,209 张大小为 30×30 像素的图像,最后的 3 意味着数据包含彩色图像(RGB 值)。
使用 sklearn 包中的 train_test_split() 方法将训练和测试数据分开。
从 keras.utils 包中,我们使用 to_categorical 方法将 y_train 和 t_test 中的标签转换为独热编码。
在 Python 项目中拆分数据集
步骤 2:构建 CNN 模型
要将图像分类到其各自的类别中,我们将构建一个 CNN 模型(卷积神经网络)。CNN 在图像分类方面表现出色。
我们模型的架构为:
- 两个 Conv2D 层(filter=32, kernel_size=(5,5), activation=“relu”)
- 一个 MaxPool2D 层(pool_size=(2,2))
- 一个 Dropout 层(rate=0.25)
- 两个 Conv2D 层(filter=64, kernel_size=(3,3), activation=“relu”)
- 一个 MaxPool2D 层(pool_size=(2,2))
- 一个 Dropout 层(rate=0.25)
- 一个 Flatten 层,将各层压缩为一维
- 一个全连接层(256 个节点,activation=“relu”)
- 一个 Dropout 层(rate=0.5)
- 一个全连接层(43 个节点,activation=“softmax”)
我们使用 Adam 优化器编译模型,性能良好,损失函数为 “categorical_crossentropy”,因为我们有多类别需要分类。
在 Python 数据科学项目中构建 CNN 模型
步骤 3:训练并验证模型
构建完模型架构后,我们使用 model.fit() 训练模型。我尝试了 32 和 64 的批次大小。我们的模型在 64 的批次大小下表现更好,在 15 个周期后准确率趋于稳定。
项目中的模型训练
我们的模型在训练数据集上达到了 95% 的准确率。使用 Matplotlib,我们绘制了准确率和损失的图表。
在 Python 项目示例中绘制准确率
绘制准确率
在 Python 机器学习项目中的准确率和损失图表
准确率与损失图表
步骤 4:用测试数据集测试我们的模型
我们的数据集包含一个测试文件夹,在 test.csv 文件中,我们有关于图像路径及其相应类别标签的详细信息。我们使用 Pandas 提取图像路径和标签。为了预测模型,我们需要将图像调整为 30×30 像素,并形成一个包含所有图像数据的 NumPy 数组。从 sklearn.metrics 中,我们导入了 accuracy_score,并观察模型对实际标签的预测准确率。我们在该模型中同样实现了 95% 的准确率。
在高级 Python 项目中测试准确率
最后,我们将使用 Keras 的 model.save() 函数保存训练好的模型。
model.save('traffic_classifier.h5')
等一下!你是否查看了我们的最新 OpenCV & 计算机视觉教程?
交通标志分类器 GUI
现在我们将使用 Tkinter 为交通标志分类器构建一个图形用户界面。Tkinter 是 Python 标准库中的一个 GUI 工具包。在项目文件夹中新建一个文件,复制以下代码。将其保存为 gui.py,你可以在命令行中通过输入 python gui.py 来运行代码。
在这个文件中,我们首先使用 Keras 加载了训练好的模型‘traffic_classifier.h5’。然后我们构建了一个 GUI 用于上传图像,使用了一个按钮进行分类,该按钮调用了 classify() 函数。classify() 函数将图像转换为 (1, 30, 30, 3) 的维度。这是因为我们需要提供与构建模型时使用的维度相同的输入来预测交通标志。然后我们预测类别,model.predict_classes(image) 返回一个 0-42 之间的数字,这个数字代表它所属的类别。我们使用字典来获取类别的信息。以下为 gui.py 文件的代码。
代码:
import tkinter as tk
from tkinter import filedialog
from tkinter import *
from PIL import ImageTk, Image
import numpy
# 加载训练好的模型以对标志进行分类
from keras.models import load_model
model = load_model('traffic_classifier.h5')
# 用于标注所有交通标志类别的字典
classes = { 1: '限速 (20km/h)' ,
2: '限速 (30km/h)',
3: '限速 (50km/h)',
4: '限速 (60km/h)',
5: '限速 (70km/h)',
6: '限速 (80km/h)',
7: '解除限速 (80km/h)',
8: '限速 (100km/h)',
9: '限速 (120km/h)',
10: '禁止超车',
11: '禁止 3.5 吨以上车辆超车',
12: '交叉路口先行权',
13: '优先道路',
14: '让行',
15: '停车',
16: '禁止车辆通行',
17: '禁止 3.5 吨以上车辆通行',
18: '禁止入内',
19: '注意危险',
20: '左侧危险弯道',
21: '右侧危险弯道',
22: '双弯道',
23: '路面不平',
24: '路面湿滑',
25: '右侧路面变窄',
26: '施工道路',
27: '交通信号',
28: '行人过桥',
29: '儿童过街',
30: '自行车过街',
31: '注意冰/雪危险',
32: '野生动物横穿道路',
33: '解除限速和超车限制',
34: '向右转',
35: '向左转',
36: '直行',
37: '直行或向右',
38: '直行或向左',
39: '靠右行驶',
40: '靠左行驶',
41: '强制环岛',
42: '解除禁止超车',
43: '解除禁止 3.5 吨以上车辆超车' }
# 初始化 GUI
top = tk.Tk()
top.geometry('800x600')
top.title('交通标志分类')
top.configure(background='#CDCDCD')
label = Label(top, background='#CDCDCD', font=('arial',15,'bold'))
sign_image = Label(top)
def classify(file_path):
global label_packed
image = Image.open(file_path)
image = image.resize((30,30))
image = numpy.expand_dims(image, axis=0)
image = numpy.array(image)
pred = model.predict_classes([image])[0]
sign = classes[pred+1]
print(sign)
label.configure(foreground='#011638', text=sign)
def show_classify_button(file_path):
classify_b = Button(top, text="分类图像", command=lambda: classify(file_path), padx=10, pady=5)
classify_b.configure(background='#364156', foreground='white', font=('arial', 10, 'bold'))
classify_b.place(relx=0.79, rely=0.46)
def upload_image():
try:
file_path = filedialog.askopenfilename()
uploaded = Image.open(file_path)
uploaded.thumbnail(((top.winfo_width()/2.25), (top.winfo_height()/2.25)))
im = ImageTk.PhotoImage(uploaded)
sign_image.configure(image=im)
sign_image.image = im
label.configure(text='')
show_classify_button(file_path)
except:
pass
upload = Button(top, text="上传图像", command=upload_image, padx=10, pady=5)
upload.configure(background='#364156', foreground='white', font=('arial', 10, 'bold'))
upload.pack(side=BOTTOM, pady=50)
sign_image.pack(side=BOTTOM, expand=True)
label.pack(side=BOTTOM, expand=True)
heading = Label(top, text="了解你的交通标志", pady=20, font=('arial', 20, 'bold'))
heading.configure(background='#CDCDCD', foreground='#364156')
heading.pack()
top.mainloop()
在 Python 项目中为交通标志识别构建的图形用户界面
总结
在这个带有源代码的 Python 项目中,我们成功地实现了交通标志分类器,准确率达到 95%。此外,我们还通过图表可视化了随时间变化的准确率和损失的变化,从简单的 CNN 模型来看,这是一个相当不错的成绩。
参考资料
资料名称 | 链接 |
---|---|
Kaggle 交通标志数据集 | Kaggle |
TensorFlow 官方文档 | TF official |
Keras 官方文档 | Keras official |
Matplotlib 官方文档 | Matplotlib official |
Pandas 官方文档 | Pandas official |
Python Imaging Library (PIL) | PIL |
Sklearn 官方文档 | Sklearn official |
Python 官方文档 | Python official |
Tkinter 教程 | Tkinter tutorial |
计算机视觉简明教程 | CV intro |
交通标志识别技术综述 | 论文链接 |
深度学习与交通标志识别 | 论文链接 |
使用 CNN 进行交通标志识别的研究 | 论文链接 |
结合数据增强和卷积神经网络的交通标志识别方法 | 论文链接 |
机器学习交通标志识别方法 | 论文链接 |
交通标志分类的性能评估 | 论文链接 |
参考链接:https://data-flair.training/blogs/python-project-traffic-signs-recognition/
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)