机器视觉-SAHI

SAHI 资料

yolov8示例代码: https://github.com/obss/sahi/blob/main/demo/inference_for_yolov8.ipynb
测试图像: https://github.com/obss/sahi/blob/main/tests/data/small-vehicles1.jpeg
原理介绍: https://learnopencv.com/slicing-aided-hyper-inference/
sahi命令行使用说明: https://github.com/obss/sahi/blob/main/docs/cli.md#predict-command-usage

步骤1: 模型初始化

SAHI 默认支持yolov5/yolov8/mmdet等多种预测网络, 我们可以直接使用yolov8的预训练模型文件, 下面是集成yolov8模型的示例代码:

detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov8',
    model_path=yolov8_model_path,
    confidence_threshold=0.3,
    device="cpu", # or 'cuda:0'
)

步骤2: 进行推理:

SAHI 不仅提供了slice 版推理函数 get_sliced_prediction(), 而且也提供了原始Yolo的简单封装推理函数 get_prediction(), 这两个函数返回类型统一为 sahi.prediction.PredictionResult, 这样我们可以方便切换不同predict函数.

步骤3: 使用推理结果对象做进一步处理

预测函数返回类 sahi.prediction.PredictionResult 成员:

  • export_visuals()函数, 可以将推理结果保存为png图片
  • object_prediction_list 成员: 得到 detection object list, 每个detection object 类型都为 ObjectPrediction 类.
  • ObjectPrediction类成员:
    . bbox: BoundingBox: <(321.0, 322.0, 383.0, 363.0), w: 62.0, h: 41.0>,
    . mask: None,
    . score: PredictionScore: <value: 0.9093314409255981>,
    . category: Category: <id: 2, name: car>

代码

import os
from IPython import display
import ultralytics
from ultralytics import YOLO, settings
from os import path
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.predict import get_prediction, get_sliced_prediction
from IPython.display import Image

def yolov8_predict():
    image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
    yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
    model = YOLO(yolov8_model_path)
    results_list = model.predict(source=[image_file1], show=False, save=True, save_conf=True,
                                 save_txt=True)
    for results in results_list:
        boxes = results.boxes
        speed = results.speed
        names = results.names
        json = results.tojson()
        image_path = results.path
        print("====")
        print(image_path)
        print(names)
        print(json)

def sahi_orginal_predict():
    image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
    yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
    config_path=r"D:\my_workspace\py_code\yolo8\Lib\site-packages\ultralytics\cfg\default.yaml",

    # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
    #   比如设置 classes =[2] , 仅仅输出 car 类型
    detection_model=AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=yolov8_model_path,
        confidence_threshold=0.2,
        device="cpu", # or 'cuda:0'
    )

    result = get_prediction(
       image= image_file1,
       detection_model= detection_model,
    )
    for obj in result.object_prediction_list:
        category = obj.category
        #print("====")
        #print(category)

    result.export_visuals(
        export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
        file_name="prediction_visual3",
        hide_labels=False,
        hide_conf=False)
    #Image("demo_data/prediction_visual3.png")


def sahi_sliced_predict():
    image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
    yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"

    # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
    #   比如设置 classes =[2] , 仅仅输出 car 类型
    detection_model=AutoDetectionModel.from_pretrained(
        model_type='yolov8',
        model_path=yolov8_model_path,
        confidence_threshold=0.2,
        device="cpu", # or 'cuda:0'
    )

    result = get_sliced_prediction(
        image= image_file1,
        detection_model= detection_model,
        slice_height=256,
        slice_width=256,
        overlap_height_ratio=0.25,
        overlap_width_ratio=0.25,
        postprocess_type="NMS",
        verbose=2,
     )
    result.export_visuals(
        export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
        file_name="prediction_visual4",
        hide_labels=False,
        hide_conf=False)
    for obj in result.object_prediction_list:
        category = obj.category
        #print("====")
        #print(category)
    #Image("demo_data/prediction_visual4.png")

if __name__ == '__main__':
    display.clear_output()
    ultralytics.checks()
    #yolov8_predict()
    #sahi_orginal_predict()
    sahi_sliced_predict()

posted @ 2024-02-19 15:11  harrychinese  阅读(133)  评论(1编辑  收藏  举报