借用Ultralytics Yolo快速训练一个物体检测器

借用Ultralytics Yolo快速训练一个物体检测器

 

https://github.com/ultralytics/ultralytics

 

Step-1 准备数据集

你需要一些待检测物体比如安全帽, 把它从各个角度拍摄一下. 再找一些不相关的背景图片. 然后把安全帽给放大缩小旋转等等贴到背景图片上去, 生成一堆训练数据.

 

配置文件:

 
复制代码
extract_cfg:
  output_dir: '/datasets/images'
  fps: 0.25

screen_images_path: '/datasets/待检测图片'
max_scale: 1.0
min_scale: 0.1
manual_scale: [ {name: 'logo', min_scale: 0.05, max_scale: 0.3},
                {name: 'logo', min_scale: 0.1, max_scale: 0.5},
                {name: '箭头', min_scale: 0.1, max_scale: 0.5}
]
data_cfgs: [ {id: 0, name: 'logo', min_scale: 0.05, max_scale: 0.3, gen_num: 2},
            {id: 1, name: '截屏', min_scale: 0.1, max_scale: 1.0, gen_num: 3, need_full_screen: true},
            {id: 2, name: '红包', min_scale: 0.1, max_scale: 0.5, gen_num: 2},
            {id: 3, name: '箭头', min_scale: 0.1, max_scale: 0.5, gen_num: 2, rotate_aug: true},
]
save_oss_dir: /datasets/gen_datasets/
gen_num_per_image: 2
max_bg_img_sample: 1
复制代码

数据集生成:

 

  

运行后, 可以在outputs文件夹下生成符合要求的训练数据.

 

image 就是背景+检测物体

labels 中的内容就是这样的文件:

1
2
1 0.6701388888888888 0.289453125 0.5736111111111111 0.57421875
# 类型 box

  

 

Step-2 训练模型

 

这个更简单, 在官网下载一个模型权重, 比如yolo8s.pt, 对付安全帽这种东西, 几M大的模型就够了.

训练配置文件:

1
2
3
4
5
6
7
8
names:
  0: logo
  1: 截屏
  2: 红包
path: /outputs
test: images/test
train: images/train
val: images/val

  

训练代码:

没错就这么一点

1
2
3
4
from ultralytics import YOLO
 
model = YOLO('./yolo8s.pt')
model.train(data='dataset.yaml', epochs=100, imgsz=1280)

 

然后就可以自动化训练了, 结束后会自动保存模型与评估检测效果.

 

 

Step-3 检测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Special_Obj_Detect(object):
 
    def __init__(self, cfg) -> None:
        model_path = cfg.model_path
        self.model = YOLO(model_path)
        self.model.requires_grad_ = False
        self.cls_names = {0: 'logo', 1: '截屏', 2: '红包'}
 
    # 单帧图像检测
    def detect_image(self, img_path):
        results = self.model(img_path)
        objects = []
        objects_cnt = dict()
        objects_area_pct = dict()
        for result in results:
            result = result.cpu()
            boxes = list(result.boxes)
            for box in boxes:
                if box.conf < 0.8: continue
                boxcls = box.cls[0].item()
                objects.append(self.cls_names[boxcls])
                objects_cnt[self.cls_names[boxcls]] = objects_cnt.get(self.cls_names[boxcls], 0) + 1
                area_p = sum([ (xywh[2]*xywh[3]).item()  for xywh in box.xywhn])
                area_p = min(1, area_p)
                objects_area_pct[self.cls_names[boxcls]] = area_p
        objects = list(set(objects))
        return objects, objects_cnt, objects_area_pct

  

收工.

 

本文作者:JiangOil

本文链接: https://www.codebonobo.tech/post/14

 

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!

 

posted @   酱_油  阅读(284)  评论(5编辑  收藏  举报
相关博文:
阅读排行:
· 2分钟学会 DeepSeek API,竟然比官方更好用!
· .NET 使用 DeepSeek R1 开发智能 AI 客户端
· 10亿数据,如何做迁移?
· 推荐几款开源且免费的 .NET MAUI 组件库
· c# 半导体/led行业 晶圆片WaferMap实现 map图实现入门篇
历史上的今天:
2021-10-31 WPS中导入endnote插件
点击右上角即可分享
微信分享提示