yolov5 onnx部署模型代码,python版本

点击查看代码


import os
import cv2
import numpy as np
import onnxruntime
import time
from tqdm import tqdm
from matplotlib import pyplot as plt
import math
CLASSES = ['jump_cap2', 'jump_cap4']

class YOLOV5():
    def __init__(self, onnxpath):
        self.onnx_session = onnxruntime.InferenceSession(onnxpath, providers=['CPUExecutionProvider'])
        self.input_name = self.get_input_name()
        self.output_name = self.get_output_name()


    def get_input_name(self):
        input_name = []
        for node in self.onnx_session.get_inputs():
            input_name.append(node.name)
        return input_name

    def get_output_name(self):
        output_name = []
        for node in self.onnx_session.get_outputs():
            output_name.append(node.name)
        return output_name


    def get_input_feed(self, img_tensor):
        input_feed = {}
        for name in self.input_name:
            input_feed[name] = img_tensor
        return input_feed


    def inference(self, img_path):
        img = cv2.imread(img_path)
        img_o=img.copy()
        or_img = cv2.resize(img, (640, 640))        # 640x640
        img = or_img[:, :, ::-1].transpose(2, 0, 1)  # BGR2RGB��HWC2CHW
        img = img.astype(dtype=np.float32)
        img /= 255.0
        img = np.expand_dims(img, axis=0)
        input_feed = self.get_input_feed(img)
        pred = self.onnx_session.run(None, input_feed)[0]
        return pred, img_o

def nms(dets, thresh):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]

    areas = (y2 - y1 + 1) * (x2 - x1 + 1)
    scores = dets[:, 4]
    keep = []
    index = scores.argsort()[::-1]

    while index.size > 0:
        i = index[0]
        keep.append(i)

        x11 = np.maximum(x1[i], x1[index[1:]])
        y11 = np.maximum(y1[i], y1[index[1:]])
        x22 = np.minimum(x2[i], x2[index[1:]])
        y22 = np.minimum(y2[i], y2[index[1:]])

        w = np.maximum(0, x22 - x11 + 1)
        h = np.maximum(0, y22 - y11 + 1)

        overlaps = w * h

        ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
        idx = np.where(ious <= thresh)[0]
        index = index[idx + 1]
    return keep


def xywh2xyxy(x):
    # [x, y, w, h] to [x1, y1, x2, y2]
    y = np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2
    y[:, 1] = x[:, 1] - x[:, 3] / 2
    y[:, 2] = x[:, 0] + x[:, 2] / 2
    y[:, 3] = x[:, 1] + x[:, 3] / 2
    return y


def filter_box(org_box, conf_thres, iou_thres):

    org_box = np.squeeze(org_box)
    conf = org_box[..., 4] > conf_thres
    box = org_box[conf == True]

    cls_cinf = box[..., 5:]
    cls = []
    for i in range(len(cls_cinf)):
        cls.append(int(np.argmax(cls_cinf[i])))
    all_cls = list(set(cls))
    output = []


    for i in range(len(all_cls)):
        curr_cls = all_cls[i]
        curr_cls_box = []
        curr_out_box = []
        for j in range(len(cls)):
            if cls[j] == curr_cls:
                box[j][5] = curr_cls
                curr_cls_box.append(box[j][:6])
        curr_cls_box = np.array(curr_cls_box)

        curr_cls_box = xywh2xyxy(curr_cls_box)
        curr_out_box = nms(curr_cls_box, iou_thres)
        for k in curr_out_box:
            output.append(curr_cls_box[k])
    output = np.array(output)
    return output


def draw(image, box_data):

    boxes = box_data[..., :4].astype(np.int32)
    scores = box_data[..., 4]
    classes = box_data[..., 5].astype(np.int32)

    print("132:boxes",boxes)
    img_height_o=image.shape[0]
    img_width_o=image.shape[1]
    x_ratio=img_width_o/640
    y_ratio=img_height_o/640

    max_area = 0
    best_rec = None
    best_cls = None
    for box, score, cl in zip(boxes, scores, classes):

        top, left, right, bottom = box

        top=int(top*x_ratio)
        right=int(right*x_ratio)
        left=int(left*y_ratio)
        bottom=int(bottom*y_ratio)

        print("149:top",top)
        print("149:right",right)
        print("149:left",left)
        print("149:bottom",bottom)

        w = top-right
        h = left -bottom
        area = w*h
        if area > max_area:
            max_area = area
            best_rec = [top, left, right, bottom]
            best_cls = cl
    top, left, right, bottom = best_rec
    cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 1)

    print("164:best_rec",best_rec)
    # print("166:best_cls",best_cls)

    # plt.imshow(image)
    # plt.show()

    return best_rec, best_cls


def onnx_infer(input, output, onnx_path, pin_rec_list, single_inference=False, conf_thres=0.1, iou_thres=0.1):

    if single_inference:
        model = YOLOV5(onnx_path)
        outbox, or_img = model.inference(input)
        outbox = filter_box(outbox, conf_thres, iou_thres)

        best_rec, best_cls = draw(or_img, outbox)
        file_name = input.split('/')[-1]
        output = os.path.join(output, file_name)
        cv2.imwrite(output, or_img)


    else:
        img_list = os.listdir(input)
        for img in tqdm(img_list):
            img_path = os.path.join(input, img)
            model = YOLOV5(onnx_path)
            outbox, or_img = model.inference(img_path)
            outbox = filter_box(outbox, conf_thres, iou_thres)
            best_rec = []
            try:
                best_rec, best_cls = draw(or_img, outbox)
            except Exception as e:
                continue
            file_name = img.split('/')[-1]
            save_path = os.path.join(output, file_name)
            cv2.imwrite(save_path, or_img)


if __name__ == "__main__":

    start_time = time.time()
    onnx_path = 'model/best.onnx'
    model = YOLOV5(onnx_path)
    img_dir_path = 'testpicture'
    # img_dir_path = 'testpic'
    img_list = os.listdir(img_dir_path)
    save_dir_path = 'detimg/'

    if not os.path.exists(save_dir_path):
        os.makedirs(save_dir_path)
    for img_name in tqdm(img_list):
        img_path = os.path.join(img_dir_path, img_name)
        output, or_img = model.inference(img_path)
        outbox = filter_box(output, 0.25, 0.7)
        try:
            draw(or_img, outbox)


        except Exception as e:
            print(img_name)
            continue
        save_img_path = os.path.join(save_dir_path, img_name)
        cv2.imwrite(save_img_path, or_img)

    end_time = time.time()
    print('inference time: {:.2f}'.format(end_time - start_time))

    pass`
posted @   little_cute  阅读(23)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示