# 项目标题注释,标识此代码片段为Ultralytics YOLO项目的一部分,可能指示代码的功能或项目的版权信息
# AGPL-3.0 许可证,指明此代码受 AGPL-3.0 许可证保护,要求在使用、修改和分发时保持开源
# 🚀 感叹号和火箭图标,可能表示项目的迅速发展或功能强大的特征


# 导入必要的库
import cv2  # OpenCV库,用于图像处理
import torch  # PyTorch深度学习库
from PIL import Image  # Python Imaging Library,用于图像处理

# 导入Ultralytics预测相关模块
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops

# 分类预测器类,继承自BasePredictor类
class ClassificationPredictor(BasePredictor):
    A class extending the BasePredictor class for prediction based on a classification model.

        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.

        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.classify import ClassificationPredictor

        args = dict(model='yolov8n-cls.pt', source=ASSETS)
        predictor = ClassificationPredictor(overrides=args)

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initializes ClassificationPredictor setting the task to 'classify'."""
        super().__init__(cfg, overrides, _callbacks)
        # 设置任务为分类 'classify'
        self.args.task = "classify"
        # 处理旧版数据增强转换的名称
        self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"

    def preprocess(self, img):
        """Converts input image to model-compatible data type."""
        if not isinstance(img, torch.Tensor):
            # 检查是否存在旧版数据增强转换
            is_legacy_transform = any(
                self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
            if is_legacy_transform:  # 处理旧版数据增强转换
                img = torch.stack([self.transforms(im) for im in img], dim=0)
                # 转换图像数据格式为模型兼容的类型
                img = torch.stack(
                    [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
        # 将图像数据转换为模型所在设备的Tensor类型
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
        return img.half() if self.model.fp16 else img.float()  # 将uint8类型转换为fp16/32类型

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions to return Results objects."""
        if not isinstance(orig_imgs, list):  # 输入的图像是一个torch.Tensor,而不是一个列表
            # 将torch.Tensor类型的图像转换为numpy数组的批次
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        # 遍历预测结果、原始图像和输入路径,并构建Results对象列表
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
        return results


# 导入PyTorch库
import torch

# 从Ultralytics库中导入所需的模块和类
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first

# 定义一个名为ClassificationTrainer的类,继承自BaseTrainer类,用于分类模型的训练
class ClassificationTrainer(BaseTrainer):

        - Torchvision的分类模型也可以通过'model'参数传递,例如model='resnet18'。

        from ultralytics.models.yolo.classify import ClassificationTrainer

        args = dict(model='yolov8n-cls.pt', data='imagenet10', epochs=3)
        trainer = ClassificationTrainer(overrides=args)

    # 初始化ClassificationTrainer对象,可选配置覆盖和回调函数
    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
        # 如果overrides为None,则设为一个空字典
        if overrides is None:
            overrides = {}
        # 设置任务为分类
        overrides["task"] = "classify"
        # 如果未设置imgsz,则设为224
        if overrides.get("imgsz") is None:
            overrides["imgsz"] = 224
        # 调用父类的初始化方法
        super().__init__(cfg, overrides, _callbacks)

    # 设置YOLO模型的类名,从加载的数据集中获取
    def set_model_attributes(self):
        """Set the YOLO model's class names from the loaded dataset."""
        self.model.names = self.data["names"]

    # 获取模型的方法,返回一个配置好的用于YOLO训练的修改后的PyTorch模型
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Returns a modified PyTorch model configured for training YOLO."""
        # 创建分类模型对象
        model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        # 如果有指定的权重,则加载模型权重
        if weights:

        # 遍历模型的所有模块
        for m in model.modules():
            # 如果未使用预训练模型并且模块具有reset_parameters方法,则重置其参数
            if not self.args.pretrained and hasattr(m, "reset_parameters"):
            # 如果模块是Dropout类型并且启用了dropout,则设置其p属性
            if isinstance(m, torch.nn.Dropout) and self.args.dropout:
                m.p = self.args.dropout  # set dropout
        # 设置模型所有参数为需要梯度计算(用于训练)
        for p in model.parameters():
            p.requires_grad = True  # for training
        return model

    # 设置模型的方法,加载、创建或下载模型,适用于任何任务
    def setup_model(self):
        """Load, create or download model for any task."""
        # 导入torchvision库,以便更快速地在作用域内导入ultralytics
        import torchvision  # scope for faster 'import ultralytics'

        # 如果self.model名称存在于torchvision.models.__dict__中
        if str(self.model) in torchvision.models.__dict__:
            # 使用指定的self.model创建torchvision中的模型实例
            self.model = torchvision.models.__dict__[self.model](
                weights="IMAGENET1K_V1" if self.args.pretrained else None
            # 检查点设为None
            ckpt = None
            # 否则调用父类的setup_model方法,获取检查点
            ckpt = super().setup_model()
        # 调整分类模型的输出维度为self.data["nc"]
        ClassificationModel.reshape_outputs(self.model, self.data["nc"])
        return ckpt
    def build_dataset(self, img_path, mode="train", batch=None):
        """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.)."""
        # 使用给定的图片路径和模式创建一个 ClassificationDataset 实例
        return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)

    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        """Returns PyTorch DataLoader with transforms to preprocess images for inference."""
        # 在分布式数据并行训练环境中,仅在 rank 为 0 时初始化 dataset 的缓存
        with torch_distributed_zero_first(rank):
            # 使用 build_dataset 方法创建数据集对象
            dataset = self.build_dataset(dataset_path, mode)

        # 使用 build_dataloader 函数创建 PyTorch DataLoader 对象,用于加载数据批次
        loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
        # 如果不是训练模式,将推理转换添加到模型中
        if mode != "train":
            if is_parallel(self.model):
                self.model.module.transforms = loader.dataset.torch_transforms
                self.model.transforms = loader.dataset.torch_transforms
        # 返回 DataLoader 对象
        return loader

    def preprocess_batch(self, batch):
        """Preprocesses a batch of images and classes."""
        # 将批次中的图像和类别转移到设备上(GPU)
        batch["img"] = batch["img"].to(self.device)
        batch["cls"] = batch["cls"].to(self.device)
        return batch

    def progress_string(self):
        """Returns a formatted string showing training progress."""
        # 返回格式化后的训练进度字符串,包括当前训练轮次、GPU内存占用和各种损失的名称
        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (

    def get_validator(self):
        """Returns an instance of ClassificationValidator for validation."""
        # 设置损失名称为 "loss" 并返回一个用于验证的 ClassificationValidator 实例
        self.loss_names = ["loss"]
        return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)

    def label_loss_items(self, loss_items=None, prefix="train"):
        Returns a loss dict with labelled training loss items tensor.

        Not needed for classification but necessary for segmentation & detection
        # 根据指定的前缀和损失名称,返回带标签的训练损失字典
        keys = [f"{prefix}/{x}" for x in self.loss_names]
        if loss_items is None:
            return keys
        loss_items = [round(float(loss_items), 5)]
        return dict(zip(keys, loss_items))

    def plot_metrics(self):
        """Plots metrics from a CSV file."""
        # 从 CSV 文件中绘制指标图表,用于分类任务,并保存为 results.png
        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)
    # 定义一个方法,用于评估训练好的模型并保存验证结果
    def final_eval(self):
        """Evaluate trained model and save validation results."""
        # 遍历最后和最佳模型文件
        for f in self.last, self.best:
            # 检查文件是否存在
            if f.exists():
                # 去除模型文件中的优化器信息
                strip_optimizer(f)  # strip optimizers
                # 如果当前文件是最佳模型文件
                if f is self.best:
                    # 记录信息:正在验证最佳模型文件
                    LOGGER.info(f"\nValidating {f}...")
                    # 设置验证器参数
                    self.validator.args.data = self.args.data
                    self.validator.args.plots = self.args.plots
                    # 使用验证器评估模型
                    self.metrics = self.validator(model=f)
                    # 移除评估结果中的 fitness 指标(如果存在)
                    self.metrics.pop("fitness", None)
                    # 执行回调函数:在每个训练周期结束时
        # 记录信息:结果保存到指定目录
        LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")

    # 定义一个方法,用于绘制训练样本和它们的标注
    def plot_training_samples(self, batch, ni):
        """Plots training samples with their annotations."""
        # 调用 plot_images 函数,绘制训练样本图片
            # 创建一个张量,包含批次内所有图片的索引
            # 警告:对于分类模型,使用 .view() 而不是 .squeeze() 来展平类别数据
            # 图片文件名,保存在指定目录下,并包含批次号
            fname=self.save_dir / f"train_batch{ni}.jpg",
            # 在绘图时执行的回调函数


# 导入PyTorch库
import torch

# 从Ultralytics库中导入相关模块和类
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images

# 创建一个分类验证器类,继承自BaseValidator基类
class ClassificationValidator(BaseValidator):
    A class extending the BaseValidator class for validation based on a classification model.

        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.

        from ultralytics.models.yolo.classify import ClassificationValidator

        args = dict(model='yolov8n-cls.pt', data='imagenet10')
        validator = ClassificationValidator(args=args)

    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
        # 调用父类的初始化方法
        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
        # 初始化变量
        self.targets = None
        self.pred = None
        self.args.task = "classify"  # 设置任务类型为分类
        self.metrics = ClassifyMetrics()  # 初始化分类度量器对象

    def get_desc(self):
        """Returns a formatted string summarizing classification metrics."""
        return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")

    def init_metrics(self, model):
        """Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
        # 设置类别名称列表和类别数量
        self.names = model.names
        self.nc = len(model.names)
        # 初始化混淆矩阵对象,传入类别数、置信度阈值和任务类型
        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify")
        self.pred = []  # 初始化预测结果列表
        self.targets = []  # 初始化目标标签列表

    def preprocess(self, batch):
        """Preprocesses input batch and returns it."""
        # 将图像数据移到指定设备上,并根据需要转换为半精度或全精度
        batch["img"] = batch["img"].to(self.device, non_blocking=True)
        batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
        batch["cls"] = batch["cls"].to(self.device)  # 将类别标签数据移到指定设备上
        return batch

    def update_metrics(self, preds, batch):
        """Updates running metrics with model predictions and batch targets."""
        n5 = min(len(self.names), 5)
        # 将模型预测结果按概率降序排列,取前n5个类别作为预测结果,并转换为CPU上的整数Tensor
        self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
        # 将批次的类别标签转换为CPU上的整数Tensor,并添加到目标标签列表中
    # 定义方法,用于最终化模型的指标,如混淆矩阵和速度
    def finalize_metrics(self, *args, **kwargs):
        """Finalizes metrics of the model such as confusion_matrix and speed."""
        # 处理混淆矩阵的类预测和目标值
        self.confusion_matrix.process_cls_preds(self.pred, self.targets)
        # 如果指定生成图表
        if self.args.plots:
            # 遍历两种情况的标准化选项
            for normalize in True, False:
                # 绘制混淆矩阵图表,保存在指定目录下,使用类名列表和标准化参数,触发绘图事件
                    save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
        # 设置速度指标
        self.metrics.speed = self.speed
        # 设置混淆矩阵指标
        self.metrics.confusion_matrix = self.confusion_matrix
        # 设置保存目录指标
        self.metrics.save_dir = self.save_dir

    # 返回处理目标和预测结果后的指标字典
    def get_stats(self):
        """Returns a dictionary of metrics obtained by processing targets and predictions."""
        # 处理目标和预测结果的指标
        self.metrics.process(self.targets, self.pred)
        # 返回结果字典
        return self.metrics.results_dict

    # 创建并返回一个 ClassificationDataset 实例,使用给定的图像路径和预处理参数
    def build_dataset(self, img_path):
        """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters."""
        return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)

    # 构建并返回一个用于分类任务的数据加载器,使用给定的参数
    def get_dataloader(self, dataset_path, batch_size):
        """Builds and returns a data loader for classification tasks with given parameters."""
        # 创建数据集
        dataset = self.build_dataset(dataset_path)
        # 构建数据加载器,使用数据集、批大小、工作进程数和排名参数
        return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)

    # 打印 YOLO 目标检测模型的评估指标
    def print_results(self):
        """Prints evaluation metrics for YOLO object detection model."""
        # 定义打印格式
        pf = "%22s" + "%11.3g" * len(self.metrics.keys)  # print format
        # 使用日志记录器打印所有类别的 Top-1 和 Top-5 指标
        LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))

    # 绘制验证集图像样本
    def plot_val_samples(self, batch, ni):
        """Plot validation image samples."""
        # 绘制图像,使用图像数据、批索引、类别标签、文件名和类名映射,触发绘图事件
            cls=batch["cls"].view(-1),  # warning: use .view(), not .squeeze() for Classify models
            fname=self.save_dir / f"val_batch{ni}_labels.jpg",

    # 绘制输入图像上的预测边界框并保存结果
    def plot_predictions(self, batch, preds, ni):
        """Plots predicted bounding boxes on input images and saves the result."""
        # 绘制图像,使用图像数据、批索引、预测的类别标签、文件名和类名映射,触发绘图事件
            cls=torch.argmax(preds, dim=1),
            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
        )  # pred


# 导入分类任务相关模块和类
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator

# 将这些类添加到模块的公开接口中,方便其他模块导入时使用
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"


# 导入需要的模块和类
from ultralytics.engine.predictor import BasePredictor  # 从 Ultralytics 引擎中导入 BasePredictor 类
from ultralytics.engine.results import Results  # 从 Ultralytics 引擎中导入 Results 类
from ultralytics.utils import ops  # 从 Ultralytics 工具中导入 ops 模块

class DetectionPredictor(BasePredictor):
    A class extending the BasePredictor class for prediction based on a detection model.

        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.detect import DetectionPredictor

        args = dict(model='yolov8n.pt', source=ASSETS)
        predictor = DetectionPredictor(overrides=args)

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions and returns a list of Results objects."""
        # 进行非最大抑制处理,返回预测结果 preds
        preds = ops.non_max_suppression(

        if not isinstance(orig_imgs, list):  # 如果输入的原始图像不是列表而是 torch.Tensor
            # 将 torch.Tensor 转换为 numpy 数组
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            # 缩放预测框的坐标,使其适应原始图像的尺寸
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            # 创建一个 Results 对象并添加到 results 列表中
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        # 返回处理后的 results 列表
        return results


# 导入必要的库和模块
import math
import random
from copy import copy

import numpy as np
import torch.nn as nn

# 导入 Ultralytics 的相关模块和函数
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first

# 定义一个名为 DetectionTrainer 的类,继承自 BaseTrainer 类,用于检测模型的训练
class DetectionTrainer(BaseTrainer):
    A class extending the BaseTrainer class for training based on a detection model.

        from ultralytics.models.yolo.detect import DetectionTrainer

        args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
        trainer = DetectionTrainer(overrides=args)

    # 定义 build_dataset 方法,用于构建 YOLO 数据集
    def build_dataset(self, img_path, mode="train", batch=None):
        Build YOLO Dataset.

            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        # 获取模型的最大步长
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
        # 调用 build_yolo_dataset 函数构建 YOLO 数据集
        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)

    # 定义 get_dataloader 方法,用于构建和返回数据加载器
    def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
        """Construct and return dataloader."""
        # 确保 mode 参数合法
        assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
        # 如果使用分布式训练(DDP),使用 torch_distributed_zero_first 函数初始化数据集 *.cache 以确保只初始化一次
        with torch_distributed_zero_first(rank):
            # 调用 build_dataset 方法构建数据集
            dataset = self.build_dataset(dataset_path, mode, batch_size)
        # 根据 mode 确定是否需要打乱数据集
        shuffle = mode == "train"
        # 如果 dataset 具有 rect 属性且需要打乱,则警告并设置 shuffle=False
        if getattr(dataset, "rect", False) and shuffle:
            LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
            shuffle = False
        # 根据 mode 确定 workers 数量
        workers = self.args.workers if mode == "train" else self.args.workers * 2
        # 调用 build_dataloader 函数构建数据加载器并返回
        return build_dataloader(dataset, batch_size, workers, shuffle, rank)
    def preprocess_batch(self, batch):
        """Preprocesses a batch of images by scaling and converting to float."""
        # 将图像批次移到指定设备上,并转换为浮点数,同时进行归一化(除以255)
        batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
        # 如果启用了多尺度处理
        if self.args.multi_scale:
            imgs = batch["img"]
            # 随机生成一个介于指定范围内的尺度大小
            sz = (
                random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
                // self.stride
                * self.stride
            )  # size
            # 计算尺度因子
            sf = sz / max(imgs.shape[2:])  # scale factor
            # 如果尺度因子不为1,则进行插值操作,调整图像尺寸
            if sf != 1:
                ns = [
                    math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
                ]  # new shape (stretched to gs-multiple)
                imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
            # 更新批次中的图像数据
            batch["img"] = imgs
        # 返回预处理后的批次数据
        return batch

    def set_model_attributes(self):
        """Nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)."""
        # 以下三行被注释掉的代码是关于模型属性设置的尝试,可能是用于调整超参数的缩放
        # self.args.box *= 3 / nl  # scale to layers
        # self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
        # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl  # scale to image size and layers
        # 将类别数量(nc)、类别名称和超参数附加到模型对象上
        self.model.nc = self.data["nc"]  # attach number of classes to model
        self.model.names = self.data["names"]  # attach class names to model
        self.model.args = self.args  # attach hyperparameters to model
        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc

    def get_model(self, cfg=None, weights=None, verbose=True):
        """Return a YOLO detection model."""
        # 创建一个 YOLO 检测模型实例
        model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        # 如果提供了预训练权重,则加载到模型中
        if weights:
        # 返回创建的模型实例
        return model

    def get_validator(self):
        """Returns a DetectionValidator for YOLO model validation."""
        # 设置损失名称列表
        self.loss_names = "box_loss", "cls_loss", "dfl_loss"
        # 返回一个用于 YOLO 模型验证的 DetectionValidator 实例
        return yolo.detect.DetectionValidator(
            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks

    def label_loss_items(self, loss_items=None, prefix="train"):
        Returns a loss dict with labelled training loss items tensor.

        Not needed for classification but necessary for segmentation & detection
        # 构建包含带有标签的训练损失项的字典
        keys = [f"{prefix}/{x}" for x in self.loss_names]
        # 如果提供了损失项,则将其转换为五位小数的浮点数,返回标签化的损失字典
        if loss_items is not None:
            loss_items = [round(float(x), 5) for x in loss_items]  # convert tensors to 5 decimal place floats
            return dict(zip(keys, loss_items))
            # 如果没有提供损失项,返回损失项名称列表
            return keys
    # 返回一个格式化的训练进度字符串,包括 epoch、GPU 内存、损失、实例数和大小等信息
    def progress_string(self):
        return ("\n" + "%11s" * (4 + len(self.loss_names))) % (

    # 绘制带有注释的训练样本图像
    def plot_training_samples(self, batch, ni):
            fname=self.save_dir / f"train_batch{ni}.jpg",

    # 绘制来自 CSV 文件的指标图表
    def plot_metrics(self):
        plot_results(file=self.csv, on_plot=self.on_plot)  # save results.png

    # 创建一个带标签的 YOLO 模型训练图
    def plot_training_labels(self):
        # 从训练数据集的标签中获取所有边界框并连接起来
        boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
        # 获取所有类别并连接起来
        cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
        # 绘制带标签的训练图,使用数据集的类别名称和保存目录
        plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)


# 导入所需的库和模块
import os
from pathlib import Path

import numpy as np
import torch

# 导入自定义的数据处理模块
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
# 导入自定义的验证器基类
from ultralytics.engine.validator import BaseValidator
# 导入自定义的工具模块
from ultralytics.utils import LOGGER, ops
# 导入检查要求的函数
from ultralytics.utils.checks import check_requirements
# 导入评估指标相关模块
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
# 导入绘图函数
from ultralytics.utils.plotting import output_to_target, plot_images

class DetectionValidator(BaseValidator):
    A class extending the BaseValidator class for validation based on a detection model.

        from ultralytics.models.yolo.detect import DetectionValidator

        args = dict(model='yolov8n.pt', data='coco8.yaml')
        validator = DetectionValidator(args=args)

    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """Initialize detection model with necessary variables and settings."""
        # 调用父类构造函数初始化基本变量和设置
        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
        # 初始化特定于检测模型的变量
        self.nt_per_class = None
        self.nt_per_image = None
        self.is_coco = False
        self.is_lvis = False
        self.class_map = None
        self.args.task = "detect"  # 设置任务为检测任务
        self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)  # 初始化评估指标
        self.iouv = torch.linspace(0.5, 0.95, 10)  # IoU向量,用于计算mAP@0.5:0.95
        self.niou = self.iouv.numel()  # 计算IoU向量的长度
        self.lb = []  # 用于自动标注的列表
        if self.args.save_hybrid:
            # 如果设置了保存混合结果,发出警告,说明可能会影响mAP的正确性
                "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n"
                "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n"

    def preprocess(self, batch):
        """Preprocesses batch of images for YOLO training."""
        # 将图像批处理移到设备上,并根据需要进行半精度转换和归一化
        batch["img"] = batch["img"].to(self.device, non_blocking=True)
        batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
        # 将其他信息也移到设备上
        for k in ["batch_idx", "cls", "bboxes"]:
            batch[k] = batch[k].to(self.device)

        if self.args.save_hybrid:
            # 如果设置了保存混合结果
            height, width = batch["img"].shape[2:]
            nb = len(batch["img"])
            # 调整边界框坐标
            bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
            # 为自动标注构建标签列表
            self.lb = [
                torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
                for i in range(nb)

        return batch
    def init_metrics(self, model):
        """Initialize evaluation metrics for YOLO."""
        val = self.data.get(self.args.split, "")  # 获取验证数据集路径
        self.is_coco = (
            isinstance(val, str)
            and "coco" in val
            and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
        )  # 判断是否为 COCO 数据集
        self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco  # 判断是否为 LVIS 数据集
        self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(len(model.names)))  # 根据数据集类型选择类别映射
        self.args.save_json |= (self.is_coco or self.is_lvis) and not self.training  # 如果是 COCO 或 LVIS 数据集且非训练阶段,设置保存 JSON 结果
        self.names = model.names  # 获取模型的类别名称列表
        self.nc = len(model.names)  # 类别数量
        self.metrics.names = self.names  # 设置评估指标的类别名称
        self.metrics.plot = self.args.plots  # 设置是否绘制图像
        self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)  # 初始化混淆矩阵
        self.seen = 0  # 初始化 seen 参数
        self.jdict = []  # 初始化 jdict 列表
        self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])  # 初始化统计信息字典

    def get_desc(self):
        """Return a formatted string summarizing class metrics of YOLO model."""
        return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")  # 返回格式化的描述字符串,总结 YOLO 模型的类别指标

    def postprocess(self, preds):
        """Apply Non-maximum suppression to prediction outputs."""
        return ops.non_max_suppression(
            agnostic=self.args.single_cls or self.args.agnostic_nms,
        )  # 对预测输出应用非最大抑制算法,返回处理后的预测结果

    def _prepare_batch(self, si, batch):
        """Prepares a batch of images and annotations for validation."""
        idx = batch["batch_idx"] == si  # 确定当前批次中与索引 si 对应的样本
        cls = batch["cls"][idx].squeeze(-1)  # 获取类别信息并去除多余的维度
        bbox = batch["bboxes"][idx]  # 获取边界框信息
        ori_shape = batch["ori_shape"][si]  # 获取原始图像形状
        imgsz = batch["img"].shape[2:]  # 获取图像尺寸
        ratio_pad = batch["ratio_pad"][si]  # 获取比例填充信息
        if len(cls):
            bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]  # 转换目标框坐标格式并缩放
            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad)  # 缩放边界框到原始空间
        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}  # 返回处理后的验证批次数据

    def _prepare_pred(self, pred, pbatch):
        """Prepares a batch of images and annotations for validation."""
        predn = pred.clone()  # 克隆预测结果
            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
        )  # 缩放预测框到原始空间
        return predn  # 返回处理后的预测结果
    def update_metrics(self, preds, batch):
        Update metrics based on predictions and batch data.

            preds (list): List of predictions.
            batch (dict): Batch data dictionary.

        # Iterate over predictions
        for si, pred in enumerate(preds):
            self.seen += 1  # Increment the count of processed predictions
            npr = len(pred)  # Number of predictions in the current batch item

            # Initialize statistics dictionary
            stat = dict(
                conf=torch.zeros(0, device=self.device),  # Confidence values tensor
                pred_cls=torch.zeros(0, device=self.device),  # Predicted classes tensor
                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),  # True positives tensor

            # Prepare the batch for processing
            pbatch = self._prepare_batch(si, batch)
            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")  # Extract class labels and bounding boxes
            nl = len(cls)  # Number of ground truth labels in the batch

            # Set target class labels and unique image-level classes
            stat["target_cls"] = cls
            stat["target_img"] = cls.unique()

            # Skip further processing if there are no predictions
            if npr == 0:
                if nl:
                    # Append statistics to respective lists
                    for k in self.stats.keys():
                    # Optionally process confusion matrix if plots are enabled
                    if self.args.plots:
                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)

            # Adjust predictions if single class mode is enabled
            if self.args.single_cls:
                pred[:, 5] = 0

            # Prepare predictions for processing
            predn = self._prepare_pred(pred, pbatch)
            stat["conf"] = predn[:, 4]  # Confidence values from predictions
            stat["pred_cls"] = predn[:, 5]  # Predicted classes from predictions

            # Evaluate predictions if ground truth labels are present
            if nl:
                stat["tp"] = self._process_batch(predn, bbox, cls)  # Calculate true positives
                if self.args.plots:
                    self.confusion_matrix.process_batch(predn, bbox, cls)  # Update confusion matrix

            # Append statistics to respective lists
            for k in self.stats.keys():

            # Save predictions in JSON format if enabled
            if self.args.save_json:
                self.pred_to_json(predn, batch["im_file"][si])

            # Save predictions in text file format if enabled
            if self.args.save_txt:
                    self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',

    def finalize_metrics(self, *args, **kwargs):
        Finalize metric values after all predictions are processed.

            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        # Set final speed metric and confusion matrix values
        self.metrics.speed = self.speed
        self.metrics.confusion_matrix = self.confusion_matrix

    def get_stats(self):
        Retrieve metrics statistics and results.

            dict: Dictionary containing metrics statistics.
        # Convert statistics tensors to numpy arrays
        stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()}
        # Calculate number of ground truth labels per class and per image
        self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
        self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
        # Remove target_img key from statistics dictionary
        stats.pop("target_img", None)
        # Process metrics if there are any true positive predictions
        if len(stats) and stats["tp"].any():
        # Return metrics results dictionary
        return self.metrics.results_dict
    def print_results(self):
        Prints training/validation set metrics per class.
        # 设置打印格式,包括所有类别的计数、每类的样本数、各指标的均值结果
        pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys)  # print format
        # 使用日志记录器打印所有类别的汇总信息
        LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
        # 如果数据集中没有标签,警告用户无法计算指标
        if self.nt_per_class.sum() == 0:
            LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")

        # 按类别打印结果
        if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
            # 对每个类别打印详细的指标结果
            for i, c in enumerate(self.metrics.ap_class_index):
                    pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))

        # 如果设置了绘图选项,则绘制混淆矩阵
        if self.args.plots:
            # 分别绘制归一化和非归一化的混淆矩阵
            for normalize in True, False:
                    save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot

    def _process_batch(self, detections, gt_bboxes, gt_cls):
        Return correct prediction matrix.

            detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is
                (x1, y1, x2, y2, conf, class).
            gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each
                bounding box is of the format: (x1, y1, x2, y2).
            gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices.

            (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels.

            The function does not return any value directly usable for metrics calculation. Instead, it provides an
            intermediate representation used for evaluating predictions against ground truth.
        # 计算检测结果与真实边界框的 IoU
        iou = box_iou(gt_bboxes, detections[:, :4])
        # 返回匹配预测结果的矩阵,用于不同 IoU 水平的评估
        return self.match_predictions(detections[:, 5], gt_cls, iou)

    def build_dataset(self, img_path, mode="val", batch=None):
        Build YOLO Dataset.

            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        # 构建 YOLO 数据集
        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)

    def get_dataloader(self, dataset_path, batch_size):
        Construct and return dataloader.
        # 构建数据集
        dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
        # 构建并返回数据加载器
        return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1)  # return dataloader
    def plot_val_samples(self, batch, ni):
        """Plot validation image samples."""
        # 调用plot_images函数,绘制验证集图像样本,保存为指定文件
            batch["img"],  # 图像数据
            batch["batch_idx"],  # 批次索引
            batch["cls"].squeeze(-1),  # 类别标签
            batch["bboxes"],  # 边界框信息
            paths=batch["im_file"],  # 图像文件路径
            fname=self.save_dir / f"val_batch{ni}_labels.jpg",  # 保存的文件名
            names=self.names,  # 类别名称列表
            on_plot=self.on_plot,  # 绘图回调函数

    def plot_predictions(self, batch, preds, ni):
        """Plots predicted bounding boxes on input images and saves the result."""
        # 调用plot_images函数,绘制输入图像上的预测边界框,并保存结果
            batch["img"],  # 图像数据
            *output_to_target(preds, max_det=self.args.max_det),  # 将预测结果转换为目标格式
            paths=batch["im_file"],  # 图像文件路径
            fname=self.save_dir / f"val_batch{ni}_pred.jpg",  # 保存的文件名
            names=self.names,  # 类别名称列表
            on_plot=self.on_plot,  # 绘图回调函数
        )  # pred

    def save_one_txt(self, predn, save_conf, shape, file):
        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
        from ultralytics.engine.results import Results

        # 创建Results对象,保存YOLO检测结果到指定的txt文件中,使用归一化的坐标
            np.zeros((shape[0], shape[1]), dtype=np.uint8),  # 创建空白图像作为占位
            path=None,  # 文件路径(暂未指定)
            names=self.names,  # 类别名称列表
            boxes=predn[:, :6],  # 检测框信息
        ).save_txt(file, save_conf=save_conf)  # 调用Results对象的保存方法

    def pred_to_json(self, predn, filename):
        """Serialize YOLO predictions to COCO json format."""
        stem = Path(filename).stem  # 获取文件名的主干部分
        image_id = int(stem) if stem.isnumeric() else stem  # 解析图像ID
        box = ops.xyxy2xywh(predn[:, :4])  # 将边界框从xyxy格式转换为xywh格式
        box[:, :2] -= box[:, 2:] / 2  # 将xy中心坐标转换为左上角坐标
        for p, b in zip(predn.tolist(), box.tolist()):  # 遍历每个预测和其对应的边界框
            self.jdict.append(  # 将结果添加到JSON字典中
                    "image_id": image_id,  # 图像ID
                    "category_id": self.class_map[int(p[5])] + (1 if self.is_lvis else 0),  # 类别ID
                    "bbox": [round(x, 3) for x in b],  # 边界框坐标
                    "score": round(p[4], 5),  # 检测置信度得分
    def eval_json(self, stats):
        """Evaluates YOLO output in JSON format and returns performance statistics."""

        # 检查是否需要保存 JSON,并且数据集为 COCO 或 LVIS,且 jdict 非空
        if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
            # 设置预测结果保存路径
            pred_json = self.save_dir / "predictions.json"  # predictions
            # 设置注释文件路径
            anno_json = (
                / "annotations"
                / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
            )  # annotations

            # 根据数据集类型确定使用的包
            pkg = "pycocotools" if self.is_coco else "lvis"

            # 打印评估信息
            LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...")

                # 检查预测结果和注释文件是否存在
                for x in pred_json, anno_json:
                    assert x.is_file(), f"{x} file not found"

                # 检查依赖的版本要求
                check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3")

                if self.is_coco:
                    # COCO 数据集的评估过程
                    from pycocotools.coco import COCO  # noqa
                    from pycocotools.cocoeval import COCOeval  # noqa

                    # 初始化 COCO 数据集的注释 API 和预测 API
                    anno = COCO(str(anno_json))
                    pred = anno.loadRes(str(pred_json))
                    val = COCOeval(anno, pred, "bbox")
                    # LVIS 数据集的评估过程
                    from lvis import LVIS, LVISEval

                    # 初始化 LVIS 数据集的注释 API 和预测 API
                    anno = LVIS(str(anno_json))
                    pred = anno._load_json(str(pred_json))
                    val = LVISEval(anno, pred, "bbox")

                # 设置需要评估的图像 ID 列表
                val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]
                # 执行评估过程
                # 累积评估结果
                # 总结评估结果

                if self.is_lvis:
                    # 如果是 LVIS 数据集,显示详细评估结果

                # 更新统计指标 mAP50-95 和 mAP50
                stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (
                    val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]]

            except Exception as e:
                # 捕获异常并记录警告信息
                LOGGER.warning(f"{pkg} unable to run: {e}")

        # 返回更新后的统计指标字典
        return stats


# 导入 DetectionPredictor、DetectionTrainer 和 DetectionValidator 模块
# 这些模块来自当前目录下的 predict.py、train.py 和 val.py 文件
from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator

# 将 DetectionPredictor、DetectionTrainer 和 DetectionValidator 导出到外部模块使用
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"


# Ultralytics YOLO 🚀, AGPL-3.0 license

from pathlib import Path  # 导入路径操作模块Path

from ultralytics.engine.model import Model  # 导入Ultralytics的模型基类Model
from ultralytics.models import yolo  # 导入Ultralytics的YOLO模块
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel  # 导入Ultralytics的不同任务模型
from ultralytics.utils import ROOT, yaml_load  # 导入Ultralytics的工具函数ROOT和yaml_load

class YOLO(Model):
    """YOLO (You Only Look Once) object detection model."""

    def __init__(self, model="yolov8n.pt", task=None, verbose=False):
        """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
        path = Path(model)  # 使用路径模块Path创建路径对象path,指定模型文件名
        if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}:  # 如果模型文件名包含'-world'并且文件类型是'.pt', '.yaml', '.yml'
            new_instance = YOLOWorld(path, verbose=verbose)  # 创建YOLOWorld的实例new_instance,传入模型路径和是否详细输出参数
            self.__class__ = type(new_instance)  # 设置当前对象的类为new_instance的类
            self.__dict__ = new_instance.__dict__  # 将当前对象的字典设置为new_instance的字典
            # Continue with default YOLO initialization
            super().__init__(model=model, task=task, verbose=verbose)  # 使用默认的YOLO模型初始化过程

    def task_map(self):
        """Map head to model, trainer, validator, and predictor classes."""
        return {
            "classify": {
                "model": ClassificationModel,  # 分类任务的模型类
                "trainer": yolo.classify.ClassificationTrainer,  # 分类任务的训练器类
                "validator": yolo.classify.ClassificationValidator,  # 分类任务的验证器类
                "predictor": yolo.classify.ClassificationPredictor,  # 分类任务的预测器类
            "detect": {
                "model": DetectionModel,  # 检测任务的模型类
                "trainer": yolo.detect.DetectionTrainer,  # 检测任务的训练器类
                "validator": yolo.detect.DetectionValidator,  # 检测任务的验证器类
                "predictor": yolo.detect.DetectionPredictor,  # 检测任务的预测器类
            "segment": {
                "model": SegmentationModel,  # 分割任务的模型类
                "trainer": yolo.segment.SegmentationTrainer,  # 分割任务的训练器类
                "validator": yolo.segment.SegmentationValidator,  # 分割任务的验证器类
                "predictor": yolo.segment.SegmentationPredictor,  # 分割任务的预测器类
            "pose": {
                "model": PoseModel,  # 姿态估计任务的模型类
                "trainer": yolo.pose.PoseTrainer,  # 姿态估计任务的训练器类
                "validator": yolo.pose.PoseValidator,  # 姿态估计任务的验证器类
                "predictor": yolo.pose.PosePredictor,  # 姿态估计任务的预测器类
            "obb": {
                "model": OBBModel,  # 目标边界框任务的模型类
                "trainer": yolo.obb.OBBTrainer,  # 目标边界框任务的训练器类
                "validator": yolo.obb.OBBValidator,  # 目标边界框任务的验证器类
                "predictor": yolo.obb.OBBPredictor,  # 目标边界框任务的预测器类

class YOLOWorld(Model):
    """YOLO-World object detection model."""
    def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
        Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.

            model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
        # 调用父类的初始化方法,传入模型路径和任务类型为'detect',同时设置是否详细输出信息
        super().__init__(model=model, task="detect", verbose=verbose)

        # 如果模型对象没有属性 'names',则加载默认的 COCO 类别名称
        if not hasattr(self.model, "names"):
            self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")

    def task_map(self):
        """Map head to model, validator, and predictor classes."""
        # 返回一个字典,映射任务类型为'detect'时对应的模型类、验证器类、预测器类和训练器类
        return {
            "detect": {
                "model": WorldModel,
                "validator": yolo.detect.DetectionValidator,
                "predictor": yolo.detect.DetectionPredictor,
                "trainer": yolo.world.WorldTrainer,

    def set_classes(self, classes):
        Set classes.

            classes (List(str)): A list of categories i.e. ["person"].
        # 调用模型对象的设置类别方法,设置新的类别列表
        # 如果类别列表中包含背景类别,将其移除
        background = " "
        if background in classes:
        # 更新模型的类别名称为新的类别列表
        self.model.names = classes

        # 重置预测器对象的类别名称
        if self.predictor:
            self.predictor.model.names = classes


# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入 PyTorch 库
import torch

# 导入 Ultralytics 相关模块和函数
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops

class OBBPredictor(DetectionPredictor):
    一个扩展了 DetectionPredictor 类的类,用于基于定向边界框(OBB)模型进行预测。

        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.obb import OBBPredictor

        args = dict(model='yolov8n-obb.pt', source=ASSETS)
        predictor = OBBPredictor(overrides=args)

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """初始化 OBBPredictor 类,可选择模型和数据配置的覆盖设置。"""
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "obb"

    def postprocess(self, preds, img, orig_imgs):
        """后处理预测结果并返回 Results 对象的列表。"""
        # 执行非最大抑制以筛选预测框
        preds = ops.non_max_suppression(

        # 如果输入的原始图像不是列表而是一个 torch.Tensor
        if not isinstance(orig_imgs, list):
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        # 遍历每个预测结果、原始图像和图像路径
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            # 规范化旋转框坐标并进行缩放调整
            rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
            rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
            # 创建包含 OBB 信息的 tensor:xywh, r, conf, cls
            obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
            # 将处理后的结果添加到 results 列表中
            results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
        return results


# 导入必要的模块和类
from copy import copy  # 导入copy函数,用于复制对象

from ultralytics.models import yolo  # 导入yolo模型
from ultralytics.nn.tasks import OBBModel  # 导入OBBModel类
from ultralytics.utils import DEFAULT_CFG, RANK  # 导入默认配置和RANK变量

class OBBTrainer(yolo.detect.DetectionTrainer):

        from ultralytics.models.yolo.obb import OBBTrainer

        args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3)
        trainer = OBBTrainer(overrides=args)

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        if overrides is None:
            overrides = {}
        overrides["task"] = "obb"  # 设置任务为"obb"
        super().__init__(cfg, overrides, _callbacks)

    def get_model(self, cfg=None, weights=None, verbose=True):
        model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)  # 如果有指定权重,则加载这些权重到模型中

        return model

    def get_validator(self):
        self.loss_names = "box_loss", "cls_loss", "dfl_loss"  # 设置损失名称
        return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))


# 导入必要的库和模块,包括路径操作和PyTorch
from pathlib import Path
import torch

# 从Ultralytics中导入YOLO检测相关的类和函数
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
from ultralytics.utils.plotting import output_to_rotated_target, plot_images

# 定义一个名为OBBValidator的类,继承自DetectionValidator类,用于面向定向边界框(OBB)模型的验证
class OBBValidator(DetectionValidator):
    A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.

        from ultralytics.models.yolo.obb import OBBValidator

        args = dict(model='yolov8n-obb.pt', data='dota8.yaml')
        validator = OBBValidator(args=args)

    # 初始化方法,设置任务为'obb',指定评估指标为OBBMetrics
    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
        self.args.task = "obb"  # 设置任务类型为'obb'
        self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)  # 初始化OBBMetrics用于评估

    # 初始化评估指标,特定于YOLO模型的初始化
    def init_metrics(self, model):
        """Initialize evaluation metrics for YOLO."""
        val = self.data.get(self.args.split, "")  # 获取验证数据集路径
        self.is_dota = isinstance(val, str) and "DOTA" in val  # 判断数据集是否为DOTA格式(COCO的一种)

    # 后处理方法,对预测结果应用非极大值抑制
    def postprocess(self, preds):
        """Apply Non-maximum suppression to prediction outputs."""
        return ops.non_max_suppression(
            self.args.conf,  # 置信度阈值
            self.args.iou,  # IoU阈值
            labels=self.lb,  # 类别标签
            nc=self.nc,  # 类别数
            multi_label=True,  # 是否多标签
            agnostic=self.args.single_cls,  # 是否单类别检测
            max_det=self.args.max_det,  # 最大检测数
            rotated=True,  # 是否是旋转边界框
    def _process_batch(self, detections, gt_bboxes, gt_cls):
        Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.

            detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
                data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
            gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
                represented as (x1, y1, x2, y2, angle).
            gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.

            (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
                Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.

            detections = torch.rand(100, 7)  # 100 sample detections
            gt_bboxes = torch.rand(50, 5)  # 50 sample ground truth boxes
            gt_cls = torch.randint(0, 5, (50,))  # 50 ground truth class labels
            correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)

            This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
        # Calculate IoU (Intersection over Union) between detections and ground truth boxes
        iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
        # Match predictions based on class labels and calculated IoU
        return self.match_predictions(detections[:, 5], gt_cls, iou)

    def _prepare_batch(self, si, batch):
        """Prepares and returns a batch for OBB validation."""
        # Select indices matching the batch index si
        idx = batch["batch_idx"] == si
        # Extract class labels and squeeze the tensor if necessary
        cls = batch["cls"][idx].squeeze(-1)
        # Extract bounding boxes corresponding to the selected indices
        bbox = batch["bboxes"][idx]
        # Retrieve original shape of the image batch
        ori_shape = batch["ori_shape"][si]
        # Extract image size
        imgsz = batch["img"].shape[2:]
        # Retrieve ratio padding for the batch index si
        ratio_pad = batch["ratio_pad"][si]
        if len(cls):
            # Scale target boxes using image size
            bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]])  # target boxes
            # Scale and pad boxes in native-space labels
            ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True)  # native-space labels
        return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}

    def _prepare_pred(self, pred, pbatch):
        """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
        # Create a deep copy of predictions
        predn = pred.clone()
        # Scale and pad predicted boxes in native-space
            pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
        )  # native-space pred
        return predn
    def plot_predictions(self, batch, preds, ni):
        Plots predicted bounding boxes on input images and saves the result.

            batch (dict): A dictionary containing batch data, including images and paths.
            preds (torch.Tensor): Predictions from the model.
            ni (int): Index of the batch.

        # Call plot_images function to plot bounding boxes on images and save the result
            *output_to_rotated_target(preds, max_det=self.args.max_det),
            fname=self.save_dir / f"val_batch{ni}_pred.jpg",
        )  # pred

    def pred_to_json(self, predn, filename):
        Serialize YOLO predictions to COCO json format.

            predn (torch.Tensor): Predictions in tensor format.
            filename (str): File name for saving JSON.

        # Extract stem from filename
        stem = Path(filename).stem
        # Determine image_id based on stem (numeric or string)
        image_id = int(stem) if stem.isnumeric() else stem
        # Convert bounding box predictions to COCO polygon format
        rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
        poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
        # Append predictions to self.jdict in COCO JSON format
        for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
                    "image_id": image_id,
                    "category_id": self.class_map[int(predn[i, 5].item())],
                    "score": round(predn[i, 4].item(), 5),
                    "rbox": [round(x, 3) for x in r],
                    "poly": [round(x, 3) for x in b],

    def save_one_txt(self, predn, save_conf, shape, file):
        Save YOLO detections to a txt file in normalized coordinates in a specific format.

            predn (torch.Tensor): Predictions in tensor format.
            save_conf (bool): Whether to save confidence scores.
            shape (tuple): Shape of the image.
            file (str): File path to save the txt file.

        import numpy as np
        from ultralytics.engine.results import Results

        # Convert predicted boxes to oriented bounding boxes (OBB)
        rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
        # xywh, r, conf, cls
        obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
        # Save detections to a txt file using Results class from ultralytics
            np.zeros((shape[0], shape[1]), dtype=np.uint8),
        ).save_txt(file, save_conf=save_conf)


# 导入本地模块中的OBBPredictor类、OBBTrainer类和OBBValidator类
from .predict import OBBPredictor
from .train import OBBTrainer
from .val import OBBValidator

# 将OBBPredictor、OBBTrainer和OBBValidator这三个符号添加到__all__列表中,使它们可以被 from module import * 导入
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"


# 导入所需模块和类
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops

class PosePredictor(DetectionPredictor):

        from ultralytics.utils import ASSETS
        from ultralytics.models.yolo.pose import PosePredictor

        args = dict(model='yolov8n-pose.pt', source=ASSETS)
        predictor = PosePredictor(overrides=args)

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        super().__init__(cfg, overrides, _callbacks)
        self.args.task = "pose"
        # 如果设备是字符串类型且为'mps',记录警告信息
        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."

    def postprocess(self, preds, img, orig_imgs):
        # 执行非最大抑制操作,获取最终的预测结果
        preds = ops.non_max_suppression(

        # 如果输入的原始图像不是列表而是torch.Tensor
        if not isinstance(orig_imgs, list):
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        # 存储处理后的结果
        results = []
        # 对每个预测结果进行处理
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            # 缩放边界框坐标到原始图像尺寸
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
            # 如果存在关键点预测,也对关键点坐标进行缩放
            pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
            pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
            # 构建Results对象并添加到结果列表中
                Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
        # 返回最终的结果列表
        return results


# 导入所需的模块和函数
from copy import copy

# 导入 Ultralytics 中的 YOLO 模型相关内容
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_results

# 定义一个名为 PoseTrainer 的类,继承自 DetectionTrainer 类
class PoseTrainer(yolo.detect.DetectionTrainer):
    PoseTrainer 类,扩展自 DetectionTrainer 类,用于基于姿态模型进行训练。

        from ultralytics.models.yolo.pose import PoseTrainer

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
        trainer = PoseTrainer(overrides=args)

    # 初始化 PoseTrainer 对象,设置配置和覆盖参数
    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """Initialize a PoseTrainer object with specified configurations and overrides."""
        if overrides is None:
            overrides = {}
        # 将任务类型设置为 "pose"
        overrides["task"] = "pose"
        super().__init__(cfg, overrides, _callbacks)

        # 如果设备类型为 "mps",则输出警告信息
        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."

    # 获取带有指定配置和权重的姿态估计模型
    def get_model(self, cfg=None, weights=None, verbose=True):
        """Get pose estimation model with specified configuration and weights."""
        model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
        if weights:

        return model

    # 设置 PoseModel 的关键点形状属性
    def set_model_attributes(self):
        """Sets keypoints shape attribute of PoseModel."""
        self.model.kpt_shape = self.data["kpt_shape"]

    # 获取 PoseValidator 类的实例,用于验证
    def get_validator(self):
        """Returns an instance of the PoseValidator class for validation."""
        # 设置损失名称
        self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
        return yolo.pose.PoseValidator(
            self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks

    # 绘制一批训练样本的图像,包括类标签、边界框和关键点的注释
    def plot_training_samples(self, batch, ni):
        """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
        images = batch["img"]
        kpts = batch["keypoints"]
        cls = batch["cls"].squeeze(-1)
        bboxes = batch["bboxes"]
        paths = batch["im_file"]
        batch_idx = batch["batch_idx"]
        # 调用 plot_images 函数绘制图像
            fname=self.save_dir / f"train_batch{ni}.jpg",

    # 绘制训练/验证指标的图表
    def plot_metrics(self):
        """Plots training/val metrics."""
        # 调用 plot_results 函数绘制训练/验证指标的图表并保存为 results.png 文件
        plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png


# 导入必要的模块和类
from pathlib import Path  # 提供处理路径的类
import numpy as np  # 提供数值计算支持
import torch  # 提供深度学习框架支持

# 导入 Ultralytics 内部模块和函数
from ultralytics.models.yolo.detect import DetectionValidator  # 导入检测模型验证器
from ultralytics.utils import LOGGER, ops  # 导入日志和操作函数
from ultralytics.utils.checks import check_requirements  # 导入检查依赖的函数
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou  # 导入评估指标相关函数和常量
from ultralytics.utils.plotting import output_to_target, plot_images  # 导入输出和绘图函数

class PoseValidator(DetectionValidator):
    一个用于基于姿势模型进行验证的 DetectionValidator 的子类。

        from ultralytics.models.yolo.pose import PoseValidator

        args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
        validator = PoseValidator(args=args)

    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """初始化 PoseValidator 对象,设置自定义参数并分配属性。"""
        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
        self.sigma = None  # 初始化 sigma 参数为 None
        self.kpt_shape = None  # 初始化关键点形状参数为 None
        self.args.task = "pose"  # 设置任务类型为 "pose"
        self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)  # 初始化评估指标
        if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
                "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
                "See https://github.com/ultralytics/ultralytics/issues/4031."
            )  # 如果设备为 Apple MPS,发出警告建议使用 CPU 运行姿势模型

    def preprocess(self, batch):
        """预处理批次数据,将 'keypoints' 数据转换为浮点数并移动到指定设备上。"""
        batch = super().preprocess(batch)  # 调用父类方法预处理批次数据
        batch["keypoints"] = batch["keypoints"].to(self.device).float()  # 将关键点数据转换为浮点数并移到设备上
        return batch

    def get_desc(self):
        return ("%22s" + "%11s" * 10) % (
        )  # 返回评估指标的表头描述字符串

    def postprocess(self, preds):
        return ops.non_max_suppression(
        )  # 对预测结果应用非极大值抑制,返回处理后的检测结果
    # 初始化 YOLO 模型的姿态估计指标
    def init_metrics(self, model):
        # 调用父类的初始化方法,初始化模型指标
        # 获取关键点形状信息
        self.kpt_shape = self.data["kpt_shape"]
        # 判断是否为姿态估计(关键点形状为 [17, 3])
        is_pose = self.kpt_shape == [17, 3]
        # 关键点数量
        nkpt = self.kpt_shape[0]
        # 根据是否为姿态估计设置高斯核大小
        self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
        # 初始化统计信息字典
        self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])

    # 准备批次数据,将关键点转换为浮点数并移到设备上
    def _prepare_batch(self, si, batch):
        # 调用父类的准备批次方法,获取处理后的批次数据
        pbatch = super()._prepare_batch(si, batch)
        # 获取当前批次中的关键点
        kpts = batch["keypoints"][batch["batch_idx"] == si]
        # 获取图像高度和宽度
        h, w = pbatch["imgsz"]
        # 克隆关键点数据
        kpts = kpts.clone()
        # 缩放关键点坐标到图像尺寸
        kpts[..., 0] *= w
        kpts[..., 1] *= h
        # 使用图像尺寸和原始形状比例对关键点坐标进行缩放
        kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
        # 将处理后的关键点数据存入批次中
        pbatch["kpts"] = kpts
        # 返回处理后的批次数据
        return pbatch

    # 准备预测数据,为姿态处理准备并缩放批次中的关键点
    def _prepare_pred(self, pred, pbatch):
        # 调用父类的准备预测方法,获取处理后的预测数据
        predn = super()._prepare_pred(pred, pbatch)
        # 获取关键点数量
        nk = pbatch["kpts"].shape[1]
        # 提取预测关键点坐标并重塑其形状
        pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
        # 使用图像尺寸和原始形状比例对预测关键点坐标进行缩放
        ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
        # 返回处理后的预测数据及预测关键点坐标
        return predn, pred_kpts
    # 定义一个方法用于更新指标数据,接受预测结果和批处理数据作为输入
    def update_metrics(self, preds, batch):
        # 遍历预测结果列表
        for si, pred in enumerate(preds):
            # 增加已处理样本计数器
            self.seen += 1
            # 获取当前预测结果中预测的数量
            npr = len(pred)
            # 初始化统计信息字典,包括置信度、预测类别和真阳性标志
            stat = dict(
                conf=torch.zeros(0, device=self.device),  # 置信度初始化为空张量
                pred_cls=torch.zeros(0, device=self.device),  # 预测类别初始化为空张量
                tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),  # 真阳性标志初始化为零张量
                tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),  # 预测关键点真阳性标志初始化为零张量
            # 准备当前批次的数据
            pbatch = self._prepare_batch(si, batch)
            # 分别提取类别和边界框数据
            cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
            # 获取类别数量
            nl = len(cls)
            # 将真实类别和独特的类别 ID 存储到统计信息中
            stat["target_cls"] = cls
            stat["target_img"] = cls.unique()
            # 如果预测数量为零,则跳过当前循环
            if npr == 0:
                if nl:
                    # 将统计信息存储到统计数据字典中
                    for k in self.stats.keys():
                    # 如果设置了绘图选项,则处理混淆矩阵
                    if self.args.plots:
                        self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)

            # 如果设置了单类别模式,则将预测结果中的类别分数清零
            if self.args.single_cls:
                pred[:, 5] = 0

            # 准备预测数据和关键点数据
            predn, pred_kpts = self._prepare_pred(pred, pbatch)
            # 将预测结果中的置信度和类别分数存储到统计信息中
            stat["conf"] = predn[:, 4]
            stat["pred_cls"] = predn[:, 5]

            # 如果存在真实类别,则评估真阳性标志和预测关键点真阳性标志
            if nl:
                stat["tp"] = self._process_batch(predn, bbox, cls)
                stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
                # 如果设置了绘图选项,则处理混淆矩阵
                if self.args.plots:
                    self.confusion_matrix.process_batch(predn, bbox, cls)

            # 将当前统计信息存储到统计数据字典中
            for k in self.stats.keys():

            # 如果设置了保存 JSON 选项,则将预测结果保存为 JSON 格式
            if self.args.save_json:
                self.pred_to_json(predn, batch["im_file"][si])
            # 如果设置了保存文本文件选项,则将预测结果保存为文本文件
            if self.args.save_txt:
                    self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt',
    def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
        Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.

            detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
                detection is of the format (x1, y1, x2, y2, conf, class).
            gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
                box is of the format (x1, y1, x2, y2).
            gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
            pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
                51 corresponds to 17 keypoints each having 3 values.
            gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.

            torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
                where N is the number of detections.

            detections = torch.rand(100, 6)  # 100 predictions: (x1, y1, x2, y2, conf, class)
            gt_bboxes = torch.rand(50, 4)  # 50 ground truth boxes: (x1, y1, x2, y2)
            gt_cls = torch.randint(0, 2, (50,))  # 50 ground truth class indices
            pred_kpts = torch.rand(100, 51)  # 100 predicted keypoints
            gt_kpts = torch.rand(50, 51)  # 50 ground truth keypoints
            correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)

            `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
        if pred_kpts is not None and gt_kpts is not None:
            # 计算每个 ground truth 边界框的面积,乘以 0.53 作为尺度因子
            area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
            # 计算预测关键点与 ground truth 关键点之间的 IoU
            iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
        else:  # boxes
            # 如果没有关键点信息,则计算边界框之间的 IoU
            iou = box_iou(gt_bboxes, detections[:, :4])

        # 返回匹配预测结果的矩阵
        return self.match_predictions(detections[:, 5], gt_cls, iou)

    def plot_val_samples(self, batch, ni):
        """Plots and saves validation set samples with predicted bounding boxes and keypoints."""
            batch["img"],  # 图像数据
            batch["batch_idx"],  # 批次索引
            batch["cls"].squeeze(-1),  # 类别标签
            batch["bboxes"],  # 预测的边界框
            kpts=batch["keypoints"],  # 预测的关键点
            paths=batch["im_file"],  # 图像文件路径
            fname=self.save_dir / f"val_batch{ni}_labels.jpg",  # 保存文件名
            names=self.names,  # 类别名称
            on_plot=self.on_plot,  # 绘图回调函数
    def plot_predictions(self, batch, preds, ni):
        """Plots predictions for YOLO model."""
        # Concatenate keypoints predictions from all batches of predictions
        pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
        # Plot images with predictions overlaid
            batch["img"],  # Input images batch
            *output_to_target(preds, max_det=self.args.max_det),  # Convert predictions to target format
            kpts=pred_kpts,  # Predicted keypoints
            paths=batch["im_file"],  # File paths of input images
            fname=self.save_dir / f"val_batch{ni}_pred.jpg",  # Output filename for the plotted image
            names=self.names,  # List of class names
            on_plot=self.on_plot,  # Callback function for additional plotting actions
        )  # pred

    def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
        """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
        from ultralytics.engine.results import Results
        # Save YOLO detections as a TXT file
            np.zeros((shape[0], shape[1]), dtype=np.uint8),  # Placeholder image data
            path=None,  # Path not used for saving TXT
            names=self.names,  # List of class names
            boxes=predn[:, :6],  # Detected bounding boxes
            keypoints=pred_kpts,  # Detected keypoints
        ).save_txt(file, save_conf=save_conf)  # Save detections to the specified file

    def pred_to_json(self, predn, filename):
        """Converts YOLO predictions to COCO JSON format."""
        stem = Path(filename).stem  # Get the stem (base filename without extension)
        image_id = int(stem) if stem.isnumeric() else stem  # Convert stem to integer if numeric, else keep as string
        box = ops.xyxy2xywh(predn[:, :4])  # Convert bounding boxes from xyxy to xywh format
        box[:, :2] -= box[:, 2:] / 2  # Adjust bounding box coordinates from xy center to top-left corner
        for p, b in zip(predn.tolist(), box.tolist()):  # Iterate over each prediction and adjusted bounding box
            self.jdict.append(  # Append to JSON dictionary
                    "image_id": image_id,  # Image identifier
                    "category_id": self.class_map[int(p[5])],  # Category ID from class map
                    "bbox": [round(x, 3) for x in b],  # Rounded bounding box coordinates
                    "keypoints": p[6:],  # Predicted keypoints
                    "score": round(p[4], 5),  # Confidence score rounded to 5 decimal places
    def eval_json(self, stats):
        """Evaluates object detection model using COCO JSON format."""
        # 检查是否需要保存 JSON,并且确保是 COCO 格式,并且 jdict 不为空
        if self.args.save_json and self.is_coco and len(self.jdict):
            # 设置注释文件和预测文件的路径
            anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json"  # annotations
            pred_json = self.save_dir / "predictions.json"  # predictions
            # 打印评估信息,包括预测文件和注释文件的路径
            LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
            try:  # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
                # 检查是否满足 pycocotools 的要求
                # 导入 pycocotools 的相关模块
                from pycocotools.coco import COCO  # noqa
                from pycocotools.cocoeval import COCOeval  # noqa

                # 确保注释文件和预测文件存在
                for x in anno_json, pred_json:
                    assert x.is_file(), f"{x} file not found"
                # 初始化 COCO 对象,用于处理注释
                anno = COCO(str(anno_json))  # init annotations api
                # 加载预测结果到 COCO 对象中,必须传入字符串而不是 Path 对象
                pred = anno.loadRes(str(pred_json))  # init predictions api (must pass string, not Path)
                # 遍历两种评估(bbox 和 keypoints)
                for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
                    # 如果是 COCO 数据集,设置要评估的图片 IDs
                    if self.is_coco:
                        eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files]  # im to eval
                    # 执行评估
                    # 累积评估结果
                    # 汇总评估指标
                    # 更新 mAP50-95 和 mAP50 的统计数据到 stats 字典中
                    idx = i * 4 + 2
                    stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
                    ]  # update mAP50-95 and mAP50
            except Exception as e:
                # 捕获异常,打印警告信息
                LOGGER.warning(f"pycocotools unable to run: {e}")
        # 返回更新后的 stats 字典
        return stats
