yolov8模型转为onnx后的推理测试(分为两个py文件)

点击查看代码
import torch

from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_requirements, check_yaml
import numpy as np
import cv2
import onnxruntime as ort

class YOLOv8:
    """YOLOv8 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, input_image, confidence_thres, iou_thres):
        """
        Initializes an instance of the YOLOv8 class.

        Args:
            onnx_model: Path to the ONNX model.
            input_image: Path to the input image.
            confidence_thres: Confidence threshold for filtering detections.
            iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
        """
        self.onnx_model = onnx_model
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres

        # Load the class names from the COCO dataset
        self.classes = yaml_load(check_yaml("mycoco.yaml"))["names"]

        # Generate a color palette for the classes
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))

    def draw_detections(self, img, box, score, class_id):
        """
        Draws bounding boxes and labels on the input image based on the detected objects.

        Args:
            img: The input image to draw detections on.
            box: Detected bounding box.
            score: Corresponding detection score.
            class_id: Class ID for the detected object.

        Returns:
            None
        """

        # Extract the coordinates of the bounding box
        x1, y1, w, h = box
        # Retrieve the color for the class ID
        color = self.color_palette[class_id]
        # Draw the bounding box on the image
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
        # Create the label text with class name and score
        label = f"{self.classes[class_id]}: {score:.2f}"
        # Calculate the dimensions of the label text
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        # Calculate the position of the label text
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
        # Draw a filled rectangle as the background for the label text
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )
        # Draw the label text on the image
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

        # Define the rectangle parameters
        x = max(int(x1), 0)
        y = max(int(y1), 0)
        w = int(x1 + w) - x
        h = int(y1 + h) - y

        # Return the parameters as a list
        return [x, y, w, h]


    def preprocess(self):
        """
        Preprocesses the input image before performing inference.

        Returns:
            image_data: Preprocessed image data ready for inference.
        """
        # Read the input image using OpenCV
        self.img = cv2.imread(self.input_image)

        # Get the height and width of the input image
        self.img_height, self.img_width = self.img.shape[:2]

        # Convert the image color space from BGR to RGB
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

        # Resize the image to match the input shape
        img = cv2.resize(img, (self.input_width, self.input_height))

        # Normalize the image data by dividing it by 255.0
        image_data = np.array(img) / 255.0

        # Transpose the image to have the channel dimension as the first dimension
        image_data = np.transpose(image_data, (2, 0, 1))  # Channel first

        # Expand the dimensions of the image data to match the expected input shape
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)

        # Return the preprocessed image data
        return image_data

    def postprocess(self, input_image, output):
        """
        Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.

        Args:
            input_image (numpy.ndarray): The input image.
            output (numpy.ndarray): The output of the model.

        Returns:
            numpy.ndarray: The input image with detections drawn on it.
        """

        # Transpose and squeeze the output to match the expected shape
        outputs = np.transpose(np.squeeze(output[0]))

        # Get the number of rows in the outputs array
        rows = outputs.shape[0]

        # Lists to store the bounding boxes, scores, and class IDs of the detections
        boxes = []
        scores = []
        class_ids = []

        # Calculate the scaling factors for the bounding box coordinates
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        # Iterate over each row in the outputs array
        for i in range(rows):
            # Extract the class scores from the current row
            classes_scores = outputs[i][4:]

            # Find the maximum score among the class scores
            max_score = np.amax(classes_scores)

            # If the maximum score is above the confidence threshold
            if max_score >= self.confidence_thres:
                # Get the class ID with the highest score
                class_id = np.argmax(classes_scores)

                # Extract the bounding box coordinates from the current row
                x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]

                # Calculate the scaled coordinates of the bounding box
                left = int((x - w / 2) * x_factor)
                top = int((y - h / 2) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                # Add the class ID, score, and box coordinates to the respective lists
                class_ids.append(class_id)
                scores.append(max_score)
                boxes.append([left, top, width, height])

        # Apply non-maximum suppression to filter out overlapping bounding boxes
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)

        # Iterate over the selected indices after non-maximum suppression
        for i in indices:
            # Get the box, score, and class ID corresponding to the index
            box = boxes[i]
            score = scores[i]
            class_id = class_ids[i]

            # Draw the detection on the input image
            rectangle_params = self.draw_detections(input_image, box, score, class_id)

        # Return the modified input image
        return input_image, rectangle_params

    def main(self):
        """
        Performs inference using an ONNX model and returns the output image with drawn detections.

        Returns:
            output_img: The output image with drawn detections.
        """
        # Create an inference session using the ONNX model and specify execution providers
        session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
        # Get the model inputs
        model_inputs = session.get_inputs()
        # Store the shape of the input for later use
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]
        # Preprocess the image data
        img_data = self.preprocess()
        # Run inference using the preprocessed image data
        outputs = session.run(None, {model_inputs[0].name: img_data})
        # Perform post-processing on the outputs to obtain output image.
        return self.postprocess(self.img, outputs)  # output image

class YOLOv8Inference:
    def __init__(self, model_path, conf_thres, iou_thres):
        # Check the requirements and select the appropriate backend (CPU or GPU)
        check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")

        # Initialize the YOLOv8 instance
        self.detection = YOLOv8(model_path, None, conf_thres, iou_thres)

    def process_image(self, img_path):
        # Set the image path for the YOLOv8 instance
        print("207:img_path",img_path)
        self.detection.input_image = img_path

        # Perform object detection and obtain the output image
        output_image, rectangles = self.detection.main()

        return output_image, rectangles
点击查看代码
# Ultralytics YOLO 🚀, AGPL-3.0 license

import argparse
import time

import cv2
import datetime
from examples.YOLOV8 import YOLOv8Inference




if __name__ == "__main__":

    # img_path = "D:/work/wugong/ultralytics-main/TESTONNX/INPUTimg/7.png"
    # img_path = "D:/work/wugong/ultralytics-main/TESTONNX/INPUTimg/8.png"
    # img_path = "D:/work/wugong/ultralytics-main/TESTONNX/INPUTimg/9.png"
    # img_path = "D:/work/wugong/ultralytics-main/TESTONNX/INPUTimg/10.png"
    img_path = "D:/work/wugong/ultralytics-main/TESTONNX/INPUTimg/11.png"
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    # Create an argument parser to handle command-line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str,
                        default="./runs/detect/train12/weights/best.onnx",
                        help="Input your ONNX model.")
    parser.add_argument("--img", type=str, default=img_path, help="Path to input image.")
    parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
    parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
    args = parser.parse_args()

    # Create an instance of YOLOv8Inference with the specified arguments
    yolo_inference = YOLOv8Inference(args.model, args.conf_thres, args.iou_thres)

    start = time.time()
    # Process the input image
    output_image, rectangles = yolo_inference.process_image(args.img)
    end = time.time()
    print("38rectangles:",rectangles)
    print(f"total cost time:{end - start}")
    cv2.imwrite(f"./TESTONNX/OUTPUTimg/{current_time}.png", output_image)

    # # Display the output image in a window
    # cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
    # cv2.imshow("Output", output_image)
    #
    # # Wait for a key press to exit
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()
posted @   little_cute  阅读(99)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
点击右上角即可分享
微信分享提示