机器视觉-SAHI
SAHI 资料
yolov8示例代码: https://github.com/obss/sahi/blob/main/demo/inference_for_yolov8.ipynb
测试图像: https://github.com/obss/sahi/blob/main/tests/data/small-vehicles1.jpeg
原理介绍: https://learnopencv.com/slicing-aided-hyper-inference/
sahi命令行使用说明: https://github.com/obss/sahi/blob/main/docs/cli.md#predict-command-usage
步骤1: 模型初始化
SAHI 默认支持yolov5/yolov8/mmdet等多种预测网络, 我们可以直接使用yolov8的预训练模型文件, 下面是集成yolov8模型的示例代码:
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.3,
device="cpu", # or 'cuda:0'
)
步骤2: 进行推理:
SAHI 不仅提供了slice 版推理函数 get_sliced_prediction()
, 而且也提供了原始Yolo的简单封装推理函数 get_prediction()
, 这两个函数返回类型统一为 sahi.prediction.PredictionResult
, 这样我们可以方便切换不同predict函数.
步骤3: 使用推理结果对象做进一步处理
预测函数返回类 sahi.prediction.PredictionResult 成员:
- export_visuals()函数, 可以将推理结果保存为png图片
object_prediction_list
成员: 得到 detection object list, 每个detection object 类型都为 ObjectPrediction 类.- ObjectPrediction类成员:
. bbox: BoundingBox: <(321.0, 322.0, 383.0, 363.0), w: 62.0, h: 41.0>,
. mask: None,
. score: PredictionScore: <value: 0.9093314409255981>,
. category: Category: <id: 2, name: car>
代码
import os
from IPython import display
import ultralytics
from ultralytics import YOLO, settings
from os import path
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.predict import get_prediction, get_sliced_prediction
from IPython.display import Image
def yolov8_predict():
image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
model = YOLO(yolov8_model_path)
results_list = model.predict(source=[image_file1], show=False, save=True, save_conf=True,
save_txt=True)
for results in results_list:
boxes = results.boxes
speed = results.speed
names = results.names
json = results.tojson()
image_path = results.path
print("====")
print(image_path)
print(names)
print(json)
def sahi_orginal_predict():
image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
config_path=r"D:\my_workspace\py_code\yolo8\Lib\site-packages\ultralytics\cfg\default.yaml",
# 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
# 比如设置 classes =[2] , 仅仅输出 car 类型
detection_model=AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.2,
device="cpu", # or 'cuda:0'
)
result = get_prediction(
image= image_file1,
detection_model= detection_model,
)
for obj in result.object_prediction_list:
category = obj.category
#print("====")
#print(category)
result.export_visuals(
export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
file_name="prediction_visual3",
hide_labels=False,
hide_conf=False)
#Image("demo_data/prediction_visual3.png")
def sahi_sliced_predict():
image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
# 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
# 比如设置 classes =[2] , 仅仅输出 car 类型
detection_model=AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.2,
device="cpu", # or 'cuda:0'
)
result = get_sliced_prediction(
image= image_file1,
detection_model= detection_model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
postprocess_type="NMS",
verbose=2,
)
result.export_visuals(
export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
file_name="prediction_visual4",
hide_labels=False,
hide_conf=False)
for obj in result.object_prediction_list:
category = obj.category
#print("====")
#print(category)
#Image("demo_data/prediction_visual4.png")
if __name__ == '__main__':
display.clear_output()
ultralytics.checks()
#yolov8_predict()
#sahi_orginal_predict()
sahi_sliced_predict()