Yolov8-源码解析-四十三-

Yolov8 源码解析(四十三)

.\yolov8\ultralytics\utils\patches.py

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Monkey patches to update/extend functionality of existing functions."""

import time
from pathlib import Path

import cv2  # 导入OpenCV库
import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库

# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
_imshow = cv2.imshow  # 将cv2.imshow赋值给_imshow变量,避免递归错误


def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
    """
    Read an image from a file.

    Args:
        filename (str): Path to the file to read.
        flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.

    Returns:
        (np.ndarray): The read image.
    """
    return cv2.imdecode(np.fromfile(filename, np.uint8), flags)  # 使用cv2.imdecode函数读取文件并返回图像数据


def imwrite(filename: str, img: np.ndarray, params=None):
    """
    Write an image to a file.

    Args:
        filename (str): Path to the file to write.
        img (np.ndarray): Image to write.
        params (list of ints, optional): Additional parameters. See OpenCV documentation.

    Returns:
        (bool): True if the file was written, False otherwise.
    """
    try:
        cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)  # 使用cv2.imencode将图像编码并写入文件
        return True
    except Exception:
        return False


def imshow(winname: str, mat: np.ndarray):
    """
    Displays an image in the specified window.

    Args:
        winname (str): Name of the window.
        mat (np.ndarray): Image to be shown.
    """
    _imshow(winname.encode("unicode_escape").decode(), mat)  # 使用_imshow显示指定名称的窗口中的图像


# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_load = torch.load  # 将torch.load赋值给_torch_load变量,避免递归错误
_torch_save = torch.save


def torch_load(*args, **kwargs):
    """
    Load a PyTorch model with updated arguments to avoid warnings.

    This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.

    Args:
        *args (Any): Variable length argument list to pass to torch.load.
        **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.

    Returns:
        (Any): The loaded PyTorch object.

    Note:
        For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
        if the argument is not provided, to avoid deprecation warnings.
    """
    from ultralytics.utils.torch_utils import TORCH_1_13  # 导入TORCH_1_13变量,用于检测PyTorch版本

    if TORCH_1_13 and "weights_only" not in kwargs:
        kwargs["weights_only"] = False  # 如果使用的是PyTorch 1.13及以上版本且没有指定'weights_only'参数,则设置为False

    return _torch_load(*args, **kwargs)  # 调用torch.load加载模型


def torch_save(*args, use_dill=True, **kwargs):
    """
    Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
    exponential standoff in case of save failure.

    ```py
    # 此处代码块是省略部分,不需要注释
    ```
    """
    pass  # torch_save函数暂时没有实现内容,直接返回
    """
    Args:
        *args (tuple): Positional arguments to pass to torch.save.
        use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
        **kwargs (Any): Keyword arguments to pass to torch.save.
    """
    # 尝试使用 dill 序列化库(如果可用),否则使用 pickle
    try:
        assert use_dill
        import dill as pickle
    except (AssertionError, ImportError):
        import pickle

    # 如果 kwargs 中没有指定 pickle_module,则默认使用 pickle 库
    if "pickle_module" not in kwargs:
        kwargs["pickle_module"] = pickle

    # 最多尝试保存 4 次(包括初始尝试),以处理可能的运行时错误
    for i in range(4):  # 3 retries
        try:
            # 调用 _torch_save 函数尝试保存数据
            return _torch_save(*args, **kwargs)
        except RuntimeError as e:  # unable to save, possibly waiting for device to flush or antivirus scan
            # 如果是最后一次尝试保存,则抛出原始的 RuntimeError
            if i == 3:
                raise e
            # 等待指数增长的时间,用于避免设备刷新或者反病毒扫描等问题
            time.sleep((2**i) / 2)  # exponential standoff: 0.5s, 1.0s, 2.0s

.\yolov8\ultralytics\utils\plotting.py

# 导入需要的库
import contextlib  # 上下文管理模块,用于创建上下文管理器
import math  # 数学函数模块,提供数学函数的实现
import warnings  # 警告模块,用于处理警告信息
from pathlib import Path  # 路径操作模块,用于处理文件和目录路径
from typing import Callable, Dict, List, Optional, Union  # 类型提示模块,用于类型注解

import cv2  # OpenCV图像处理库
import matplotlib.pyplot as plt  # 绘图库matplotlib的pyplot模块
import numpy as np  # 数值计算库numpy
import torch  # 深度学习框架PyTorch
from PIL import Image, ImageDraw, ImageFont  # Python Imaging Library,用于图像处理

from PIL import __version__ as pil_version  # PIL版本信息

from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded  # 导入自定义工具函数和变量
from ultralytics.utils.checks import check_font, check_version, is_ascii  # 导入自定义检查函数
from ultralytics.utils.files import increment_path  # 导入路径处理函数

# 颜色类,包含Ultralytics默认色彩方案和转换函数
class Colors:
    """
    Ultralytics default color palette https://ultralytics.com/.

    This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
    RGB values.

    Attributes:
        palette (list of tuple): List of RGB color values.
        n (int): The number of colors in the palette.
        pose_palette (np.ndarray): A specific color palette array with dtype np.uint8.
    """

    def __init__(self):
        """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
        hexs = (
            "042AFF",
            "0BDBEB",
            "F3F3F3",
            "00DFB7",
            "111F68",
            "FF6FDD",
            "FF444F",
            "CCED00",
            "00F344",
            "BD00FF",
            "00B4FF",
            "DD00BA",
            "00FFFF",
            "26C000",
            "01FFB3",
            "7D24FF",
            "7B0068",
            "FF1B6C",
            "FC6D2F",
            "A2FF0B",
        )
        # 初始化颜色调色板,将16进制颜色代码转换为RGB元组
        self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
        self.n = len(self.palette)
        # 预定义特定颜色调色板,用于特定应用场景
        self.pose_palette = np.array(
            [
                [255, 128, 0],
                [255, 153, 51],
                [255, 178, 102],
                [230, 230, 0],
                [255, 153, 255],
                [153, 204, 255],
                [255, 102, 255],
                [255, 51, 255],
                [102, 178, 255],
                [51, 153, 255],
                [255, 153, 153],
                [255, 102, 102],
                [255, 51, 51],
                [153, 255, 153],
                [102, 255, 102],
                [51, 255, 51],
                [0, 255, 0],
                [0, 0, 255],
                [255, 0, 0],
                [255, 255, 255],
            ],
            dtype=np.uint8,
        )

    def __call__(self, i, bgr=False):
        """Converts hex color codes to RGB values."""
        # 返回调色板中第i个颜色的RGB值,支持BGR格式
        c = self.palette[int(i) % self.n]
        return (c[2], c[1], c[0]) if bgr else c

    @staticmethod
    def hex2rgb(h):
        """Converts hex color codes to RGB values (i.e. default PIL order)."""
        # 将16进制颜色代码转换为RGB元组(PIL默认顺序)
        return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))


colors = Colors()  # 创建颜色对象实例,用于绘图颜色选择

class Annotator:
    """
    Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
    """
    # 定义类属性,用于图像注释
    Attributes:
        # im是要注释的图像,可以是PIL图像(Image.Image)或者numpy数组
        im (Image.Image or numpy array): The image to annotate.
        # pil标志指示是否使用PIL库进行注释,而不是cv2
        pil (bool): Whether to use PIL or cv2 for drawing annotations.
        # font用于文本注释的字体,可以是ImageFont.truetype或ImageFont.load_default
        font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations.
        # lw是用于绘制注释的线条宽度
        lw (float): Line width for drawing.
        # skeleton是关键点的骨架结构的列表,其中每个元素是一个列表,表示连接的两个关键点的索引
        skeleton (List[List[int]]): Skeleton structure for keypoints.
        # limb_color是绘制骨架连接的颜色调色板,以RGB整数列表形式表示
        limb_color (List[int]): Color palette for limbs.
        # kpt_color是绘制关键点的颜色调色板,以RGB整数列表形式表示
        kpt_color (List[int]): Color palette for keypoints.
    """
    # 初始化 Annotator 类,接受图像 im、线宽 line_width、字体大小 font_size、字体名称 font、是否使用 PIL 的标志 pil、示例 example
    def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"):
        """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
        # 检查示例是否包含非 ASCII 字符,用于确定是否使用 PIL
        non_ascii = not is_ascii(example)  # non-latin labels, i.e. asian, arabic, cyrillic
        # 检查输入的图像是否为 PIL Image 对象
        input_is_pil = isinstance(im, Image.Image)
        # 根据条件判断是否使用 PIL
        self.pil = pil or non_ascii or input_is_pil
        # 计算线宽,默认为图像尺寸或形状的一半乘以 0.003,取整后至少为 2
        self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
        
        if self.pil:  # 如果使用 PIL
            # 如果输入的是 PIL Image,则直接使用;否则将其转换为 PIL Image
            self.im = im if input_is_pil else Image.fromarray(im)
            # 创建一个用于绘制的 ImageDraw 对象
            self.draw = ImageDraw.Draw(self.im)
            try:
                # 根据示例中是否包含非 ASCII 字符,选择适当的字体文件(Unicode 或 Latin)
                font = check_font("Arial.Unicode.ttf" if non_ascii else font)
                # 计算字体大小,默认为图像尺寸的一半乘以 0.035,取整后至少为 12
                size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
                # 加载选择的字体文件并设置字体大小
                self.font = ImageFont.truetype(str(font), size)
            except Exception:
                # 如果加载字体文件出错,则使用默认字体
                self.font = ImageFont.load_default()
            # 如果 PIL 版本高于等于 9.2.0,则修复 getsize 方法的用法为 getbbox 方法的结果中的宽度和高度
            if check_version(pil_version, "9.2.0"):
                self.font.getsize = lambda x: self.font.getbbox(x)[2:4]  # text width, height
        else:  # 如果使用 cv2
            # 断言输入的图像数据是连续的,否则提出警告
            assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
            # 如果图像数据不可写,则创建其副本
            self.im = im if im.flags.writeable else im.copy()
            # 计算字体粗细,默认为线宽减 1,至少为 1
            self.tf = max(self.lw - 1, 1)  # font thickness
            # 计算字体缩放比例,默认为线宽的三分之一
            self.sf = self.lw / 3  # font scale
        
        # 姿态关键点的连接关系
        self.skeleton = [
            [16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12],
            [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3],
            [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7],
        ]

        # 姿态关键点连接线的颜色
        self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
        # 姿态关键点的颜色
        self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
        
        # 深色调色板,用于姿态显示
        self.dark_colors = {
            (235, 219, 11), (243, 243, 243), (183, 223, 0), (221, 111, 255),
            (0, 237, 204), (68, 243, 0), (255, 255, 0), (179, 255, 1),
            (11, 255, 162),
        }
        # 浅色调色板,用于姿态显示
        self.light_colors = {
            (255, 42, 4), (79, 68, 255), (255, 0, 189), (255, 180, 0),
            (186, 0, 221), (0, 192, 38), (255, 36, 125), (104, 0, 123),
            (108, 27, 255), (47, 109, 252), (104, 31, 17),
        }
    def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)):
        """Assign text color based on background color."""
        # 检查给定的背景颜色是否为暗色
        if color in self.dark_colors:
            # 如果是暗色,则返回预定义的深色文本颜色
            return 104, 31, 17
        elif color in self.light_colors:
            # 如果是亮色,则返回白色作为文本颜色
            return 255, 255, 255
        else:
            # 如果背景颜色既不是暗色也不是亮色,则返回默认的文本颜色
            return txt_color

    def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2):
        """
        Draws a label with a background rectangle centered within a given bounding box.

        Args:
            box (tuple): The bounding box coordinates (x1, y1, x2, y2).
            label (str): The text label to be displayed.
            color (tuple, optional): The background color of the rectangle (R, G, B).
            txt_color (tuple, optional): The color of the text (R, G, B).
            margin (int, optional): The margin between the text and the rectangle border.
        """

        # 如果标签超过3个字符,打印警告信息,并仅使用前三个字符作为圆形标注的文本
        if len(label) > 3:
            print(
                f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!"
            )
            label = label[:3]

        # 计算框的中心点坐标
        x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
        # 获取文本的大小
        text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0]
        # 计算需要的半径,以适应文本和边距
        required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin
        # 在图像上绘制圆形标注
        cv2.circle(self.im, (x_center, y_center), required_radius, color, -1)
        # 计算文本位置
        text_x = x_center - text_size[0] // 2
        text_y = y_center + text_size[1] // 2
        # 绘制文本
        cv2.putText(
            self.im,
            str(label),
            (text_x, text_y),
            cv2.FONT_HERSHEY_SIMPLEX,
            self.sf - 0.15,
            # 获取文本颜色,根据背景颜色自动选择
            self.get_txt_color(color, txt_color),
            self.tf,
            lineType=cv2.LINE_AA,
        )
    def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5):
        """
        Draws a label with a background rectangle centered within a given bounding box.

        Args:
            box (tuple): The bounding box coordinates (x1, y1, x2, y2).
            label (str): The text label to be displayed.
            color (tuple, optional): The background color of the rectangle (R, G, B).
            txt_color (tuple, optional): The color of the text (R, G, B).
            margin (int, optional): The margin between the text and the rectangle border.
        """

        # Calculate the center of the bounding box
        x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
        # Get the size of the text
        text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0]
        # Calculate the top-left corner of the text (to center it)
        text_x = x_center - text_size[0] // 2
        text_y = y_center + text_size[1] // 2
        # Calculate the coordinates of the background rectangle
        rect_x1 = text_x - margin
        rect_y1 = text_y - text_size[1] - margin
        rect_x2 = text_x + text_size[0] + margin
        rect_y2 = text_y + margin
        # Draw the background rectangle
        cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1)
        # Draw the text on top of the rectangle
        cv2.putText(
            self.im,  # 目标图像,在其上绘制
            label,  # 要绘制的文本
            (text_x, text_y),  # 文本的起始坐标(左下角位置)
            cv2.FONT_HERSHEY_SIMPLEX,  # 字体类型
            self.sf - 0.1,  # 字体比例因子
            self.get_txt_color(color, txt_color),  # 文本颜色
            self.tf,  # 文本线宽
            lineType=cv2.LINE_AA,  # 线型
        )
    def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
        """
        Plot masks on image.

        Args:
            masks (tensor): Predicted masks on cuda, shape: [n, h, w]
            colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n]
            im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1]
            alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque
            retina_masks (bool): Whether to use high resolution masks or not. Defaults to False.
        """

        # 如果使用 PIL,先转换为 numpy 数组
        if self.pil:
            self.im = np.asarray(self.im).copy()

        # 如果没有预测到任何 mask,则直接将原始图像拷贝到 self.im
        if len(masks) == 0:
            self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255

        # 如果图像和 masks 不在同一个设备上,则将 im_gpu 移动到 masks 所在的设备上
        if im_gpu.device != masks.device:
            im_gpu = im_gpu.to(masks.device)

        # 将 colors 转换为 torch.tensor,并归一化到 [0, 1] 的范围
        colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0  # shape(n,3)

        # 扩展维度以便进行广播操作,将 colors 变为 shape(n,1,1,3)
        colors = colors[:, None, None]  # shape(n,1,1,3)

        # 增加一个维度到 masks 上,使其变为 shape(n,h,w,1)
        masks = masks.unsqueeze(3)  # shape(n,h,w,1)

        # 将 masks 与颜色相乘,乘以 alpha 控制透明度,得到彩色的 masks,shape(n,h,w,3)
        masks_color = masks * (colors * alpha)

        # 计算反向透明度 masks,用于混合原始图像和 masks_color,shape(n,h,w,1)
        inv_alpha_masks = (1 - masks * alpha).cumprod(0)

        # 计算最大通道值,用于融合图像和 masks_color,shape(n,h,w,3)
        mcs = masks_color.max(dim=0).values  # shape(n,h,w,3)

        # 翻转图像的通道顺序,从 RGB 转为 BGR
        im_gpu = im_gpu.flip(dims=[0])

        # 调整张量的维度顺序,从 (3,h,w) 转为 (h,w,3)
        im_gpu = im_gpu.permute(1, 2, 0).contiguous()

        # 使用 inv_alpha_masks[-1] 和 mcs 进行图像的混合
        im_gpu = im_gpu * inv_alpha_masks[-1] + mcs

        # 将混合后的图像乘以 255,并转为 numpy 数组
        im_mask = im_gpu * 255
        im_mask_np = im_mask.byte().cpu().numpy()

        # 根据 retina_masks 参数选择是否缩放图像
        self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)

        # 如果使用 PIL,将处理后的 numpy 数组转回 PIL 格式,并更新 draw
        if self.pil:
            self.fromarray(self.im)
    def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25):
        """
        Plot keypoints on the image.

        Args:
            kpts (tensor): Predicted keypoints with shape [17, 3]. Each keypoint has (x, y, confidence).
            shape (tuple): Image shape as a tuple (h, w), where h is the height and w is the width.
            radius (int, optional): Radius of the drawn keypoints. Default is 5.
            kpt_line (bool, optional): If True, the function will draw lines connecting keypoints
                                       for human pose. Default is True.

        Note:
            `kpt_line=True` currently only supports human pose plotting.
        """

        if self.pil:
            # If working with PIL image, convert to numpy array for processing
            self.im = np.asarray(self.im).copy()  # Convert PIL image to numpy array
        
        # Get the number of keypoints and dimensions from the input tensor
        nkpt, ndim = kpts.shape
        # Check if the keypoints represent a human pose (17 keypoints with 2 or 3 dimensions)
        is_pose = nkpt == 17 and ndim in {2, 3}
        # Adjust kpt_line based on whether it's a valid human pose and the argument value
        kpt_line &= is_pose  # `kpt_line=True` for now only supports human pose plotting

        # Loop through each keypoint and plot a circle on the image
        for i, k in enumerate(kpts):
            # Determine color for the keypoint based on whether it's a pose or not
            color_k = [int(x) for x in self.kpt_color[i]] if is_pose else colors(i)
            x_coord, y_coord = k[0], k[1]
            # Check if the keypoint coordinates are within image boundaries
            if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
                # If confidence score is provided (3 dimensions), skip keypoints below threshold
                if len(k) == 3:
                    conf = k[2]
                    if conf < conf_thres:
                        continue
                # Draw a circle on the image at the keypoint location
                cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)

        # If kpt_line is True, draw lines connecting keypoints (for human pose)
        if kpt_line:
            ndim = kpts.shape[-1]
            # Iterate over predefined skeleton connections and draw lines between keypoints
            for i, sk in enumerate(self.skeleton):
                pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
                pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
                # If confidence scores are provided, skip lines for keypoints below threshold
                if ndim == 3:
                    conf1 = kpts[(sk[0] - 1), 2]
                    conf2 = kpts[(sk[1] - 1), 2]
                    if conf1 < conf_thres or conf2 < conf_thres:
                        continue
                # Check if keypoints' positions are within image boundaries before drawing lines
                if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
                    continue
                if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
                    continue
                # Draw a line connecting two keypoints on the image
                cv2.line(self.im, pos1, pos2, [int(x) for x in self.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)

        if self.pil:
            # Convert numpy array (image) back to PIL image format and update self.im
            self.fromarray(self.im)  # Convert numpy array back to PIL image

    def rectangle(self, xy, fill=None, outline=None, width=1):
        """Add rectangle to image (PIL-only)."""
        self.draw.rectangle(xy, fill, outline, width)
    def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False):
        """Adds text to an image using PIL or cv2."""
        # 如果锚点是"bottom",从字体底部开始计算y坐标
        if anchor == "bottom":  # start y from font bottom
            w, h = self.font.getsize(text)  # 获取文本的宽度和高度
            xy[1] += 1 - h
        if self.pil:
            # 如果需要使用方框样式
            if box_style:
                w, h = self.font.getsize(text)  # 获取文本的宽度和高度
                # 在图像上绘制一个矩形框作为背景,并使用txt_color填充
                self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
                # 将txt_color作为背景颜色,将文本以白色填充前景绘制
                txt_color = (255, 255, 255)
            # 如果文本中包含换行符
            if "\n" in text:
                lines = text.split("\n")  # 拆分成多行文本
                _, h = self.font.getsize(text)  # 获取单行文本的高度
                for line in lines:
                    self.draw.text(xy, line, fill=txt_color, font=self.font)  # 绘制每一行文本
                    xy[1] += h  # 更新y坐标以绘制下一行文本
            else:
                self.draw.text(xy, text, fill=txt_color, font=self.font)  # 绘制单行文本
        else:
            # 如果需要使用方框样式
            if box_style:
                w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]  # 获取文本的宽度和高度
                h += 3  # 增加一些像素以填充文本
                outside = xy[1] >= h  # 判断标签是否适合在框外
                p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
                cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA)  # 填充矩形框
                # 将txt_color作为背景颜色,将文本以白色填充前景绘制
                txt_color = (255, 255, 255)
            cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)  # 使用cv2绘制文本

    def fromarray(self, im):
        """Update self.im from a numpy array."""
        self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)  # 将numpy数组或PIL图像赋值给self.im
        self.draw = ImageDraw.Draw(self.im)  # 使用PIL的ImageDraw创建绘图对象

    def result(self):
        """Return annotated image as array."""
        return np.asarray(self.im)  # 将PIL图像转换为numpy数组并返回

    def show(self, title=None):
        """Show the annotated image."""
        Image.fromarray(np.asarray(self.im)[..., ::-1]).show(title)  # 将numpy数组转换为RGB模式的PIL图像并显示

    def save(self, filename="image.jpg"):
        """Save the annotated image to 'filename'."""
        cv2.imwrite(filename, np.asarray(self.im))  # 将numpy数组保存为图像文件

    def get_bbox_dimension(self, bbox=None):
        """
        Calculate the area of a bounding box.

        Args:
            bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).

        Returns:
            angle (degree): Degree value of angle between three points
        """
        x_min, y_min, x_max, y_max = bbox  # 解构包围框坐标
        width = x_max - x_min  # 计算包围框宽度
        height = y_max - y_min  # 计算包围框高度
        return width, height, width * height  # 返回宽度、高度和面积的元组
    def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
        """
        Draw region line.

        Args:
            reg_pts (list): Region Points (for line 2 points, for region 4 points)
            color (tuple): Region Color value
            thickness (int): Region area thickness value
        """

        # 使用 cv2.polylines 方法在图像上绘制多边形线段,reg_pts 是多边形的顶点坐标
        cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)

    def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
        """
        Draw centroid point and track trails.

        Args:
            track (list): object tracking points for trails display
            color (tuple): tracks line color
            track_thickness (int): track line thickness value
        """

        # 将轨迹点连接成连续的线段,并绘制到图像上
        points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
        cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)

        # 在轨迹的最后一个点处画一个实心圆圈,表示物体的当前位置
        cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)

    def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)):
        """
        Displays queue counts on an image centered at the points with customizable font size and colors.

        Args:
            label (str): queue counts label
            points (tuple): region points for center point calculation to display text
            region_color (RGB): queue region color
            txt_color (RGB): text display color
        """

        # 计算区域中心点的坐标
        x_values = [point[0] for point in points]
        y_values = [point[1] for point in points]
        center_x = sum(x_values) // len(points)
        center_y = sum(y_values) // len(points)

        # 计算显示文本的大小和位置
        text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
        text_width = text_size[0]
        text_height = text_size[1]

        rect_width = text_width + 20
        rect_height = text_height + 20
        rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
        rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)

        # 在图像上绘制一个填充的矩形框作为背景
        cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)

        text_x = center_x - text_width // 2
        text_y = center_y + text_height // 2

        # 在指定位置绘制文本
        cv2.putText(
            self.im,
            label,
            (text_x, text_y),
            0,
            fontScale=self.sf,
            color=txt_color,
            thickness=self.tf,
            lineType=cv2.LINE_AA,
        )
    def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin):
        """
        Display the bounding boxes labels in parking management app.

        Args:
            im0 (ndarray): inference image
            text (str): object/class name
            txt_color (bgr color): display color for text foreground
            bg_color (bgr color): display color for text background
            x_center (float): x position center point for bounding box
            y_center (float): y position center point for bounding box
            margin (int): gap between text and rectangle for better display
        """

        # Calculate the size of the text to be displayed
        text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
        # Calculate the x and y coordinates for placing the text centered at (x_center, y_center)
        text_x = x_center - text_size[0] // 2
        text_y = y_center + text_size[1] // 2

        # Calculate the coordinates of the rectangle surrounding the text
        rect_x1 = text_x - margin
        rect_y1 = text_y - text_size[1] - margin
        rect_x2 = text_x + text_size[0] + margin
        rect_y2 = text_y + margin
        # Draw a filled rectangle with specified background color around the text
        cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
        # Draw the text on the image at (text_x, text_y)
        cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)

    def display_analytics(self, im0, text, txt_color, bg_color, margin):
        """
        Display the overall statistics for parking lots.

        Args:
            im0 (ndarray): inference image
            text (dict): labels dictionary
            txt_color (bgr color): display color for text foreground
            bg_color (bgr color): display color for text background
            margin (int): gap between text and rectangle for better display
        """

        # Calculate horizontal and vertical gaps based on image dimensions
        horizontal_gap = int(im0.shape[1] * 0.02)
        vertical_gap = int(im0.shape[0] * 0.01)
        text_y_offset = 0  # Initialize offset for vertical placement of text

        # Iterate through each label and value pair in the provided dictionary
        for label, value in text.items():
            txt = f"{label}: {value}"  # Format the label and value into a string
            # Calculate the size of the text to be displayed
            text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
            # Ensure minimum size for text dimensions to avoid errors
            if text_size[0] < 5 or text_size[1] < 5:
                text_size = (5, 5)
            # Calculate the x and y coordinates for placing the text on the image
            text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
            text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
            # Calculate the coordinates of the rectangle surrounding the text
            rect_x1 = text_x - margin * 2
            rect_y1 = text_y - text_size[1] - margin * 2
            rect_x2 = text_x + text_size[0] + margin * 2
            rect_y2 = text_y + margin * 2
            # Draw a filled rectangle with specified background color around the text
            cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
            # Draw the text on the image at (text_x, text_y)
            cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
            # Update the vertical offset for placing the next text block
            text_y_offset = rect_y2

    @staticmethod
    def estimate_pose_angle(a, b, c):
        """
        Calculate the pose angle between three points.

        Args:
            a (float) : The coordinates of pose point a
            b (float): The coordinates of pose point b
            c (float): The coordinates of pose point c

        Returns:
            angle (degree): Degree value of the angle between the points
        """

        # Convert input points to numpy arrays for calculations
        a, b, c = np.array(a), np.array(b), np.array(c)

        # Calculate the angle using arctangent and convert from radians to degrees
        radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
        angle = np.abs(radians * 180.0 / np.pi)

        # Normalize angle to be within [0, 180] degrees
        if angle > 180.0:
            angle = 360 - angle

        return angle

    def draw_specific_points(self, keypoints, indices=None, shape=(640, 640), radius=2, conf_thres=0.25):
        """
        Draw specific keypoints on an image.

        Args:
            keypoints (list): List of keypoints to be plotted
            indices (list): Indices of keypoints to be plotted
            shape (tuple): Size of the image (width, height)
            radius (int): Radius of the keypoints
            conf_thres (float): Confidence threshold for keypoints
        """

        # If indices are not provided, default to drawing keypoints 2, 5, and 7
        if indices is None:
            indices = [2, 5, 7]

        # Iterate through keypoints and draw circles for specific indices
        for i, k in enumerate(keypoints):
            if i in indices:
                x_coord, y_coord = k[0], k[1]

                # Check if the keypoints are within image bounds
                if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
                    if len(k) == 3:
                        conf = k[2]
                        # Skip drawing if confidence is below the threshold
                        if conf < conf_thres:
                            continue

                    # Draw a circle on the image at the keypoint coordinates
                    cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, (0, 255, 0), -1, lineType=cv2.LINE_AA)

        # Return the image with keypoints drawn
        return self.im

    def plot_angle_and_count_and_stage(
        self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255)
    ):
        """
        Plot angle, count, and stage information on the image.

        Args:
            angle_text (str): Text to display for the angle
            count_text (str): Text to display for the count
            stage_text (str): Text to display for the stage
            center_kpt (tuple): Center keypoint coordinates
            color (tuple): Color of the plotted elements
            txt_color (tuple): Color of the text
        """

        # Implementation details are missing in the provided snippet.
        # The function definition is incomplete.
        pass

    def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)):
        """
        Draw a segmented object with a bounding box on the image.

        Args:
            mask (list): List of mask data points for the segmented object
            mask_color (RGB): Color for the mask
            label (str): Text label for the detection
            txt_color (RGB): Text color
        """

        # Draw the polygonal lines around the mask region
        cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)

        # Calculate text size for label and draw a rectangle around the label
        text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf)
        cv2.rectangle(
            self.im,
            (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
            (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)),
            mask_color,
            -1,
        )

        # Draw the label text on the image
        if label:
            cv2.putText(
                self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf
            )
    def plot_distance_and_line(self, distance_m, distance_mm, centroids, line_color, centroid_color):
        """
        Plot the distance and line on frame.

        Args:
            distance_m (float): Distance between two bbox centroids in meters.
            distance_mm (float): Distance between two bbox centroids in millimeters.
            centroids (list): Bounding box centroids data.
            line_color (RGB): Distance line color.
            centroid_color (RGB): Bounding box centroid color.
        """

        # 计算 "Distance M" 文本的宽度和高度
        (text_width_m, text_height_m), _ = cv2.getTextSize(f"Distance M: {distance_m:.2f}m", 0, self.sf, self.tf)
        # 绘制包围 "Distance M" 文本的矩形框
        cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 10, 25 + text_height_m + 20), line_color, -1)
        # 在图像中绘制 "Distance M" 文本
        cv2.putText(
            self.im,
            f"Distance M: {distance_m:.2f}m",
            (20, 50),
            0,
            self.sf,
            centroid_color,
            self.tf,
            cv2.LINE_AA,
        )

        # 计算 "Distance MM" 文本的宽度和高度
        (text_width_mm, text_height_mm), _ = cv2.getTextSize(f"Distance MM: {distance_mm:.2f}mm", 0, self.sf, self.tf)
        # 绘制包围 "Distance MM" 文本的矩形框
        cv2.rectangle(self.im, (15, 75), (15 + text_width_mm + 10, 75 + text_height_mm + 20), line_color, -1)
        # 在图像中绘制 "Distance MM" 文本
        cv2.putText(
            self.im,
            f"Distance MM: {distance_mm:.2f}mm",
            (20, 100),
            0,
            self.sf,
            centroid_color,
            self.tf,
            cv2.LINE_AA,
        )

        # 在图像中绘制两个中心点之间的直线
        cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
        # 在图像中绘制第一个中心点
        cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
        # 在图像中绘制第二个中心点
        cv2.circle(self.im, centroids[1], 6, centroid_color, -1)

    def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)):
        """
        Function for pinpoint human-vision eye mapping and plotting.

        Args:
            box (list): Bounding box coordinates
            center_point (tuple): center point for vision eye view
            color (tuple): object centroid and line color value
            pin_color (tuple): visioneye point color value
        """

        # 计算 bounding box 的中心点坐标
        center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
        # 在图像中绘制 visioneye 点的中心点
        cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
        # 在图像中绘制 bounding box 的中心点
        cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
        # 在图像中绘制 visioneye 点与 bounding box 中心点之间的连线
        cv2.line(self.im, center_point, center_bbox, color, self.tf)
@TryExcept()  # 使用 TryExcept 装饰器,处理已知问题 https://github.com/ultralytics/yolov5/issues/5395
@plt_settings()  # 使用 plt_settings 函数进行绘图设置
def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
    """Plot training labels including class histograms and box statistics."""
    import pandas  # 导入 pandas 库,用于数据处理
    import seaborn  # 导入 seaborn 库,用于统计图表绘制

    # 过滤掉 matplotlib>=3.7.2 的警告和 Seaborn 的 use_inf 和 is_categorical 的 FutureWarnings
    warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
    warnings.filterwarnings("ignore", category=FutureWarning)

    # 绘制数据集标签
    LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
    nc = int(cls.max() + 1)  # 计算类别数量
    boxes = boxes[:1000000]  # 限制最多处理 100 万个框
    x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])  # 创建包含框坐标的 DataFrame

    # 绘制 Seaborn 相关性图
    seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
    plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)  # 保存相关性图
    plt.close()

    # 绘制 Matplotlib 标签
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
    y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)  # 绘制类别直方图
    for i in range(nc):
        y[2].patches[i].set_color([x / 255 for x in colors(i)])  # 设置直方图颜色
    ax[0].set_ylabel("instances")  # 设置 y 轴标签
    if 0 < len(names) < 30:
        ax[0].set_xticks(range(len(names)))
        ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)  # 设置 x 轴标签
    else:
        ax[0].set_xlabel("classes")  # 设置 x 轴标签为类别

    seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)  # 绘制 x、y 分布图
    seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)  # 绘制宽度、高度分布图

    # 绘制矩形框
    boxes[:, 0:2] = 0.5  # 将框坐标调整为中心点
    boxes = ops.xywh2xyxy(boxes) * 1000  # 转换为绝对坐标并放大
    img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)  # 创建空白图像
    for cls, box in zip(cls[:500], boxes[:500]):
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls))  # 绘制矩形框
    ax[1].imshow(img)  # 显示图像
    ax[1].axis("off")  # 关闭坐标轴显示

    for a in [0, 1, 2, 3]:
        for s in ["top", "right", "left", "bottom"]:
            ax[a].spines[s].set_visible(False)  # 隐藏图表边框

    fname = save_dir / "labels.jpg"
    plt.savefig(fname, dpi=200)  # 保存最终标签图像
    plt.close()  # 关闭绘图窗口
    if on_plot:
        on_plot(fname)  # 如果指定了回调函数,则调用回调函数
    # 根据传入的边界框信息 xyxy,裁剪输入图像 im,并返回裁剪后的图像。
    def save_one_box(xyxy, im, file='im.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):
        """
        Args:
            xyxy (torch.Tensor or list): 表示边界框的张量或列表,格式为 xyxy。
            im (numpy.ndarray): 输入图像。
            file (Path, optional): 裁剪后的图像保存路径。默认为 'im.jpg'。
            gain (float, optional): 边界框尺寸增益因子。默认为 1.02。
            pad (int, optional): 边界框宽度和高度增加的像素数。默认为 10。
            square (bool, optional): 如果为 True,则将边界框转换为正方形。默认为 False。
            BGR (bool, optional): 如果为 True,则保存图像为 BGR 格式;否则保存为 RGB 格式。默认为 False。
            save (bool, optional): 如果为 True,则保存裁剪后的图像到磁盘。默认为 True。
    
        Returns:
            (numpy.ndarray): 裁剪后的图像。
    
        Example:
            ```python
            from ultralytics.utils.plotting import save_one_box
    
            xyxy = [50, 50, 150, 150]
            im = cv2.imread('image.jpg')
            cropped_im = save_one_box(xyxy, im, file='cropped.jpg', square=True)
            ```py
        """
    
        if not isinstance(xyxy, torch.Tensor):  # 如果 xyxy 不是 torch.Tensor 类型,可能是列表
            xyxy = torch.stack(xyxy)  # 转换为 torch.Tensor
    
        b = ops.xyxy2xywh(xyxy.view(-1, 4))  # 将 xyxy 格式的边界框转换为 xywh 格式
        if square:
            b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # 尝试将矩形边界框转换为正方形
    
        b[:, 2:] = b[:, 2:] * gain + pad  # 计算边界框宽高乘以增益因子后加上 pad 像素
        xyxy = ops.xywh2xyxy(b).long()  # 将 xywh 格式的边界框转换回 xyxy 格式,并转换为整型坐标
        xyxy = ops.clip_boxes(xyxy, im.shape)  # 将边界框坐标限制在图像范围内
        crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]  # 根据边界框坐标裁剪图像
    
        if save:
            file.parent.mkdir(parents=True, exist_ok=True)  # 创建保存图像的文件夹
            f = str(increment_path(file).with_suffix(".jpg"))  # 生成带有递增数字的文件名,并设置为 jpg 后缀
            # cv2.imwrite(f, crop)  # 保存为 BGR 格式图像(存在色度抽样问题)
            Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0)  # 保存为 RGB 格式图像
    
        return crop  # 返回裁剪后的图像
# 使用装饰器标记该函数为可多线程执行的函数
@threaded
# 定义函数用于绘制带有标签、边界框、掩码和关键点的图像网格
def plot_images(
    # 图像数据,可以是 torch.Tensor 或 np.ndarray 类型,形状为 (batch_size, channels, height, width)
    images: Union[torch.Tensor, np.ndarray],
    # 每个检测的批次索引,形状为 (num_detections,)
    batch_idx: Union[torch.Tensor, np.ndarray],
    # 每个检测的类别标签,形状为 (num_detections,)
    cls: Union[torch.Tensor, np.ndarray],
    # 每个检测的边界框,形状为 (num_detections, 4) 或 (num_detections, 5)(用于旋转边界框)
    bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32),
    # 每个检测的置信度分数,形状为 (num_detections,)
    confs: Optional[Union[torch.Tensor, np.ndarray]] = None,
    # 实例分割掩码,形状为 (num_detections, height, width) 或 (1, height, width)
    masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8),
    # 每个检测的关键点,形状为 (num_detections, 51)
    kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32),
    # 图像文件路径列表,与批次中每个图像对应
    paths: Optional[List[str]] = None,
    # 输出图像网格的文件名
    fname: str = "images.jpg",
    # 类别索引到类别名称的映射字典
    names: Optional[Dict[int, str]] = None,
    # 绘图完成后的回调函数,可选
    on_plot: Optional[Callable] = None,
    # 输出图像网格的最大尺寸
    max_size: int = 1920,
    # 图像网格中最大子图数目
    max_subplots: int = 16,
    # 是否保存绘制的图像网格到文件
    save: bool = True,
    # 显示检测结果所需的置信度阈值
    conf_thres: float = 0.25,
) -> Optional[np.ndarray]:
    """
    Plot image grid with labels, bounding boxes, masks, and keypoints.

    Args:
        images: Batch of images to plot. Shape: (batch_size, channels, height, width).
        batch_idx: Batch indices for each detection. Shape: (num_detections,).
        cls: Class labels for each detection. Shape: (num_detections,).
        bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes.
        confs: Confidence scores for each detection. Shape: (num_detections,).
        masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width).
        kpts: Keypoints for each detection. Shape: (num_detections, 51).
        paths: List of file paths for each image in the batch.
        fname: Output filename for the plotted image grid.
        names: Dictionary mapping class indices to class names.
        on_plot: Optional callback function to be called after saving the plot.
        max_size: Maximum size of the output image grid.
        max_subplots: Maximum number of subplots in the image grid.
        save: Whether to save the plotted image grid to a file.
        conf_thres: Confidence threshold for displaying detections.

    Returns:
        np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise.

    Note:
        This function supports both tensor and numpy array inputs. It will automatically
        convert tensor inputs to numpy arrays for processing.
    """

    # 如果 images 是 torch.Tensor 类型,则转换为 numpy 数组类型
    if isinstance(images, torch.Tensor):
        images = images.cpu().float().numpy()
    # 如果 cls 是 torch.Tensor 类型,则转换为 numpy 数组类型
    if isinstance(cls, torch.Tensor):
        cls = cls.cpu().numpy()
    # 如果 bboxes 是 torch.Tensor 类型,则转换为 numpy 数组类型
    if isinstance(bboxes, torch.Tensor):
        bboxes = bboxes.cpu().numpy()
    # 如果 masks 是 torch.Tensor 类型,则转换为 numpy 数组类型,并将类型转换为 int
    if isinstance(masks, torch.Tensor):
        masks = masks.cpu().numpy().astype(int)
    # 如果 kpts 是 torch.Tensor 类型,则转换为 numpy 数组类型
    if isinstance(kpts, torch.Tensor):
        kpts = kpts.cpu().numpy()
    # 如果 batch_idx 是 torch.Tensor 类型,则转换为 numpy 数组类型
    if isinstance(batch_idx, torch.Tensor):
        batch_idx = batch_idx.cpu().numpy()

    # 获取图像的批次大小、通道数、高度和宽度
    bs, _, h, w = images.shape  # batch size, _, height, width
    # 限制要绘制的图像数量,最多为 max_subplots
    bs = min(bs, max_subplots)
    # 计算图像网格中子图的行数和列数(向上取整)
    ns = np.ceil(bs**0.5)
    
    # 如果图像的最大像素值小于等于1,则将其转换为 0-255 范围的值(去除标准化)
    if np.max(images[0]) <= 1:
        images *= 255  # de-normalise (optional)
    # 构建图像拼接
    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # 初始化一个白色背景的图像数组

    # 遍历每个图像块,将其放置在合适的位置
    for i in range(bs):
        x, y = int(w * (i // ns)), int(h * (i % ns))  # 计算当前块的起始位置
        mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)  # 将图像块放置到拼接图像上

    # 可选的调整大小操作
    scale = max_size / ns / max(h, w)
    if scale < 1:
        h = math.ceil(scale * h)  # 计算调整后的高度
        w = math.ceil(scale * w)  # 计算调整后的宽度
        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))  # 调整拼接后的图像大小

    # 添加注释
    fs = int((h + w) * ns * 0.01)  # 计算字体大小
    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)  # 创建一个注释器对象
    if not save:
        return np.asarray(annotator.im)  # 如果不需要保存,返回注释后的图像数组
    annotator.im.save(fname)  # 否则保存注释后的图像
    if on_plot:
        on_plot(fname)  # 如果有指定的绘图函数,调用它并传入保存的文件名
@plt_settings()
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
    """
    Plot training results from a results CSV file. The function supports various types of data including segmentation,
    pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.

    Args:
        file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
        dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
        segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
        pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
        classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
        on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
            Defaults to None.

    Example:
        ```python
        from ultralytics.utils.plotting import plot_results

        plot_results('path/to/results.csv', segment=True)
        ```py
    """
    import pandas as pd  # 导入 pandas 库,用于处理 CSV 文件
    from scipy.ndimage import gaussian_filter1d  # 导入 scipy 库中的高斯滤波函数

    # 确定保存图片的目录
    save_dir = Path(file).parent if file else Path(dir)

    # 根据不同的数据类型和设置,选择合适的子图布局和指数索引
    if classify:
        fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True)  # 分类数据的布局
        index = [1, 4, 2, 3]  # 对应子图的索引顺序
    elif segment:
        fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)  # 分割数据的布局
        index = [1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]  # 对应子图的索引顺序
    elif pose:
        fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True)  # 姿态估计数据的布局
        index = [1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 18, 8, 9, 12, 13]  # 对应子图的索引顺序
    else:
        fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)  # 默认数据的布局
        index = [1, 2, 3, 4, 5, 8, 9, 10, 6, 7]  # 对应子图的索引顺序

    ax = ax.ravel()  # 将子图数组展平,便于迭代处理

    files = list(save_dir.glob("results*.csv"))  # 查找保存结果的 CSV 文件列表
    assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."  # 断言确保找到了结果文件,否则报错

    for f in files:
        try:
            data = pd.read_csv(f)  # 读取 CSV 文件中的数据
            s = [x.strip() for x in data.columns]  # 清理列名,去除空格
            x = data.values[:, 0]  # 获取 X 轴数据,通常是第一列数据

            # 遍历子图索引,绘制每个子图的数据曲线和平滑曲线
            for i, j in enumerate(index):
                y = data.values[:, j].astype("float")  # 获取 Y 轴数据,并转换为浮点数类型
                # y[y == 0] = np.nan  # 不显示值为零的点,可选功能

                # 绘制实际结果曲线
                ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8)
                # 绘制平滑后的曲线
                ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2)
                ax[i].set_title(s[j], fontsize=12)  # 设置子图标题

                # 如果是指定的子图索引,共享训练和验证损失的 Y 轴
                # if j in {8, 9, 10}:
                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])

        except Exception as e:
            LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")  # 捕获并记录绘图过程中的异常信息

    ax[1].legend()  # 在第二个子图上添加图例
    # 指定文件名为 save_dir 下的 "results.png"
    fname = save_dir / "results.png"
    # 将当前图形保存为 PNG 文件,设置 DPI 为 200
    fig.savefig(fname, dpi=200)
    # 关闭当前图形,释放资源
    plt.close()
    # 如果定义了 on_plot 回调函数,则调用该函数,传递保存的文件名作为参数
    if on_plot:
        on_plot(fname)
def plot_tune_results(csv_file="tune_results.csv"):
    """
    Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
    in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.

    Args:
        csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.

    Examples:
        >>> plot_tune_results('path/to/tune_results.csv')
    """

    import pandas as pd  # 导入 pandas 库,用于处理数据
    from scipy.ndimage import gaussian_filter1d  # 导入 scipy 库中的高斯滤波函数

    def _save_one_file(file):
        """Save one matplotlib plot to 'file'."""
        plt.savefig(file, dpi=200)  # 保存当前 matplotlib 图形为指定文件,设置分辨率为200dpi
        plt.close()  # 关闭当前 matplotlib 图形
        LOGGER.info(f"Saved {file}")  # 记录日志信息,显示保存成功的文件名

    # Scatter plots for each hyperparameter
    csv_file = Path(csv_file)  # 将传入的 CSV 文件路径转换为 Path 对象
    data = pd.read_csv(csv_file)  # 使用 pandas 读取 CSV 文件中的数据
    num_metrics_columns = 1  # 指定要跳过的列数(这里是第一列的列数)
    keys = [x.strip() for x in data.columns][num_metrics_columns:]  # 获取 CSV 文件中的列名,并去除首尾空白字符
    x = data.values  # 获取 CSV 文件中的所有数据值
    fitness = x[:, 0]  # 从数据中提取 fitness(适应度)列数据
    j = np.argmax(fitness)  # 找到 fitness 列中最大值的索引
    n = math.ceil(len(keys) ** 0.5)  # 计算绘图的行数和列数,向上取整以确保足够的子图空间
    plt.figure(figsize=(10, 10), tight_layout=True)  # 创建一个 10x10 英寸大小的图形,并启用紧凑布局
    for i, k in enumerate(keys):
        v = x[:, i + num_metrics_columns]  # 获取当前列(除 fitness 外的其他列)的数据
        mu = v[j]  # 获取当前列中 fitness 最大值对应的数据点
        plt.subplot(n, n, i + 1)  # 在 n x n 的子图中,选择第 i+1 个子图
        plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")  # 调用 plt_color_scatter 函数绘制散点图
        plt.plot(mu, fitness.max(), "k+", markersize=15)  # 在散点图上绘制 fitness 最大值对应的点
        plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9})  # 设置子图标题,显示参数名和对应的最佳单个结果
        plt.tick_params(axis="both", labelsize=8)  # 设置坐标轴标签的大小为 8
        if i % n != 0:
            plt.yticks([])  # 如果不是每行的第一个子图,则不显示 y 轴刻度

    _save_one_file(csv_file.with_name("tune_scatter_plots.png"))  # 调用保存函数,将绘制好的图形保存为 PNG 文件

    # Fitness vs iteration
    # 生成 x 轴的数值范围,从1到fitness列表长度加1
    x = range(1, len(fitness) + 1)
    # 创建一个图形对象,设置图形大小为10x6,启用紧凑布局
    plt.figure(figsize=(10, 6), tight_layout=True)
    # 绘制 fitness 列表的数据点,使用圆形标记,折线样式为无,设置标签为"fitness"
    plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
    # 绘制 fitness 列表数据点的高斯平滑曲线,设置折线样式为冒号,设置标签为"smoothed",设置线宽为2,说明是平滑线
    plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2)  # smoothing line
    # 设置图形的标题为"Fitness vs Iteration"
    plt.title("Fitness vs Iteration")
    # 设置 x 轴标签为"Iteration"
    plt.xlabel("Iteration")
    # 设置 y 轴标签为"Fitness"
    plt.ylabel("Fitness")
    # 启用网格线
    plt.grid(True)
    # 显示图例
    plt.legend()
    # 调用保存图形的函数,保存文件名为csv_file的名称加上"tune_fitness.png"作为后缀
    _save_one_file(csv_file.with_name("tune_fitness.png"))
def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")):
    """
    Visualize feature maps of a given model module during inference.

    Args:
        x (torch.Tensor): Features to be visualized.
        module_type (str): Module type.
        stage (int): Module stage within the model.
        n (int, optional): Maximum number of feature maps to plot. Defaults to 32.
        save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp').
    """
    # 检查模块类型是否属于需要可视化的类型,如果不属于则直接返回
    for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}:  # all model heads
        if m in module_type:
            return

    # 检查输入特征是否为Tensor类型
    if isinstance(x, torch.Tensor):
        _, channels, height, width = x.shape  # 获取特征张量的形状信息:batch, channels, height, width
        if height > 1 and width > 1:
            f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png"  # 构建保存文件路径和名称

            # 按照通道数拆分特征图块
            blocks = torch.chunk(x[0].cpu(), channels, dim=0)  # 选择批次索引为0的数据,并按通道拆分
            n = min(n, channels)  # 确定要绘制的特征图块数量,不超过通道数
            _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 创建绘图布局,8行 n/8 列
            ax = ax.ravel()
            plt.subplots_adjust(wspace=0.05, hspace=0.05)
            for i in range(n):
                ax[i].imshow(blocks[i].squeeze())  # 显示特征图块,去除单维度条目
                ax[i].axis("off")  # 关闭坐标轴显示

            LOGGER.info(f"Saving {f}... ({n}/{channels})")  # 记录保存文件信息
            plt.savefig(f, dpi=300, bbox_inches="tight")  # 保存绘制结果为PNG文件,300dpi,紧凑边界
            plt.close()  # 关闭绘图窗口
            np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy())  # 保存特征数据为.npy文件

.\yolov8\ultralytics\utils\tal.py

# 导入 PyTorch 库中的相关模块
import torch
import torch.nn as nn

# 从本地模块中导入必要的函数和类
from .checks import check_version
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy

# 检查当前使用的 PyTorch 版本是否符合最低要求
TORCH_1_10 = check_version(torch.__version__, "1.10.0")

# 定义一个任务对齐分配器的类,用于目标检测
class TaskAlignedAssigner(nn.Module):
    """
    A task-aligned assigner for object detection.

    This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
    classification and localization information.

    Attributes:
        topk (int): The number of top candidates to consider.
        num_classes (int): The number of object classes.
        alpha (float): The alpha parameter for the classification component of the task-aligned metric.
        beta (float): The beta parameter for the localization component of the task-aligned metric.
        eps (float): A small value to prevent division by zero.
    """

    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
        """Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
        # 调用父类构造函数初始化模块
        super().__init__()
        # 设置对象属性,用于指定任务对齐分配器的超参数
        self.topk = topk  # 设置前k个候选框的数量
        self.num_classes = num_classes  # 设置目标类别的数量
        self.bg_idx = num_classes  # 设置背景类别的索引,默认为num_classes
        self.alpha = alpha  # 设置任务对齐度量中分类组件的参数alpha
        self.beta = beta  # 设置任务对齐度量中定位组件的参数beta
        self.eps = eps  # 设置一个极小值,用于避免除以零的情况

    @torch.no_grad()
    def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
        """
        Compute the task-aligned assignment. Reference code is available at
        https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.

        Args:
            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
                预测得分张量,形状为(bs, num_total_anchors, num_classes)
            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
                预测边界框张量,形状为(bs, num_total_anchors, 4)
            anc_points (Tensor): shape(num_total_anchors, 2)
                锚点坐标张量,形状为(num_total_anchors, 2)
            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
                真实标签张量,形状为(bs, n_max_boxes, 1)
            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
                真实边界框张量,形状为(bs, n_max_boxes, 4)
            mask_gt (Tensor): shape(bs, n_max_boxes, 1)
                真实边界框掩码张量,形状为(bs, n_max_boxes, 1)

        Returns:
            target_labels (Tensor): shape(bs, num_total_anchors)
                目标标签张量,形状为(bs, num_total_anchors)
            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
                目标边界框张量,形状为(bs, num_total_anchors, 4)
            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
                目标得分张量,形状为(bs, num_total_anchors, num_classes)
            fg_mask (Tensor): shape(bs, num_total_anchors)
                前景掩码张量,形状为(bs, num_total_anchors)
            target_gt_idx (Tensor): shape(bs, num_total_anchors)
                目标真实边界框索引张量,形状为(bs, num_total_anchors)
        """
        self.bs = pd_scores.shape[0]  # 记录批次大小
        self.n_max_boxes = gt_bboxes.shape[1]  # 记录每个样本最大边界框数

        if self.n_max_boxes == 0:  # 如果没有真实边界框
            device = gt_bboxes.device
            return (
                torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),  # 返回背景类索引
                torch.zeros_like(pd_bboxes).to(device),  # 返回零张量形状与预测边界框一致
                torch.zeros_like(pd_scores).to(device),  # 返回零张量形状与预测得分一致
                torch.zeros_like(pd_scores[..., 0]).to(device),  # 返回零张量形状与预测得分一致
                torch.zeros_like(pd_scores[..., 0]).to(device),  # 返回零张量形状与预测得分一致
            )

        mask_pos, align_metric, overlaps = self.get_pos_mask(
            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
        )  # 获取正样本掩码、对齐度量、重叠度量

        target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
        # 选择最高重叠度的真实边界框索引、前景掩码、正样本掩码

        # Assigned target
        target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
        # 获取分配的目标标签、目标边界框、目标得分

        # Normalize
        align_metric *= mask_pos  # 对齐度量乘以正样本掩码
        pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)  # 计算每个样本的最大对齐度量
        pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)  # 计算每个样本的最大重叠度
        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
        # 计算归一化后的对齐度量
        target_scores = target_scores * norm_align_metric  # 目标得分乘以归一化后的对齐度量

        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
        """Get in_gts mask, (b, max_num_obj, h*w)."""
        # Select candidates within ground truth bounding boxes
        mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
        
        # Compute alignment metric and overlaps between predicted and ground truth boxes
        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
        
        # Select top-k candidates based on alignment metric
        mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
        
        # Merge masks to get the final positive mask
        mask_pos = mask_topk * mask_in_gts * mask_gt
        
        # Return the final positive mask, alignment metric, and overlaps
        return mask_pos, align_metric, overlaps

    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
        """Compute alignment metric given predicted and ground truth bounding boxes."""
        na = pd_bboxes.shape[-2]
        mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
        
        # Initialize tensors for overlaps and bbox scores
        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
        
        # Create indices tensor for accessing scores based on ground truth labels
        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
        
        # Assign predicted scores to corresponding locations in bbox_scores
        bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w
        
        # Extract predicted and ground truth bounding boxes where mask_gt is True
        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]  # (b, max_num_obj, 1, 4)
        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]  # (b, 1, h*w, 4)
        
        # Compute IoU between selected boxes
        overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
        
        # Calculate alignment metric using bbox_scores and overlaps
        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
        
        return align_metric, overlaps

    def iou_calculation(self, gt_bboxes, pd_bboxes):
        """IoU calculation for horizontal bounding boxes."""
        # Calculate IoU using bbox_iou function with specified parameters
        return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
    # 根据给定的 metrics 张量,选择每个位置的前 self.topk 个候选项的指标值和索引
    topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
    
    # 如果 topk_mask 未提供,则根据 metrics 张量中的最大值确定 top-k 值,并扩展为布尔张量
    if topk_mask is None:
        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
    
    # 根据 topk_mask,将 topk_idxs 中未选中的位置填充为 0
    topk_idxs.masked_fill_(~topk_mask, 0)

    # 创建一个与 metrics 张量形状相同的计数张量,用于统计每个位置被选择的次数
    count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
    ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
    
    # 遍历 topk 值,对每个 topk 索引位置添加计数值
    for k in range(self.topk):
        count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
    
    # 将计数张量中大于 1 的值(即超过一次被选择的位置),置为 0,用于过滤无效的候选项
    count_tensor.masked_fill_(count_tensor > 1, 0)

    # 将计数张量转换为与 metrics 相同类型的张量,并返回结果
    return count_tensor.to(metrics.dtype)
    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
        """
        Compute target labels, target bounding boxes, and target scores for the positive anchor points.

        Args:
            gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
                                batch size and max_num_obj is the maximum number of objects.
            gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
            target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive
                                    anchor points, with shape (b, h*w), where h*w is the total
                                    number of anchor points.
            fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive
                              (foreground) anchor points.

        Returns:
            (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors:
                - target_labels (Tensor): Shape (b, h*w), containing the target labels for
                                          positive anchor points.
                - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes
                                          for positive anchor points.
                - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores
                                          for positive anchor points, where num_classes is the number
                                          of object classes.
        """

        # Assigned target labels, (b, 1)
        # Create batch indices for indexing into gt_labels
        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
        # Adjust target_gt_idx to point to the correct location in the flattened gt_labels
        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
        # Extract target labels from gt_labels using flattened indices
        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)

        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
        # Reshape gt_bboxes to (b * max_num_obj, 4) and then index using target_gt_idx
        target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]

        # Assigned target scores
        target_labels.clamp_(0)  # Clamp target_labels to ensure non-negative values

        # 10x faster than F.one_hot()
        # Initialize target_scores tensor with zeros and then scatter ones at target_labels indices
        target_scores = torch.zeros(
            (target_labels.shape[0], target_labels.shape[1], self.num_classes),
            dtype=torch.int64,
            device=target_labels.device,
        )  # (b, h*w, 80)
        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)

        # Mask target_scores based on fg_mask to only keep scores for foreground anchor points
        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)

        return target_labels, target_bboxes, target_scores
    def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
        """
        Select the positive anchor center in gt.

        Args:
            xy_centers (Tensor): shape(h*w, 2) - 存储锚点的中心坐标
            gt_bboxes (Tensor): shape(b, n_boxes, 4) - 存储每个图像中各个边界框的坐标信息

        Returns:
            (Tensor): shape(b, n_boxes, h*w) - 返回一个布尔值张量,指示哪些锚点与边界框有显著重叠
        """
        n_anchors = xy_centers.shape[0]  # 获取锚点的数量
        bs, n_boxes, _ = gt_bboxes.shape  # 获取边界框的数量和维度信息
        lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
        bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
        # 计算每个锚点与其对应边界框之间的距离,形成四个坐标差值并存储在bbox_deltas张量中
        return bbox_deltas.amin(3).gt_(eps)  # 判断距离是否大于阈值eps,并返回布尔值结果

    @staticmethod
    def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
        """
        If an anchor box is assigned to multiple gts, the one with the highest IoU will be selected.

        Args:
            mask_pos (Tensor): shape(b, n_max_boxes, h*w) - 存储布尔值指示哪些锚点与边界框有重叠
            overlaps (Tensor): shape(b, n_max_boxes, h*w) - 存储每个锚点与所有边界框之间的IoU值

        Returns:
            target_gt_idx (Tensor): shape(b, h*w) - 返回每个锚点与其最佳匹配边界框的索引
            fg_mask (Tensor): shape(b, h*w) - 返回一个布尔值张量,指示哪些锚点被分配给了边界框
            mask_pos (Tensor): shape(b, n_max_boxes, h*w) - 返回更新后的锚点分配信息
        """
        # (b, n_max_boxes, h*w) -> (b, h*w)
        fg_mask = mask_pos.sum(-2)  # 计算每个锚点分配给边界框的数量

        if fg_mask.max() > 1:  # 如果一个锚点被分配给多个边界框
            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
            max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)

            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)

            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
            fg_mask = mask_pos.sum(-2)  # 更新后的锚点分配数量

        # 找到每个网格服务的哪个gt(索引)
        target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
        return target_gt_idx, fg_mask, mask_pos
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
    """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""

    def iou_calculation(self, gt_bboxes, pd_bboxes):
        """IoU calculation for rotated bounding boxes."""
        return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)

    @staticmethod
    def select_candidates_in_gts(xy_centers, gt_bboxes):
        """
        Select the positive anchor center in gt for rotated bounding boxes.

        Args:
            xy_centers (Tensor): shape(h*w, 2) - Anchor centers to consider.
            gt_bboxes (Tensor): shape(b, n_boxes, 5) - Ground-truth rotated bounding boxes.

        Returns:
            (Tensor): shape(b, n_boxes, h*w) - Boolean mask indicating positive anchor centers.
        """
        # (b, n_boxes, 5) --> (b, n_boxes, 4, 2) - Rearrange bounding box coordinates.
        corners = xywhr2xyxyxyxy(gt_bboxes)
        # (b, n_boxes, 1, 2) - Extract corner points a, b, and d from corners.
        a, b, _, d = corners.split(1, dim=-2)
        ab = b - a  # Compute vectors ab and ad from corner points.
        ad = d - a

        # (b, n_boxes, h*w, 2) - Calculate vector ap from anchor centers to point a.
        ap = xy_centers - a
        norm_ab = (ab * ab).sum(dim=-1)  # Calculate norms and dot products for IoU calculation.
        norm_ad = (ad * ad).sum(dim=-1)
        ap_dot_ab = (ap * ab).sum(dim=-1)
        ap_dot_ad = (ap * ad).sum(dim=-1)
        return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad)  # is_in_box


def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None  # Ensure features are not None.
    dtype, device = feats[0].dtype, feats[0].device  # Determine data type and device from the first feature.
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape  # Retrieve height and width of feature map.
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # Generate x offsets.
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # Generate y offsets.
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)  # Create grid points.
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))  # Stack grid points into anchor points.
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))  # Create stride tensor.
    return torch.cat(anchor_points), torch.cat(stride_tensor)  # Concatenate anchor points and strides.


def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)  # Split distance tensor into left-top and right-bottom.
    x1y1 = anchor_points - lt  # Compute top-left corner coordinates.
    x2y2 = anchor_points + rb  # Compute bottom-right corner coordinates.
    if xywh:
        c_xy = (x1y1 + x2y2) / 2  # Compute center coordinates.
        wh = x2y2 - x1y1  # Compute width and height.
        return torch.cat((c_xy, wh), dim)  # xywh bbox - Concatenate center and size.
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox - Concatenate top-left and bottom-right.


def bbox2dist(anchor_points, bbox, reg_max):
    """Transform bbox(xyxy) to dist(ltrb)."""
    x1y1, x2y2 = bbox.chunk(2, -1)  # Split bbox tensor into x1y1 and x2y2.
    return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)  # dist (lt, rb)


def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
    """
    Decode predicted object bounding box coordinates from anchor points and distribution.
    """
    # Function not completed in provided snippet, further implementation required.
    # 将预测的旋转距离张量按照指定维度分割为左上角和右下角坐标偏移量
    lt, rb = pred_dist.split(2, dim=dim)
    # 计算预测的角度的余弦和正弦值
    cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
    # 计算中心点偏移量的 x 和 y 分量
    xf, yf = ((rb - lt) / 2).split(1, dim=dim)
    # 根据旋转角度对中心点偏移量进行调整,得到旋转后的中心点坐标
    x, y = xf * cos - yf * sin, xf * sin + yf * cos
    # 将旋转后的中心点坐标与锚点相加,得到最终的旋转后的坐标
    xy = torch.cat([x, y], dim=dim) + anchor_points
    # 将左上角和右下角坐标偏移量相加,得到最终的旋转后的边界框坐标
    return torch.cat([xy, lt + rb], dim=dim)

.\yolov8\ultralytics\utils\torch_utils.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import gc  # 导入垃圾回收模块
import math  # 导入数学模块
import os  # 导入操作系统模块
import random  # 导入随机数模块
import time  # 导入时间模块
from contextlib import contextmanager  # 导入上下文管理器模块
from copy import deepcopy  # 导入深拷贝函数
from datetime import datetime  # 导入日期时间模块
from pathlib import Path  # 导入路径模块
from typing import Union  # 导入类型注解

import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库
import torch.distributed as dist  # 导入PyTorch分布式训练模块
import torch.nn as nn  # 导入PyTorch神经网络模块
import torch.nn.functional as F  # 导入PyTorch函数模块

from ultralytics.utils import (  # 导入Ultralytics工具函数
    DEFAULT_CFG_DICT,  # 默认配置字典
    DEFAULT_CFG_KEYS,  # 默认配置键列表
    LOGGER,  # 日志记录器
    NUM_THREADS,  # 线程数
    PYTHON_VERSION,  # Python版本
    TORCHVISION_VERSION,  # TorchVision版本
    __version__,  # Ultralytics版本
    colorstr,  # 字符串颜色化函数
)
from ultralytics.utils.checks import check_version  # 导入版本检查函数

try:
    import thop  # 尝试导入thop库
except ImportError:
    thop = None  # 如果导入失败,设为None

# Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0")  # 检查PyTorch版本是否>=1.9.0
TORCH_1_13 = check_version(torch.__version__, "1.13.0")  # 检查PyTorch版本是否>=1.13.0
TORCH_2_0 = check_version(torch.__version__, "2.0.0")  # 检查PyTorch版本是否>=2.0.0
TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")  # 检查TorchVision版本是否>=0.10.0
TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")  # 检查TorchVision版本是否>=0.11.0
TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")  # 检查TorchVision版本是否>=0.13.0
TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")  # 检查TorchVision版本是否>=0.18.0


@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first."""
    initialized = dist.is_available() and dist.is_initialized()  # 检查是否启用了分布式训练且是否已初始化
    if initialized and local_rank not in {-1, 0}:  # 如果初始化且当前进程不是主进程(rank 0)
        dist.barrier(device_ids=[local_rank])  # 等待本地主节点(rank 0)完成任务
    yield  # 执行上下文管理器的主体部分
    if initialized and local_rank == 0:  # 如果初始化且当前进程是主进程(rank 0)
        dist.barrier(device_ids=[0])  # 确保所有进程在继续之前都等待主进程完成


def smart_inference_mode():
    """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
    
    def decorate(fn):
        """Applies appropriate torch decorator for inference mode based on torch version."""
        if TORCH_1_9 and torch.is_inference_mode_enabled():
            return fn  # 如果已启用推断模式,直接返回函数
        else:
            return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)  # 根据版本选择合适的推断模式装饰器

    return decorate


def autocast(enabled: bool, device: str = "cuda"):
    """
    Get the appropriate autocast context manager based on PyTorch version and AMP setting.

    This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
    older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.

    Args:
        enabled (bool): Whether to enable automatic mixed precision.
        device (str, optional): The device to use for autocast. Defaults to 'cuda'.

    Returns:
        (torch.amp.autocast): The appropriate autocast context manager.

    Note:
        - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
        - For older versions, it uses `torch.cuda.autocast`.

    Example:
        ```py
        with autocast(amp=True):
            # Your mixed precision operations here
            pass
        ```
    """
    # 如果 TORCH_1_13 变量为真,使用 torch.amp.autocast 方法开启自动混合精度模式
    if TORCH_1_13:
        return torch.amp.autocast(device, enabled=enabled)
    # 如果 TORCH_1_13 变量为假,使用 torch.cuda.amp.autocast 方法开启自动混合精度模式
    else:
        return torch.cuda.amp.autocast(enabled)
def get_cpu_info():
    """Return a string with system CPU information, i.e. 'Apple M2'."""
    import cpuinfo  # 导入cpuinfo库,用于获取CPU信息,需使用pip安装py-cpuinfo

    k = "brand_raw", "hardware_raw", "arch_string_raw"  # 按优先顺序列出信息键(并非所有键始终可用)
    info = cpuinfo.get_cpu_info()  # 获取CPU信息的字典
    string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")  # 提取CPU信息字符串
    return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")  # 处理特殊字符后返回CPU信息字符串


def select_device(device="", batch=0, newline=False, verbose=True):
    """
    Selects the appropriate PyTorch device based on the provided arguments.

    The function takes a string specifying the device or a torch.device object and returns a torch.device object
    representing the selected device. The function also validates the number of available devices and raises an
    exception if the requested device(s) are not available.

    Args:
        device (str | torch.device, optional): Device string or torch.device object.
            Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
            the first available GPU, or CPU if no GPU is available.
        batch (int, optional): Batch size being used in your model. Defaults to 0.
        newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
        verbose (bool, optional):
    elif device:  # 非 CPU 设备请求时执行以下操作
        if device == "cuda":
            device = "0"
        visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
        os.environ["CUDA_VISIBLE_DEVICES"] = device  # 设置环境变量,必须在检查可用性之前设置
        if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
            LOGGER.info(s)  # 记录信息到日志
            install = (
                "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
                "CUDA devices are seen by torch.\n"
                if torch.cuda.device_count() == 0
                else ""
            )
            raise ValueError(
                f"Invalid CUDA 'device={device}' requested."
                f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
                f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
                f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
                f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
                f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
                f"{install}"
            )

    if not cpu and not mps and torch.cuda.is_available():  # 如果可用且未请求 CPU 或 MPS
        devices = device.split(",") if device else "0"  # 定义设备列表,默认为 "0"
        n = len(devices)  # 设备数量
        if n > 1:  # 多 GPU 情况
            if batch < 1:
                raise ValueError(
                    "AutoBatch with batch<1 not supported for Multi-GPU training, "
                    "please specify a valid batch size, i.e. batch=16."
                )
            if batch >= 0 and batch % n != 0:  # 检查 batch_size 是否可以被设备数量整除
                raise ValueError(
                    f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
                    f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
                )
        space = " " * (len(s) + 1)  # 创建空格串
        for i, d in enumerate(devices):
            p = torch.cuda.get_device_properties(i)
            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # 字符串拼接 GPU 信息
        arg = "cuda:0"  # 设置 CUDA 设备为默认值 "cuda:0"
    elif mps and TORCH_2_0 and torch.backends.mps.is_available():
        # 如果支持 MPS 并且满足条件,则优先选择 MPS
        s += f"MPS ({get_cpu_info()})\n"  # 添加 MPS 信息到字符串
        arg = "mps"  # 设置设备类型为 "mps"
    else:  # 否则,默认使用 CPU
        s += f"CPU ({get_cpu_info()})\n"  # 添加 CPU 信息到字符串
        arg = "cpu"  # 设置设备类型为 "cpu"

    if arg in {"cpu", "mps"}:
        torch.set_num_threads(NUM_THREADS)  # 设置 CPU 训练的线程数
    if verbose:
        LOGGER.info(s if newline else s.rstrip())  # 如果需要详细输出,则记录详细信息到日志
    return torch.device(arg)  # 返回对应的 Torch 设备对象
# 返回当前系统时间,确保在使用 PyTorch 时精确同步时间
def time_sync():
    """PyTorch-accurate time."""
    # 如果 CUDA 可用,同步 CUDA 计算的时间
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    # 返回当前时间戳
    return time.time()


# 将 Conv2d() 和 BatchNorm2d() 层融合,实现优化 https://tehnokv.com/posts/fusing-batchnorm-and-conv/
def fuse_conv_and_bn(conv, bn):
    """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
    # 创建融合后的卷积层对象
    fusedconv = (
        nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            dilation=conv.dilation,
            groups=conv.groups,
            bias=True,
        )
        .requires_grad_(False)  # 禁用梯度追踪,不需要反向传播训练
        .to(conv.weight.device)  # 将融合后的卷积层移到与输入卷积层相同的设备上
    )

    # 准备卷积层的权重
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    # 计算融合后的权重
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

    # 准备空间偏置项
    b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
    # 计算融合后的偏置项
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fusedconv


# 将 ConvTranspose2d() 和 BatchNorm2d() 层融合
def fuse_deconv_and_bn(deconv, bn):
    """Fuse ConvTranspose2d() and BatchNorm2d() layers."""
    # 创建融合后的反卷积层对象
    fuseddconv = (
        nn.ConvTranspose2d(
            deconv.in_channels,
            deconv.out_channels,
            kernel_size=deconv.kernel_size,
            stride=deconv.stride,
            padding=deconv.padding,
            output_padding=deconv.output_padding,
            dilation=deconv.dilation,
            groups=deconv.groups,
            bias=True,
        )
        .requires_grad_(False)  # 禁用梯度追踪,不需要反向传播训练
        .to(deconv.weight.device)  # 将融合后的反卷积层移到与输入反卷积层相同的设备上
    )

    # 准备反卷积层的权重
    w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
    # 计算融合后的权重
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))

    # 准备空间偏置项
    b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias
    # 计算融合后的偏置项
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fuseddconv


# 输出模型的信息,包括参数数量、梯度数量和层的数量
def model_info(model, detailed=False, verbose=True, imgsz=640):
    """
    Model information.

    imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320].
    """
    # 如果不需要详细信息,则直接返回
    if not verbose:
        return
    # 获取模型的参数数量
    n_p = get_num_params(model)  # number of parameters
    # 获取模型的梯度数量
    n_g = get_num_gradients(model)  # number of gradients
    # 获取模型的层数量
    n_l = len(list(model.modules()))  # number of layers
    # 如果 detailed 参数为 True,则输出详细的模型参数信息
    if detailed:
        # 使用 LOGGER 记录模型参数的详细信息表头,包括层编号、名称、梯度是否计算、参数数量、形状、平均值、标准差和数据类型
        LOGGER.info(
            f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}"
        )
        # 遍历模型的所有命名参数,并给每个参数分配一个序号 i
        for i, (name, p) in enumerate(model.named_parameters()):
            # 去除参数名中的 "module_list." 字符串
            name = name.replace("module_list.", "")
            # 使用 LOGGER 记录每个参数的详细信息,包括序号、名称、是否需要梯度、参数数量、形状、平均值、标准差和数据类型
            LOGGER.info(
                "%5g %40s %9s %12g %20s %10.3g %10.3g %10s"
                % (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std(), p.dtype)
            )

    # 计算模型的浮点运算量(FLOPs)
    flops = get_flops(model, imgsz)
    # 检查模型是否支持融合计算,如果支持,则添加 " (fused)" 到输出中
    fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
    # 如果计算得到的 FLOPs 不为空,则添加到输出中
    fs = f", {flops:.1f} GFLOPs" if flops else ""
    # 获取模型的 YAML 文件路径或者直接从模型属性中获取 YAML 文件路径,并去除路径中的 "yolo" 替换为 "YOLO",或默认为 "Model"
    yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
    model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
    # 使用 LOGGER 记录模型的总结信息,包括模型名称、层数量、参数数量、梯度数量和计算量信息
    LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
    # 返回模型的层数量、参数数量、梯度数量和计算量
    return n_l, n_p, n_g, flops
# 返回 YOLO 模型中的总参数数量
def get_num_params(model):
    return sum(x.numel() for x in model.parameters())


# 返回 YOLO 模型中具有梯度的参数总数
def get_num_gradients(model):
    return sum(x.numel() for x in model.parameters() if x.requires_grad)


# 为日志记录器返回包含有用模型信息的字典
def model_info_for_loggers(trainer):
    if trainer.args.profile:  # 如果需要进行 ONNX 和 TensorRT 的性能分析
        from ultralytics.utils.benchmarks import ProfileModels

        # 使用 ProfileModels 进行模型性能分析,获取结果
        results = ProfileModels([trainer.last], device=trainer.device).profile()[0]
        results.pop("model/name")  # 移除结果中的模型名称
    else:  # 否则仅返回最近验证的 PyTorch 时间信息
        results = {
            "model/parameters": get_num_params(trainer.model),  # 计算模型参数数量
            "model/GFLOPs": round(get_flops(trainer.model), 3),  # 计算模型的 GFLOPs
        }
    results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)  # 记录 PyTorch 推理速度
    return results


# 返回 YOLO 模型的 FLOPs(浮点运算数)
def get_flops(model, imgsz=640):
    if not thop:
        return 0.0  # 如果 thop 包未安装,返回 0.0 GFLOPs

    try:
        model = de_parallel(model)  # 取消模型的并行化
        p = next(model.parameters())
        if not isinstance(imgsz, list):
            imgsz = [imgsz, imgsz]  # 如果 imgsz 是 int 或 float,扩展为列表

        try:
            stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32  # 获取输入张量的步幅大小
            im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # 创建输入图像张量
            flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # 使用 thop 计算 GFLOPs
            return flops * imgsz[0] / stride * imgsz[1] / stride  # 计算基于图像尺寸的 GFLOPs
        except Exception:
            im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # 创建输入图像张量
            return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2  # 计算基于图像尺寸的 GFLOPs
    except Exception:
        return 0.0  # 发生异常时返回 0.0 GFLOPs


# 使用 Torch 分析器计算模型的 FLOPs(thop 包的替代方案,但速度通常较慢 2-10 倍)
def get_flops_with_torch_profiler(model, imgsz=640):
    if not TORCH_2_0:  # 如果 Torch 版本低于 2.0,返回 0.0
        return 0.0
    model = de_parallel(model)  # 取消模型的并行化
    p = next(model.parameters())
    if not isinstance(imgsz, list):
        imgsz = [imgsz, imgsz]  # 如果 imgsz 是 int 或 float,扩展为列表
    try:
        # 使用模型的步幅大小来确定输入张量的步幅
        stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2  # 最大步幅
        # 创建一个空的张量作为输入图像,格式为BCHW
        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)
        with torch.profiler.profile(with_flops=True) as prof:
            # 对模型进行推理,记录性能指标
            model(im)
        # 计算模型的浮点运算量(FLOPs)
        flops = sum(x.flops for x in prof.key_averages()) / 1e9
        # 根据输入图像大小调整计算的FLOPs,例如 640x640 GFLOPs
        flops = flops * imgsz[0] / stride * imgsz[1] / stride
    except Exception:
        # 对于RTDETR模型,使用实际图像大小作为输入张量的大小
        im = torch.empty((1, p.shape[1], *imgsz), device=p.device)  # 输入图像为BCHW格式
        with torch.profiler.profile(with_flops=True) as prof:
            # 对模型进行推理,记录性能指标
            model(im)
        # 计算模型的浮点运算量(FLOPs)
        flops = sum(x.flops for x in prof.key_averages()) / 1e9
    # 返回计算得到的FLOPs
    return flops
def initialize_weights(model):
    """Initialize model weights to random values."""
    # Iterate over all modules in the model
    for m in model.modules():
        t = type(m)
        # Check if the module is a 2D convolutional layer
        if t is nn.Conv2d:
            pass  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        # Check if the module is a 2D batch normalization layer
        elif t is nn.BatchNorm2d:
            # Set epsilon (eps) and momentum parameters
            m.eps = 1e-3
            m.momentum = 0.03
        # Check if the module is one of the specified activation functions
        elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
            # Enable inplace operation for the activation function
            m.inplace = True


def scale_img(img, ratio=1.0, same_shape=False, gs=32):
    """Scales and pads an image tensor of shape img(bs,3,y,x) based on given ratio and grid size gs, optionally
    retaining the original shape.
    """
    # If ratio is 1.0, return the original image tensor
    if ratio == 1.0:
        return img
    # Retrieve height and width from the image tensor shape
    h, w = img.shape[2:]
    # Compute the new scaled size based on the given ratio
    s = (int(h * ratio), int(w * ratio))  # new size
    # Resize the image tensor using bilinear interpolation
    img = F.interpolate(img, size=s, mode="bilinear", align_corners=False)  # resize
    # If not retaining the original shape, pad or crop the image tensor
    if not same_shape:
        # Calculate the padded height and width based on the ratio and grid size
        h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
    # Pad the image tensor to match the calculated dimensions
    return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447)  # value = imagenet mean


def copy_attr(a, b, include=(), exclude=()):
    """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
    # Iterate through attributes in object 'b'
    for k, v in b.__dict__.items():
        # Skip attributes based on conditions: not in include list, starts with '_', or in exclude list
        if (len(include) and k not in include) or k.startswith("_") or k in exclude:
            continue
        else:
            # Set attribute 'k' in object 'a' to the value 'v' from object 'b'
            setattr(a, k, v)


def get_latest_opset():
    """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity."""
    # Check if using PyTorch version 1.13 or newer
    if TORCH_1_13:
        # Dynamically compute the second-most recent ONNX opset version supported
        return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1
    # For PyTorch versions <= 1.12, return predefined opset versions
    version = torch.onnx.producer_version.rsplit(".", 1)[0]  # i.e. '2.3'
    return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12)


def intersect_dicts(da, db, exclude=()):
    """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
    # Create a dictionary comprehension to filter keys based on conditions
    return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}


def is_parallel(model):
    """Returns True if model is of type DP or DDP."""
    # Check if the model is an instance of DataParallel or DistributedDataParallel
    return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))


def de_parallel(model):
    """De-parallelize a model: returns single-GPU model if model is of type DP or DDP."""
    # Return the underlying module of a DataParallel or DistributedDataParallel model
    return model.module if is_parallel(model) else model


def one_cycle(y1=0.0, y2=1.0, steps=100):
    """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
    # Generate a lambda function that implements a sinusoidal ramp
    return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1


def init_seeds(seed=0, deterministic=False):
    """Initialize random number generator seeds."""
    # This function initializes seeds for random number generators
    # It is intended to be implemented further, but the current snippet does not contain the complete implementation.
    pass
    # 初始化随机数生成器(RNG)种子,以确保实验的可复现性 https://pytorch.org/docs/stable/notes/randomness.html.
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 用于多GPU情况下的种子设置,确保异常安全性
    # torch.backends.cudnn.benchmark = True  # AutoBatch问题 https://github.com/ultralytics/yolov5/issues/9287
    # 如果需要确定性行为,则执行以下操作
    if deterministic:
        if TORCH_2_0:
            # 使用确定性算法,并在不可确定时发出警告
            torch.use_deterministic_algorithms(True, warn_only=True)
            torch.backends.cudnn.deterministic = True
            # 设置CUBLAS工作空间大小的配置
            os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
            os.environ["PYTHONHASHSEED"] = str(seed)
        else:
            # 提示升级到torch>=2.0.0以实现确定性训练
            LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.")
    else:
        # 关闭确定性算法,允许非确定性行为
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = False
class ModelEMA:
    """
    Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
    average of everything in the model state_dict (parameters and buffers)

    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    To disable EMA set the `enabled` attribute to `False`.
    """

    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
        """Initialize EMA for 'model' with given arguments."""
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.enabled = True

    def update(self, model):
        """Update EMA parameters."""
        if self.enabled:
            self.updates += 1
            d = self.decay(self.updates)

            msd = de_parallel(model).state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:  # true for FP16 and FP32
                    v *= d
                    v += (1 - d) * msd[k].detach()
                    # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype},  model {msd[k].dtype}'

    def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
        """Updates attributes and saves stripped model with optimizer removed."""
        if self.enabled:
            copy_attr(self.ema, model, include, exclude)


def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
    """
    Strip optimizer from 'f' to finalize training, optionally save as 's'.

    Args:
        f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
        s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.

    Returns:
        None

    Example:
        ```py
        from pathlib import Path
        from ultralytics.utils.torch_utils import strip_optimizer

        for f in Path('path/to/model/checkpoints').rglob('*.pt'):
            strip_optimizer(f)
        ```

    Note:
        Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]`
    """
    try:
        x = torch.load(f, map_location=torch.device("cpu"))
        assert isinstance(x, dict), "checkpoint is not a Python dictionary"
        assert "model" in x, "'model' missing from checkpoint"
    except Exception as e:
        LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
        return

    updates = {
        "date": datetime.now().isoformat(),
        "version": __version__,
        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
        "docs": "https://docs.ultralytics.com",
    }

    # Update model
    # 如果字典 x 中有 "ema" 键,则将 "model" 键的值设为 "ema" 的值,替换模型为 EMA 模型
    if x.get("ema"):
        x["model"] = x["ema"]  # replace model with EMA
    
    # 如果 "model" 对象具有 "args" 属性,将其转换为字典类型,从 IterableSimpleNamespace 转换为 dict
    if hasattr(x["model"], "args"):
        x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict
    
    # 如果 "model" 对象具有 "criterion" 属性,将其设置为 None,去除损失函数的标准
    if hasattr(x["model"], "criterion"):
        x["model"].criterion = None  # strip loss criterion
    
    # 将模型转换为半精度浮点数表示,即 FP16
    x["model"].half()  # to FP16
    
    # 将模型的所有参数设置为不需要梯度计算
    for p in x["model"].parameters():
        p.requires_grad = False

    # 更新字典中的其他键
    args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # 将 DEFAULT_CFG_DICT 和 x 中的 "train_args" 合并为一个字典
    for k in "optimizer", "best_fitness", "ema", "updates":  # 遍历指定的键
        x[k] = None  # 将字典 x 中指定键的值设为 None
    x["epoch"] = -1  # 将 epoch 键的值设为 -1
    # 创建一个新字典,其中仅包含 DEFAULT_CFG_KEYS 中存在的键值对,并将其赋给 "train_args"
    x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
    # x['model'].args = x['train_args']  # 此行代码被注释掉了,不再使用

    # 将 updates 和 x 中的内容合并为一个字典,并保存到文件中,不使用 dill 序列化
    torch.save({**updates, **x}, s or f, use_dill=False)  # combine dicts (prefer to the right)
    
    # 获取文件的大小,并将其转换为兆字节(MB)
    mb = os.path.getsize(s or f) / 1e6  # file size
    
    # 记录日志,显示优化器已从文件中剥离,同时显示文件名和文件大小
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
# 将给定优化器的状态字典转换为FP16格式,重点在于转换'torch.Tensor'类型的数据
def convert_optimizer_state_dict_to_fp16(state_dict):
    # 遍历优化器状态字典中的'state'键对应的所有状态
    for state in state_dict["state"].values():
        # 遍历每个状态的键值对
        for k, v in state.items():
            # 排除键为"step"且值为'torch.Tensor'类型且数据类型为torch.float32的情况
            if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
                # 将符合条件的Tensor类型数据转换为半精度(FP16)
                state[k] = v.half()

    # 返回转换后的状态字典
    return state_dict


# Ultralytics速度、内存和FLOPs(浮点运算数)分析器
def profile(input, ops, n=10, device=None):
    # 结果存储列表
    results = []
    # 如果设备参数不是torch.device类型,则选择设备
    if not isinstance(device, torch.device):
        device = select_device(device)
    # 打印日志信息,包括各项参数
    LOGGER.info(
        f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
        f"{'input':>24s}{'output':>24s}"
    )
    for x in input if isinstance(input, list) else [input]:
        # 如果输入是列表,则遍历列表中的每个元素;否则将输入放入列表中并遍历
        x = x.to(device)
        # 将当前元素移动到指定的设备上(如GPU)
        x.requires_grad = True
        # 设置当前元素的梯度跟踪为True

        for m in ops if isinstance(ops, list) else [ops]:
            # 如果操作是列表,则遍历列表中的每个操作;否则将操作放入列表中并遍历
            m = m.to(device) if hasattr(m, "to") else m
            # 如果操作具有"to"方法,则将其移动到指定的设备上;否则保持不变
            m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
            # 如果操作具有"half"方法,并且输入是torch.Tensor类型且数据类型为torch.float16,则将操作转换为半精度(float16);否则保持不变
            tf, tb, t = 0, 0, [0, 0, 0]
            # 初始化时间记录变量:前向传播时间,反向传播时间,时间记录列表

            try:
                flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0
                # 使用thop库对操作进行浮点操作计算(FLOPs),并将结果转换为GFLOPs(十亿次浮点操作每秒)
            except Exception:
                flops = 0
                # 如果计算FLOPs出现异常,则将FLOPs设置为0

            try:
                for _ in range(n):
                    t[0] = time_sync()
                    # 记录前向传播开始时间
                    y = m(x)
                    # 执行操作的前向传播
                    t[1] = time_sync()
                    # 记录前向传播结束时间
                    try:
                        (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
                        # 计算输出y的总和,并对总和进行反向传播
                        t[2] = time_sync()
                        # 记录反向传播结束时间
                    except Exception:  # no backward method
                        t[2] = float("nan")
                        # 如果没有反向传播方法,则将反向传播时间设置为NaN
                    tf += (t[1] - t[0]) * 1000 / n
                    # 计算每个操作的平均前向传播时间(毫秒)
                    tb += (t[2] - t[1]) * 1000 / n
                    # 计算每个操作的平均反向传播时间(毫秒)
                mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
                # 如果CUDA可用,则计算当前GPU上的内存使用量(单位:GB)
                s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y))
                # 获取输入x和输出y的形状信息
                p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0
                # 计算操作m中的参数数量
                LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
                # 将结果记录到日志中,包括参数数量、FLOPs、内存占用、时间等信息
                results.append([p, flops, mem, tf, tb, s_in, s_out])
                # 将结果添加到结果列表中
            except Exception as e:
                LOGGER.info(e)
                # 记录异常信息到日志中
                results.append(None)
                # 将空结果添加到结果列表中
            gc.collect()
            # 尝试释放未使用的内存
            torch.cuda.empty_cache()
            # 清空CUDA缓存

    return results
    # 返回所有操作的结果列表
class EarlyStopping:
    """Early stopping class that stops training when a specified number of epochs have passed without improvement."""

    def __init__(self, patience=50):
        """
        Initialize early stopping object.

        Args:
            patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
        """
        self.best_fitness = 0.0  # 初始化最佳适应度为0.0,即最佳平均精度(mAP)
        self.best_epoch = 0  # 初始化最佳轮次为0
        self.patience = patience or float("inf")  # 设置等待适应度停止提高的轮次数,若未提供则设为无穷大
        self.possible_stop = False  # 是否可能在下一个轮次停止训练的标志

    def __call__(self, epoch, fitness):
        """
        Check whether to stop training.

        Args:
            epoch (int): Current epoch of training
            fitness (float): Fitness value of current epoch

        Returns:
            (bool): True if training should stop, False otherwise
        """
        if fitness is None:  # 检查适应度是否为None(当val=False时会发生)
            return False

        if fitness >= self.best_fitness:  # 如果当前适应度大于或等于最佳适应度
            self.best_epoch = epoch  # 更新最佳轮次为当前轮次
            self.best_fitness = fitness  # 更新最佳适应度为当前适应度
        delta = epoch - self.best_epoch  # 计算未改善的轮次数
        self.possible_stop = delta >= (self.patience - 1)  # 更新可能在下一个轮次停止训练的标志
        stop = delta >= self.patience  # 若未改善的轮次数超过设定的等待轮次数,则停止训练
        if stop:
            prefix = colorstr("EarlyStopping: ")  # 设置输出前缀
            LOGGER.info(
                f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
                f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
                f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
                f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
            )  # 输出停止训练信息
        return stop  # 返回是否停止训练的标志
posted @ 2024-09-05 12:03  绝不原创的飞龙  阅读(7)  评论(0编辑  收藏  举报