Yolov8-源码解析-三十-

Yolov8 源码解析(三十)

.\yolov8\ultralytics\data\__init__.py

# 导入基础数据集类 BaseDataset 从当前包的 base 模块中
# 导入构建数据加载器的函数 build_dataloader 和构建 grounding 的函数 build_grounding
# 导入构建 YOLO 数据集的函数 build_yolo_dataset 和加载推断源的函数 load_inference_source
# 导入分类数据集类 ClassificationDataset, 视觉 grounding 数据集类 GroundingDataset,
# 语义数据集类 SemanticDataset, YOLO 合并数据集类 YOLOConcatDataset, YOLO 数据集类 YOLODataset,
# 多模态 YOLO 数据集类 YOLOMultiModalDataset 从当前包的 dataset 模块中

from .base import BaseDataset
from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source
from .dataset import (
    ClassificationDataset,
    GroundingDataset,
    SemanticDataset,
    YOLOConcatDataset,
    YOLODataset,
    YOLOMultiModalDataset,
)

__all__ = (
    "BaseDataset",
    "ClassificationDataset",
    "SemanticDataset",
    "YOLODataset",
    "YOLOMultiModalDataset",
    "YOLOConcatDataset",
    "GroundingDataset",
    "build_yolo_dataset",
    "build_grounding",
    "build_dataloader",
    "load_inference_source",
)

.\yolov8\ultralytics\engine\exporter.py

# 导入必要的库和模块

import gc  # 垃圾回收模块,用于管理内存中不再需要的对象
import json  # JSON 数据处理模块
import os  # 操作系统相关功能模块
import shutil  # 文件操作模块,用于复制、移动和删除文件
import subprocess  # 子进程管理模块,用于执行外部命令
import time  # 时间模块,提供时间相关的函数
import warnings  # 警告处理模块,用于管理警告信息

from copy import deepcopy  # 深拷贝函数,用于创建对象的完整拷贝
from datetime import datetime  # 日期时间模块,提供日期和时间的处理功能
from pathlib import Path  # 路径操作模块,用于处理文件和目录路径

import numpy as np  # 数组处理模块,提供多维数组和矩阵操作
import torch  # PyTorch 深度学习库

from ultralytics.cfg import TASK2DATA, get_cfg  # 导入特定配置和配置获取函数
from ultralytics.data import build_dataloader  # 数据加载器构建函数
from ultralytics.data.dataset import YOLODataset  # YOLO 数据集类
from ultralytics.data.utils import check_cls_dataset, check_det_dataset  # 数据集检查函数
# 导入需要的模块和函数
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
from ultralytics.utils import (
    ARM64,
    DEFAULT_CFG,
    IS_JETSON,
    LINUX,
    LOGGER,
    MACOS,
    PYTHON_VERSION,
    ROOT,
    WINDOWS,
    __version__,
    callbacks,
    colorstr,
    get_default_args,
    yaml_save,
)
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
from ultralytics.utils.files import file_size, spaces_in_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode


# 定义一个函数,用于返回YOLOv8模型的导出格式
def export_formats():
    """YOLOv8 export formats."""
    # 引入pandas,提高“import ultralytics”的速度
    import pandas  # scope for faster 'import ultralytics'

    # 定义支持的导出格式列表
    x = [
        ["PyTorch", "-", ".pt", True, True],
        ["TorchScript", "torchscript", ".torchscript", True, True],
        ["ONNX", "onnx", ".onnx", True, True],
        ["OpenVINO", "openvino", "_openvino_model", True, False],
        ["TensorRT", "engine", ".engine", False, True],
        ["CoreML", "coreml", ".mlpackage", True, False],
        ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
        ["TensorFlow GraphDef", "pb", ".pb", True, True],
        ["TensorFlow Lite", "tflite", ".tflite", True, False],
        ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
        ["TensorFlow.js", "tfjs", "_web_model", True, False],
        ["PaddlePaddle", "paddle", "_paddle_model", True, True],
        ["NCNN", "ncnn", "_ncnn_model", True, True],
    ]
    # 返回格式列表的DataFrame形式,包含列名
    return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])


# 定义一个函数,用于提取TensorFlow GraphDef模型的输出节点名称列表
def gd_outputs(gd):
    """TensorFlow GraphDef model output node names."""
    # 初始化节点名称列表和输入节点列表
    name_list, input_list = [], []
    # 遍历GraphDef对象的节点,获取节点名称和输入节点名称
    for node in gd.node:  # tensorflow.core.framework.node_def_pb2.NodeDef
        name_list.append(node.name)
        input_list.extend(node.input)
    # 返回排序后的输出节点名称列表,排除无关节点和输入节点
    return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))


# 定义一个装饰器函数,用于YOLOv8模型的导出
def try_export(inner_func):
    """YOLOv8 export decorator, i.e. @try_export."""
    # 获取内部函数的默认参数
    inner_args = get_default_args(inner_func)

    def outer_func(*args, **kwargs):
        """Export a model."""
        # 获取导出前缀
        prefix = inner_args["prefix"]
        try:
            # 使用Profile类记录导出时间
            with Profile() as dt:
                # 调用内部函数获取导出文件和模型对象
                f, model = inner_func(*args, **kwargs)
            # 打印导出成功信息,包括导出时间、文件大小
            LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)")
            # 返回导出的文件名和模型对象
            return f, model
        except Exception as e:
            # 打印导出失败信息,并抛出异常
            LOGGER.info(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}")
            raise e

    # 返回外部函数对象
    return outer_func


# 定义一个导出类Exporter,用于导出模型
class Exporter:
    """
    A class for exporting a model.
    """

    # 在此处可以添加具体的导出方法和逻辑,根据实际需求编写
    """
    Attributes:
        args (SimpleNamespace): Configuration for the exporter.
        callbacks (list, optional): List of callback functions. Defaults to None.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the Exporter class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
            _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
        """
        # 获取配置参数并存储在self.args中
        self.args = get_cfg(cfg, overrides)

        # 如果输出格式为'coreml'或'mlmodel',尝试修复protobuf<3.20.x的错误
        if self.args.format.lower() in {"coreml", "mlmodel"}:
            # 设置环境变量,解决TensorBoard回调之前的问题
            os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

        # 设置回调函数列表,如果未提供_callbacks,则使用默认回调函数
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        # 将集成回调函数添加到回调列表中
        callbacks.add_integration_callbacks(self)

    @smart_inference_mode()
    def get_int8_calibration_dataloader(self, prefix=""):
        """Build and return a dataloader suitable for calibration of INT8 models."""
        # 记录信息,指示正在从指定数据集中收集INT8校准图像
        LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'")
        
        # 根据任务类型选择适当的数据集处理函数
        data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data)
        
        # 创建YOLO数据集对象,用于模型校准
        dataset = YOLODataset(
            data[self.args.split or "val"],  # 选择验证集或其他指定的数据集分割
            data=data,
            task=self.model.task,
            imgsz=self.imgsz[0],  # 图像尺寸
            augment=False,
            batch_size=self.args.batch * 2,  # TensorRT INT8校准应使用2倍批处理大小
        )
        
        # 数据集的长度
        n = len(dataset)
        # 如果数据集长度小于300,则发出警告
        if n < 300:
            LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
        
        # 构建数据加载器,并返回
        return build_dataloader(dataset, batch=self.args.batch * 2, workers=0)  # 批量加载所需的参数设置

    @try_export
    def export_torchscript(self, prefix=colorstr("TorchScript:")):
        """YOLOv8 TorchScript model export."""
        # 记录信息,指示使用torch版本开始导出TorchScript模型
        LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
        
        # 设置导出文件路径
        f = self.file.with_suffix(".torchscript")
        
        # 使用torch.jit.trace对模型进行追踪,生成TorchScript表示
        ts = torch.jit.trace(self.model, self.im, strict=False)
        
        # 准备额外的文件,以便与模型一起导出
        extra_files = {"config.txt": json.dumps(self.metadata)}  # torch._C.ExtraFilesMap()
        
        # 如果设置了优化选项,则进行模型优化
        if self.args.optimize:
            LOGGER.info(f"{prefix} optimizing for mobile...")
            from torch.utils.mobile_optimizer import optimize_for_mobile

            # 对模型进行移动端优化,并保存为Lite解释器可以加载的格式
            optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
        else:
            # 直接保存TorchScript模型
            ts.save(str(f), _extra_files=extra_files)
        
        # 返回导出的文件路径和空值(None)
        return f, None

    @try_export
    # 定义导出 ONNX 模型的方法,可选参数为前缀字符串
    def export_onnx(self, prefix=colorstr("ONNX:")):
        """YOLOv8 ONNX export."""
        # 定义所需的第三方库依赖
        requirements = ["onnx>=1.12.0"]
        # 如果设置了简化选项,则添加相关的库依赖
        if self.args.simplify:
            requirements += ["onnxslim>=0.1.31", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
        # 检查所需的库依赖是否满足
        check_requirements(requirements)
        # 导入 onnx 库
        import onnx  # noqa

        # 获取操作集版本号,若未指定则使用最新版本
        opset_version = self.args.opset or get_latest_opset()
        # 打印导出信息,包括 onnx 版本和操作集版本
        LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
        # 设定导出后的文件路径
        f = str(self.file.with_suffix(".onnx"))

        # 根据模型类型设置输出节点名称
        output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
        # 获取是否启用动态形状的标志
        dynamic = self.args.dynamic
        # 若启用动态形状
        if dynamic:
            # 设置动态形状的映射关系,针对不同模型类型设定不同的动态形状
            dynamic = {"images": {0: "batch", 2: "height", 3: "width"}}  # shape(1,3,640,640)
            if isinstance(self.model, SegmentationModel):
                dynamic["output0"] = {0: "batch", 2: "anchors"}  # shape(1, 116, 8400)
                dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"}  # shape(1,32,160,160)
            elif isinstance(self.model, DetectionModel):
                dynamic["output0"] = {0: "batch", 2: "anchors"}  # shape(1, 84, 8400)

        # 导出 ONNX 模型
        torch.onnx.export(
            self.model.cpu() if dynamic else self.model,  # 若启用动态形状,导出前先将模型移至 CPU
            self.im.cpu() if dynamic else self.im,  # 若启用动态形状,导出前先将输入数据移至 CPU
            f,
            verbose=False,
            opset_version=opset_version,
            do_constant_folding=True,  # 是否执行常量折叠优化
            input_names=["images"],  # 输入节点的名称
            output_names=output_names,  # 输出节点的名称
            dynamic_axes=dynamic or None,  # 动态形状的轴信息,若未启用动态形状则为 None
        )

        # 加载导出的 ONNX 模型
        model_onnx = onnx.load(f)
        # 检查 ONNX 模型的有效性
        # onnx.checker.check_model(model_onnx)  # 检查 ONNX 模型

        # 如果设置了简化选项,则尝试使用 onnxslim 进行模型简化
        if self.args.simplify:
            try:
                import onnxslim
                # 使用 onnxslim 进行模型简化
                LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...")
                model_onnx = onnxslim.slim(model_onnx)

                # ONNX 模型简化器(已弃用,需在 'cmake' 和 Conda CI 环境下编译)
                # import onnxsim
                # model_onnx, check = onnxsim.simplify(model_onnx)
                # assert check, "Simplified ONNX model could not be validated"
            except Exception as e:
                # 输出简化失败的警告信息
                LOGGER.warning(f"{prefix} simplifier failure: {e}")

        # 将元数据添加到模型中
        for k, v in self.metadata.items():
            meta = model_onnx.metadata_props.add()
            meta.key, meta.value = k, str(v)

        # 保存最终的 ONNX 模型
        onnx.save(model_onnx, f)
        # 返回导出后的 ONNX 文件路径及模型对象
        return f, model_onnx

    @try_export
    @try_export
    @try_export
    def export_paddle(self, prefix=colorstr("PaddlePaddle:")):
        """YOLOv8 Paddle export."""
        # 检查所需的依赖是否已安装
        check_requirements(("paddlepaddle", "x2paddle"))
        # 导入 x2paddle 库
        import x2paddle  # noqa
        from x2paddle.convert import pytorch2paddle  # noqa

        # 记录导出开始信息,并显示 X2Paddle 的版本号
        LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...")
        # 准备导出的文件路径,用 '_paddle_model' 替换原文件的后缀名
        f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}")

        # 使用 pytorch2paddle 将 PyTorch 模型转换为 Paddle 模型
        pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im])  # export
        # 将 metadata 保存为 YAML 文件
        yaml_save(Path(f) / "metadata.yaml", self.metadata)  # add metadata.yaml
        # 返回导出的模型文件路径及空结果
        return f, None

    @try_export
    def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
        """YOLOv8 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow."""
        # 导入 TensorFlow 库和相关功能
        import tensorflow as tf  # noqa
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2  # noqa

        # 记录导出开始信息,并显示 TensorFlow 的版本号
        LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
        # 准备导出的文件路径,用 '.pb' 替换原文件的后缀名
        f = self.file.with_suffix(".pb")

        # 将 Keras 模型封装为 TensorFlow 函数
        m = tf.function(lambda x: keras_model(x))  # full model
        # 获取具体函数,以便后续转换为常量
        m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
        # 将模型转换为冻结的 TensorFlow 图定义
        frozen_func = convert_variables_to_constants_v2(m)
        frozen_func.graph.as_graph_def()
        # 将冻结的图定义写入到指定路径的文件中
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
        # 返回导出的模型文件路径及空结果
        return f, None

    @try_export
    def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
        """YOLOv8 TensorFlow Lite export."""
        # BUG https://github.com/ultralytics/ultralytics/issues/13436
        # 导入 TensorFlow 库
        import tensorflow as tf  # noqa

        # 记录导出开始信息,并显示 TensorFlow 的版本号
        LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
        # 准备保存模型的文件夹路径,用 '_saved_model' 替换原文件的后缀名
        saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model"))
        # 根据选项选择导出的 TensorFlow Lite 模型类型
        if self.args.int8:
            f = saved_model / f"{self.file.stem}_int8.tflite"  # fp32 in/out
        elif self.args.half:
            f = saved_model / f"{self.file.stem}_float16.tflite"  # fp32 in/out
        else:
            f = saved_model / f"{self.file.stem}_float32.tflite"
        # 返回导出的模型文件路径及空结果
        return str(f), None

    @try_export
    # 继续 export 方法的定义
    def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")):
        """YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/."""
        # 输出警告信息,指出Edge TPU可能存在的已知问题
        LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185")

        # 检查Edge TPU编译器的版本命令
        cmd = "edgetpu_compiler --version"
        help_url = "https://coral.ai/docs/edgetpu/compiler/"
        # 断言当前系统是Linux,否则输出错误信息并提供帮助链接
        assert LINUX, f"export only supported on Linux. See {help_url}"

        # 如果edgetpu_compiler命令返回非零状态码,说明编译器未安装
        if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
            # 输出信息,说明Edge TPU导出需要安装Edge TPU编译器,并尝试从帮助链接处安装
            LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}")
            # 检查系统是否安装了sudo命令
            sudo = subprocess.run("sudo --version >/dev/null", shell=True).returncode == 0  # sudo installed on system
            # 遍历安装Edge TPU编译器的命令列表,如果有sudo权限则加上sudo
            for c in (
                "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -",
                'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | '
                "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list",
                "sudo apt-get update",
                "sudo apt-get install edgetpu-compiler",
            ):
                subprocess.run(c if sudo else c.replace("sudo ", ""), shell=True, check=True)

        # 获取Edge TPU编译器的版本信息
        ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]

        # 输出信息,指出使用Edge TPU编译器进行导出,并显示当前编译器的版本号
        LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
        
        # 生成导出后的Edge TPU模型文件名
        f = str(tflite_model).replace(".tflite", "_edgetpu.tflite")  # Edge TPU model

        # 构建Edge TPU编译命令
        cmd = f'edgetpu_compiler -s -d -k 10 --out_dir "{Path(f).parent}" "{tflite_model}"'
        
        # 输出信息,显示正在运行的Edge TPU编译命令
        LOGGER.info(f"{prefix} running '{cmd}'")
        
        # 运行Edge TPU编译命令
        subprocess.run(cmd, shell=True)
        
        # 为导出后的Edge TPU模型添加元数据
        self._add_tflite_metadata(f)
        
        # 返回生成的Edge TPU模型文件名和None
        return f, None
    def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
        """YOLOv8 TensorFlow.js export."""
        # 检查所需的软件包是否已安装
        check_requirements("tensorflowjs")
        # 如果是ARM64架构,修复导出到TF.js时的一个错误
        if ARM64:
            check_requirements("numpy==1.23.5")
        import tensorflow as tf
        import tensorflowjs as tfjs  # 导入TensorFlow.js库,不生成flake8警告

        # 记录导出开始的信息,包括使用的TensorFlow.js版本号
        LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
        
        # 创建用于保存导出文件的目录名(去除后缀变成"_web_model")
        f = str(self.file).replace(self.file.suffix, "_web_model")  # js dir
        # 设置保存*.pb文件路径
        f_pb = str(self.file.with_suffix(".pb"))  # *.pb path

        # 创建一个新的TensorFlow图(Graph),并将模型的GraphDef读入其中
        gd = tf.Graph().as_graph_def()  # TF GraphDef
        with open(f_pb, "rb") as file:
            gd.ParseFromString(file.read())
        
        # 获取输出节点的名称,并以逗号分隔输出
        outputs = ",".join(gd_outputs(gd))
        LOGGER.info(f"\n{prefix} output node names: {outputs}")

        # 根据输入的参数(half或int8),选择相应的量化方式
        quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
        
        # 处理文件路径中可能存在的空格问题,使用contextlib中的函数
        with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_:  # exporter can not handle spaces in path
            # 构建tensorflowjs转换命令
            cmd = (
                "tensorflowjs_converter "
                f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
            )
            LOGGER.info(f"{prefix} running '{cmd}'")
            # 运行tensorflowjs转换命令
            subprocess.run(cmd, shell=True)

        # 如果导出的目录路径中含有空格,发出警告
        if " " in f:
            LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.")

        # 将metadata.yaml保存到导出目录下
        yaml_save(Path(f) / "metadata.yaml", self.metadata)  # add metadata.yaml
        # 返回导出的目录路径和None
        return f, None
    def _add_tflite_metadata(self, file):
        """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
        import flatbuffers  # 导入 flatbuffers 模块

        try:
            # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845
            from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema  # 导入 TensorFlow Lite 元数据模块
            from tensorflow_lite_support.metadata.python import metadata  # 导入 TensorFlow Lite 元数据模块
        except ImportError:  # 捕获导入错误,ARM64 系统可能缺少 'tensorflow_lite_support' 包
            from tflite_support import metadata  # 导入 TensorFlow Lite Support 元数据模块
            from tflite_support import metadata_schema_py_generated as schema  # 导入 TensorFlow Lite 元数据模块

        # 创建模型元数据对象
        model_meta = schema.ModelMetadataT()

        # 设置模型名称、版本、作者和许可证信息
        model_meta.name = self.metadata["description"]
        model_meta.version = self.metadata["version"]
        model_meta.author = self.metadata["author"]
        model_meta.license = self.metadata["license"]

        # 标签文件处理
        tmp_file = Path(file).parent / "temp_meta.txt"
        with open(tmp_file, "w") as f:
            f.write(str(self.metadata))

        # 创建关联的文件对象
        label_file = schema.AssociatedFileT()
        label_file.name = tmp_file.name
        label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS

        # 创建输入元数据对象
        input_meta = schema.TensorMetadataT()
        input_meta.name = "image"
        input_meta.description = "Input image to be detected."
        input_meta.content = schema.ContentT()
        input_meta.content.contentProperties = schema.ImagePropertiesT()
        input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB
        input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties

        # 创建输出元数据对象
        output1 = schema.TensorMetadataT()
        output1.name = "output"
        output1.description = "Coordinates of detected objects, class labels, and confidence score"
        output1.associatedFiles = [label_file]

        # 如果模型任务是 'segment',则创建第二个输出元数据对象
        if self.model.task == "segment":
            output2 = schema.TensorMetadataT()
            output2.name = "output"
            output2.description = "Mask protos"
            output2.associatedFiles = [label_file]

        # 创建子图元数据对象
        subgraph = schema.SubGraphMetadataT()
        subgraph.inputTensorMetadata = [input_meta]
        subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
        model_meta.subgraphMetadata = [subgraph]

        # 使用 flatbuffers 创建一个 Builder 对象
        b = flatbuffers.Builder(0)

        # 打包模型元数据并设置标识符
        b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)

        # 获取元数据缓冲区
        metadata_buf = b.Output()

        # 创建 MetadataPopulator 对象并加载模型文件的元数据
        populator = metadata.MetadataPopulator.with_model_file(str(file))
        populator.load_metadata_buffer(metadata_buf)

        # 加载关联的文件列表到 MetadataPopulator
        populator.load_associated_files([str(tmp_file)])

        # 填充元数据到模型文件中
        populator.populate()

        # 删除临时标签文件
        tmp_file.unlink()
    # 定义一个方法,用于向特定事件的回调列表中添加回调函数
    def add_callback(self, event: str, callback):
        """Appends the given callback."""
        # 将指定事件的回调函数添加到回调列表中
        self.callbacks[event].append(callback)

    # 定义一个方法,用于执行特定事件的所有回调函数
    def run_callbacks(self, event: str):
        """Execute all callbacks for a given event."""
        # 遍历指定事件的回调函数列表,依次执行每个回调函数
        for callback in self.callbacks.get(event, []):
            callback(self)
    # 定义一个名为 IOSDetectModel 的类,用于封装 Ultralytics YOLO 模型,以便导出为 Apple iOS CoreML 格式

    def __init__(self, model, im):
        # 初始化 IOSDetectModel 类,传入 YOLO 模型和示例图像 im
        super().__init__()
        _, _, h, w = im.shape  # 获取图像的批处理大小、通道数、高度和宽度信息
        self.model = model  # 将传入的 YOLO 模型保存到实例变量 self.model 中
        self.nc = len(model.names)  # 计算模型中的类别数目,并保存到实例变量 self.nc 中
        if w == h:
            self.normalize = 1.0 / w  # 如果图像是正方形,则使用标量 1.0/w 进行归一化
        else:
            self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h])  # 如果图像不是正方形,则使用张量进行归一化

    def forward(self, x):
        # 实现模型的前向传播,归一化物体检测模型的预测结果,考虑输入大小相关的因素
        xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
        return cls, xywh * self.normalize  # 返回分类结果和归一化后的坐标信息

.\yolov8\ultralytics\engine\model.py

# 导入inspect模块,用于获取和检查活动对象的信息
import inspect
# 从pathlib模块中导入Path类,用于处理文件和目录路径操作
from pathlib import Path
# 从typing模块中导入List和Union类型,用于声明变量类型
from typing import List, Union

# 导入numpy库,用于支持大量的维度数组和矩阵运算
import numpy as np
# 导入torch库,用于构建和训练深度学习模型
import torch

# 从ultralytics.cfg模块导入TASK2DATA、get_cfg和get_save_dir函数
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
# 从ultralytics.engine.results模块导入Results类,用于处理模型训练和验证结果
from ultralytics.engine.results import Results
# 从ultralytics.hub模块导入HUB_WEB_ROOT和HUBTrainingSession类
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
# 从ultralytics.nn.tasks模块导入attempt_load_one_weight、guess_model_task、nn和yaml_model_load函数
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
# 从ultralytics.utils模块导入一系列辅助工具,如ARGV、ASSETS、DEFAULT_CFG_DICT、LOGGER等
from ultralytics.utils import (
    ARGV,
    ASSETS,
    DEFAULT_CFG_DICT,
    LOGGER,
    RANK,
    SETTINGS,
    callbacks,
    checks,
    emojis,
    yaml_load,
)

# 定义一个名为Model的类,继承自nn.Module类
class Model(nn.Module):
    """
    A base class for implementing YOLO models, unifying APIs across different model types.

    This class provides a common interface for various operations related to YOLO models, such as training,
    validation, prediction, exporting, and benchmarking. It handles different types of models, including those
    loaded from local files, Ultralytics HUB, or Triton Server.

    Attributes:
        callbacks (Dict): A dictionary of callback functions for various events during model operations.
        predictor (BasePredictor): The predictor object used for making predictions.
        model (nn.Module): The underlying PyTorch model.
        trainer (BaseTrainer): The trainer object used for training the model.
        ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
        cfg (str): The configuration of the model if loaded from a *.yaml file.
        ckpt_path (str): The path to the checkpoint file.
        overrides (Dict): A dictionary of overrides for model configuration.
        metrics (Dict): The latest training/validation metrics.
        session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
        task (str): The type of task the model is intended for.
        model_name (str): The name of the model.
    """
    pass  # 占位符,表示此处没有额外的实现
    # 初始化 YOLO 模型的实例
    def __init__(
        self,
        # model 参数可以是字符串路径或 Path 对象,指定模型的权重文件,默认为 "yolov8n.pt"
        model: Union[str, Path] = "yolov8n.pt",
        # task 参数指定模型的任务类型,可以是检测、分类等,默认为 None
        task: str = None,
        # verbose 参数控制是否输出详细信息,默认为 False
        verbose: bool = False,
    ) -> None:
        """
        Initializes a new instance of the YOLO model class.

        This constructor sets up the model based on the provided model path or name. It handles various types of
        model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
        initializes several important attributes of the model and prepares it for operations like training,
        prediction, or export.

        Args:
            model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
                model name from Ultralytics HUB, or a Triton Server model.
            task (str | None): The task type associated with the YOLO model, specifying its application domain.
            verbose (bool): If True, enables verbose output during the model's initialization and subsequent
                operations.

        Raises:
            FileNotFoundError: If the specified model file does not exist or is inaccessible.
            ValueError: If the model file or configuration is invalid or unsupported.
            ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.

        Examples:
            >>> model = Model("yolov8n.pt")
            >>> model = Model("path/to/model.yaml", task="detect")
            >>> model = Model("hub_model", verbose=True)
        """
        super().__init__()
        self.callbacks = callbacks.get_default_callbacks()  # 初始化回调函数
        self.predictor = None  # 用于预测的对象,暂未定义
        self.model = None  # 模型对象,暂未定义
        self.trainer = None  # 训练器对象,暂未定义
        self.ckpt = None  # 如果从 *.pt 文件加载,则为检查点对象
        self.cfg = None  # 如果从 *.yaml 文件加载,则为配置对象
        self.ckpt_path = None  # 检查点文件路径
        self.overrides = {}  # 用于训练器对象的覆盖参数
        self.metrics = None  # 验证/训练指标
        self.session = None  # HUB 会话对象
        self.task = task  # YOLO 模型的任务类型
        model = str(model).strip()

        # 检查是否为 Ultralytics HUB 模型(来自 https://hub.ultralytics.com)
        if self.is_hub_model(model):
            # 从 HUB 获取模型
            checks.check_requirements("hub-sdk>=0.0.8")
            self.session = HUBTrainingSession.create_session(model)
            model = self.session.model_file

        # 检查是否为 Triton Server 模型
        elif self.is_triton_model(model):
            self.model_name = self.model = model
            return

        # 加载或创建新的 YOLO 模型
        if Path(model).suffix in {".yaml", ".yml"}:
            self._new(model, task=task, verbose=verbose)  # 根据 YAML 文件创建新模型
        else:
            self._load(model, task=task)  # 加载已有模型
    ```
    ) -> list:
        """
        Alias for the predict method, enabling the model instance to be callable for predictions.

        This method simplifies the process of making predictions by allowing the model instance to be called
        directly with the required arguments.

        Args:
            source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
                the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
                tensor, or a list/tuple of these.
            stream (bool): If True, treat the input source as a continuous stream for predictions.
            **kwargs (Any): Additional keyword arguments to configure the prediction process.

        Returns:
            (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
                Results object.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> results = model('https://ultralytics.com/images/bus.jpg')
            >>> for r in results:
            ...     print(f"Detected {len(r)} objects in image")
        """
        # 调用 predict 方法的别名,用于执行预测任务并返回结果列表
        return self.predict(source, stream, **kwargs)

    @staticmethod
    def is_triton_model(model: str) -> bool:
        """
        Checks if the given model string is a Triton Server URL.

        This static method determines whether the provided model string represents a valid Triton Server URL by
        parsing its components using urllib.parse.urlsplit().

        Args:
            model (str): The model string to be checked.

        Returns:
            (bool): True if the model string is a valid Triton Server URL, False otherwise.

        Examples:
            >>> Model.is_triton_model('http://localhost:8000/v2/models/yolov8n')
            True
            >>> Model.is_triton_model('yolov8n.pt')
            False
        """
        # 使用 urllib.parse.urlsplit() 解析模型字符串,判断是否是有效的 Triton Server URL
        from urllib.parse import urlsplit

        url = urlsplit(model)
        return url.netloc and url.path and url.scheme in {"http", "grpc"}

    @staticmethod


这些注释解释了每个方法的作用,参数说明以及示例用法,确保代码的每一部分都得到了清晰的解释和文档化。
    def is_hub_model(model: str) -> bool:
        """
        Check if the provided model is an Ultralytics HUB model.

        This static method determines whether the given model string represents a valid Ultralytics HUB model
        identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
        or a standalone model ID.

        Args:
            model (str): The model identifier to check. This can be a URL, an API key and model ID
                combination, or a standalone model ID.

        Returns:
            (bool): True if the model is a valid Ultralytics HUB model, False otherwise.

        Examples:
            >>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
            True
            >>> Model.is_hub_model("api_key_example_model_id")
            True
            >>> Model.is_hub_model("example_model_id")
            True
            >>> Model.is_hub_model("not_a_hub_model.pt")
            False
        """
        return any(
            (
                model.startswith(f"{HUB_WEB_ROOT}/models/"),  # Check if model starts with HUB_WEB_ROOT URL
                [len(x) for x in model.split("_")] == [42, 20],  # Check if model is in APIKEY_MODEL format
                len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"),  # Check if model is a standalone MODEL ID
            )
        )
    # 定义一个方法用于初始化新模型,根据提供的配置文件推断任务类型。

    """
    初始化一个新模型,并从模型定义中推断任务类型。

    这个方法基于提供的配置文件创建一个新的模型实例。它加载模型配置,如果未指定任务类型则推断,然后使用任务映射中的适当类初始化模型。

    Args:
        cfg (str): YAML 格式的模型配置文件路径。
        task (str | None): 模型的特定任务。如果为 None,则会从配置中推断。
        model (torch.nn.Module | None): 自定义模型实例。如果提供,则使用该实例而不是创建新模型。
        verbose (bool): 如果为 True,在加载过程中显示模型信息。

    Raises:
        ValueError: 如果配置文件无效或无法推断任务。
        ImportError: 如果指定任务所需的依赖未安装。

    Examples:
        >>> model = Model()
        >>> model._new('yolov8n.yaml', task='detect', verbose=True)
    """
    # 加载 YAML 格式的模型配置文件,并保存配置字典
    cfg_dict = yaml_model_load(cfg)
    # 将配置文件路径保存到实例变量 self.cfg 中
    self.cfg = cfg
    # 如果任务类型为 None,则从配置字典中猜测任务类型并保存到实例变量 self.task 中
    self.task = task or guess_model_task(cfg_dict)
    # 如果没有提供自定义模型实例 model,则调用 self._smart_load("model") 方法创建模型实例,并使用配置字典和 verbose 参数进行初始化
    self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1)  # build model
    # 将模型和配置信息保存到实例变量 self.overrides 中,以便导出 YAML 文件时使用
    self.overrides["model"] = self.cfg
    self.overrides["task"] = self.task

    # 以下代码用于允许从 YAML 文件导出
    # 将模型的默认参数和 self.overrides 组合成一个参数字典,保存到模型实例的 args 属性中(优先使用模型参数)
    self.model.args = {**DEFAULT_CFG_DICT, **self.overrides}  # combine default and model args (prefer model args)
    # 将模型的任务类型保存到模型实例的 task 属性中
    self.model.task = self.task
    # 将配置文件名保存到实例变量 self.model_name 中
    self.model_name = cfg
    # 加载模型权重文件或从权重文件初始化模型
    def _load(self, weights: str, task=None) -> None:
        """
        Loads a model from a checkpoint file or initializes it from a weights file.

        This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
        up the model, task, and related attributes based on the loaded weights.

        Args:
            weights (str): Path to the model weights file to be loaded.
            task (str | None): The task associated with the model. If None, it will be inferred from the model.

        Raises:
            FileNotFoundError: If the specified weights file does not exist or is inaccessible.
            ValueError: If the weights file format is unsupported or invalid.

        Examples:
            >>> model = Model()
            >>> model._load('yolov8n.pt')
            >>> model._load('path/to/weights.pth', task='detect')
        """
        # 检查文件路径是否是网络链接,如果是则下载到本地
        if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
            weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"])  # download and return local file
        
        # 确保文件名后缀正确,例如将 yolov8n 转换为 yolov8n.pt
        weights = checks.check_model_file_from_stem(weights)  # add suffix, i.e. yolov8n -> yolov8n.pt

        # 如果文件后缀为 .pt,加载模型权重
        if Path(weights).suffix == ".pt":
            self.model, self.ckpt = attempt_load_one_weight(weights)
            self.task = self.model.args["task"]
            self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
            self.ckpt_path = self.model.pt_path
        else:
            # 对于其他文件格式,检查文件存在性,并设定任务类型
            weights = checks.check_file(weights)  # runs in all cases, not redundant with above call
            self.model, self.ckpt = weights, None
            self.task = task or guess_model_task(weights)
            self.ckpt_path = weights
        
        # 更新模型和任务信息到覆盖参数字典
        self.overrides["model"] = weights
        self.overrides["task"] = self.task
        self.model_name = weights
    def _check_is_pytorch_model(self) -> None:
        """
        Checks if the model is a PyTorch model and raises a TypeError if it's not.

        This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
        certain operations that require a PyTorch model are only performed on compatible model types.

        Raises:
            TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
                information about supported model formats and operations.

        Examples:
            >>> model = Model("yolov8n.pt")
            >>> model._check_is_pytorch_model()  # No error raised
            >>> model = Model("yolov8n.onnx")
            >>> model._check_is_pytorch_model()  # Raises TypeError
        """
        # 检查模型是否为字符串路径且以 '.pt' 结尾,或者是否为 nn.Module 类型
        pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
        pt_module = isinstance(self.model, nn.Module)
        # 如果既不是字符串路径以 .pt 结尾,也不是 nn.Module 类型,则抛出 TypeError 异常
        if not (pt_module or pt_str):
            raise TypeError(
                f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
                f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported "
                f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, "
                f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device "
                f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
            )

    def reset_weights(self) -> "Model":
        """
        Resets the model's weights to their initial state.

        This method iterates through all modules in the model and resets their parameters if they have a
        'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
        enabling them to be updated during training.

        Returns:
            (Model): The instance of the class with reset weights.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = Model('yolov8n.pt')
            >>> model.reset_weights()
        """
        # 调用 _check_is_pytorch_model 方法,确保模型是 PyTorch 模型
        self._check_is_pytorch_model()
        # 遍历模型中的所有模块,如果模块具有 'reset_parameters' 方法,则重置其参数
        for m in self.model.modules():
            if hasattr(m, "reset_parameters"):
                m.reset_parameters()
        # 将模型中所有参数的 requires_grad 属性设置为 True,以便在训练过程中更新它们
        for p in self.model.parameters():
            p.requires_grad = True
        # 返回重置权重后的模型实例
        return self
    def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
        """
        Loads parameters from the specified weights file into the model.

        This method supports loading weights from a file or directly from a weights object. It matches parameters by
        name and shape and transfers them to the model.

        Args:
            weights (Union[str, Path]): Path to the weights file or a weights object.

        Returns:
            (Model): The instance of the class with loaded weights.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = Model()
            >>> model.load('yolov8n.pt')
            >>> model.load(Path('path/to/weights.pt'))
        """
        # 检查当前对象是否为 PyTorch 模型,如果不是则抛出异常
        self._check_is_pytorch_model()
        # 如果 weights 是字符串或 Path 对象,则尝试加载单个权重文件
        if isinstance(weights, (str, Path)):
            weights, self.ckpt = attempt_load_one_weight(weights)
        # 调用模型的 load 方法,加载权重
        self.model.load(weights)
        # 返回当前对象本身,以支持链式调用
        return self

    def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
        """
        Saves the current model state to a file.

        This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
        the date, Ultralytics version, license information, and a link to the documentation.

        Args:
            filename (Union[str, Path]): The name of the file to save the model to.
            use_dill (bool): Whether to try using dill for serialization if available.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = Model('yolov8n.pt')
            >>> model.save('my_model.pt')
        """
        # 检查当前对象是否为 PyTorch 模型,如果不是则抛出异常
        self._check_is_pytorch_model()
        # 导入需要的库和模块
        from copy import deepcopy
        from datetime import datetime
        from ultralytics import __version__

        # 准备要保存的元数据
        updates = {
            "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 License (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        }
        # 将模型的 checkpoint 和更新的元数据保存到指定的文件中
        torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)
    def info(self, detailed: bool = False, verbose: bool = True):
        """
        Logs or returns model information.

        This method provides an overview or detailed information about the model, depending on the arguments
        passed. It can control the verbosity of the output and return the information as a list.

        Args:
            detailed (bool): If True, shows detailed information about the model layers and parameters.
            verbose (bool): If True, prints the information. If False, returns the information as a list.

        Returns:
            (List[str]): A list of strings containing various types of information about the model, including
                model summary, layer details, and parameter counts. Empty if verbose is True.

        Raises:
            TypeError: If the model is not a PyTorch model.

        Examples:
            >>> model = Model('yolov8n.pt')
            >>> model.info()  # Prints model summary
            >>> info_list = model.info(detailed=True, verbose=False)  # Returns detailed info as a list
        """
        # 确保模型是 PyTorch 模型,否则引发 TypeError
        self._check_is_pytorch_model()
        # 调用模型对象的 info 方法,根据参数返回模型信息
        return self.model.info(detailed=detailed, verbose=verbose)

    def fuse(self):
        """
        Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.

        This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
        into a single layer. This fusion can significantly improve inference speed by reducing the number of
        operations and memory accesses required during forward passes.

        The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
        bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
        performs both convolution and normalization in one step.

        Raises:
            TypeError: If the model is not a PyTorch nn.Module.

        Examples:
            >>> model = Model("yolov8n.pt")
            >>> model.fuse()
            >>> # Model is now fused and ready for optimized inference
        """
        # 确保模型是 PyTorch nn.Module,否则引发 TypeError
        self._check_is_pytorch_model()
        # 调用模型对象的 fuse 方法,用于融合 Conv2d 和 BatchNorm2d 层
        self.model.fuse()

    def embed(
        self,
        source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
        stream: bool = False,
        **kwargs,
    ):
        """
        Embeds the source data into a higher-dimensional space using the model.

        This method takes input data and embeds it into a higher-dimensional representation using the model's
        embedding capabilities. The input can be provided as various types including paths, arrays, tensors,
        etc. If streaming is enabled, the method handles data as a continuous stream.

        Args:
            source (Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor]): Input data to embed.
            stream (bool): If True, treats the input data as a continuous stream.

        Returns:
            None

        Examples:
            >>> model = Model("embedding_model.pt")
            >>> model.embed("input_data.txt")
        """
        # TODO: Add implementation for embed method
        pass  # Placeholder for the actual implementation
    def embed(
        self,
        source: Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor],
        stream: bool = False,
        **kwargs: Any
    ) -> list:
        """
        Generates image embeddings based on the provided source.

        This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
        source. It allows customization of the embedding process through various keyword arguments.

        Args:
            source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
                generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
            stream (bool): If True, predictions are streamed.
            **kwargs (Any): Additional keyword arguments for configuring the embedding process.

        Returns:
            (List[torch.Tensor]): A list containing the image embeddings.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> image = 'https://ultralytics.com/images/bus.jpg'
            >>> embeddings = model.embed(image)
            >>> print(embeddings[0].shape)
        """
        # 如果没有指定 'embed' 参数,则设为 [len(self.model.model) - 2],即倒数第二层的嵌入
        if not kwargs.get("embed"):
            kwargs["embed"] = [len(self.model.model) - 2]  # embed second-to-last layer if no indices passed
        # 调用 predict() 方法进行预测和嵌入生成,并返回结果
        return self.predict(source, stream, **kwargs)

    def predict(
        self,
        source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
        stream: bool = False,
        predictor=None,
        **kwargs,
    ):
        """
        Perform prediction based on the provided source.

        Args:
            source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the input data.
            stream (bool): If True, predictions are streamed.
            predictor (Optional): Custom predictor function.
            **kwargs (Any): Additional keyword arguments for prediction.

        Returns:
            (Any): The prediction result based on the input source.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> image = 'https://ultralytics.com/images/bus.jpg'
            >>> prediction = model.predict(image)
        """
        # 实现具体的预测逻辑,根据不同的输入源和参数进行预测
        # 这里未提供具体的实现细节,但是假设这个方法会根据输入源和参数返回预测结果
        pass

    def track(
        self,
        source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
        stream: bool = False,
        persist: bool = False,
        **kwargs,
    ):
        """
        Track objects based on the provided source.

        Args:
            source (str | Path | int | list | tuple | np.ndarray | torch.Tensor): The source of the input data.
            stream (bool): If True, predictions are streamed.
            persist (bool): If True, objects are persisted.
            **kwargs (Any): Additional keyword arguments for tracking.

        Returns:
            (Any): The tracking result based on the input source.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> image = 'https://ultralytics.com/images/bus.jpg'
            >>> tracking_result = model.track(image)
        """
        # 实现具体的对象追踪逻辑,根据不同的输入源和参数进行追踪
        # 这里未提供具体的实现细节,但是假设这个方法会根据输入源和参数返回追踪结果
        pass
    ) -> List[Results]:
        """
        Conducts object tracking on the specified input source using the registered trackers.

        This method performs object tracking using the model's predictors and optionally registered trackers. It handles
        various input sources such as file paths or video streams, and supports customization through keyword arguments.
        The method registers trackers if not already present and can persist them between calls.

        Args:
            source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
                tracking. Can be a file path, URL, or video stream.
            stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
            persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
            **kwargs (Any): Additional keyword arguments for configuring the tracking process.

        Returns:
            (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object.

        Raises:
            AttributeError: If the predictor does not have registered trackers.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> results = model.track(source='path/to/video.mp4', show=True)
            >>> for r in results:
            ...     print(r.boxes.id)  # print tracking IDs

        Notes:
            - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
            - The tracking mode is explicitly set in the keyword arguments.
            - Batch size is set to 1 for tracking in videos.
        """
        # 检查预测器是否具有注册的跟踪器
        if not hasattr(self.predictor, "trackers"):
            # 如果没有注册跟踪器,从ultralytics.trackers导入并注册跟踪器
            from ultralytics.trackers import register_tracker

            register_tracker(self, persist)
        
        # 设置关键字参数中的默认置信度阈值,用于基于ByteTrack的跟踪方法
        kwargs["conf"] = kwargs.get("conf") or 0.1  
        # 对于视频跟踪,将批处理大小设置为1
        kwargs["batch"] = kwargs.get("batch") or 1  
        # 明确设置跟踪模式为“track”
        kwargs["mode"] = "track"  
        
        # 调用预测器的predict方法执行跟踪操作,并返回结果列表
        return self.predict(source=source, stream=stream, **kwargs)
    ):
        """
        使用指定的数据集和验证配置验证模型。

        此方法简化了模型验证过程,允许通过各种设置进行自定义。支持使用自定义验证器或默认验证方法进行验证。方法结合了默认配置、特定方法的默认值和用户提供的参数来配置验证过程。

        Args:
            validator (ultralytics.engine.validator.BaseValidator | None): 用于验证模型的自定义验证器类的实例。
            **kwargs (Any): 用于自定义验证过程的任意关键字参数。

        Returns:
            (ultralytics.utils.metrics.DetMetrics): 从验证过程中获得的验证指标。

        Raises:
            AssertionError: 如果模型不是 PyTorch 模型。

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> results = model.val(data='coco128.yaml', imgsz=640)
            >>> print(results.box.map)  # 打印 mAP50-95
        """
        custom = {"rect": True}  # 方法的默认设置
        args = {**self.overrides, **custom, **kwargs, "mode": "val"}  # 参数优先级:右边的参数优先级最高

        validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
        validator(model=self.model)  # 运行验证器,验证模型
        self.metrics = validator.metrics  # 将验证器的指标保存到实例中
        return validator.metrics

    def benchmark(
        self,
        **kwargs,
    ):
        """
        Benchmarks the model across various export formats to evaluate performance.

        This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
        It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
        configured using a combination of default configuration values, model-specific arguments, method-specific
        defaults, and any additional user-provided keyword arguments.

        Args:
            **kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
                default configurations, model-specific arguments, and method defaults. Common options include:
                - data (str): Path to the dataset for benchmarking.
                - imgsz (int | List[int]): Image size for benchmarking.
                - half (bool): Whether to use half-precision (FP16) mode.
                - int8 (bool): Whether to use int8 precision mode.
                - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
                - verbose (bool): Whether to print detailed benchmark information.

        Returns:
            (Dict): A dictionary containing the results of the benchmarking process, including metrics for
                different export formats.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> results = model.benchmark(data='coco8.yaml', imgsz=640, half=True)
            >>> print(results)
        """
        self._check_is_pytorch_model()
        # Importing benchmark function from ultralytics.utils.benchmarks module
        from ultralytics.utils.benchmarks import benchmark

        custom = {"verbose": False}  # method defaults
        # Combine default configurations, model-specific arguments, method defaults, and user-provided kwargs
        args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"}
        # Call the benchmark function with specified parameters
        return benchmark(
            model=self,
            data=kwargs.get("data"),  # if no 'data' argument passed set data=None for default datasets
            imgsz=args["imgsz"],
            half=args["half"],
            int8=args["int8"],
            device=args["device"],
            verbose=kwargs.get("verbose"),
        )
    ) -> str:
        """
        Exports the model to a different format suitable for deployment.

        This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
        purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
        defaults, and any additional arguments provided.

        Args:
            **kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with
                the model's overrides and method defaults. Common arguments include:
                format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
                half (bool): Export model in half-precision.
                int8 (bool): Export model in int8 precision.
                device (str): Device to run the export on.
                workspace (int): Maximum memory workspace size for TensorRT engines.
                nms (bool): Add Non-Maximum Suppression (NMS) module to model.
                simplify (bool): Simplify ONNX model.

        Returns:
            (str): The path to the exported model file.

        Raises:
            AssertionError: If the model is not a PyTorch model.
            ValueError: If an unsupported export format is specified.
            RuntimeError: If the export process fails due to errors.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> model.export(format='onnx', dynamic=True, simplify=True)
            'path/to/exported/model.onnx'
        """
        # 检查当前模型是否为 PyTorch 模型
        self._check_is_pytorch_model()
        # 导入 Exporter 类,用于执行模型导出操作
        from .exporter import Exporter

        # 定义默认的导出参数
        custom = {
            "imgsz": self.model.args["imgsz"],
            "batch": 1,
            "data": None,
            "device": None,  # 重置以避免多GPU错误
            "verbose": False,
        }  # 方法的默认参数
        # 合并所有参数,优先级最高的参数在右边
        args = {**self.overrides, **custom, **kwargs, "mode": "export"}  # 优先使用的参数在右侧
        # 创建 Exporter 对象并执行导出操作
        return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)

    def train(
        self,
        trainer=None,
        **kwargs,
    ):
        """
        Placeholder for the training method of the model.
        This method is typically implemented to train the model on a dataset.

        Args:
            trainer (object): Trainer object for model training.
            **kwargs (Dict): Additional keyword arguments for training customization.

        Returns:
            None
        """
        pass

    def tune(
        self,
        use_ray=False,
        iterations=10,
        *args,
        **kwargs,
    ):
        """
        Placeholder for the tuning method of the model.
        This method is typically implemented to tune hyperparameters or architecture.

        Args:
            use_ray (bool): Flag indicating whether to use Ray for distributed tuning.
            iterations (int): Number of tuning iterations.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            None
        """
        pass
    ):
        """
        执行模型的超参数调优,支持使用 Ray Tune 进行调优。

        该方法支持两种超参数调优模式:使用 Ray Tune 或自定义调优方法。
        当启用 Ray Tune 时,它利用 ultralytics.utils.tuner 模块中的 'run_ray_tune' 函数。
        否则,它使用内部的 'Tuner' 类进行调优。该方法结合了默认值、重写值和自定义参数来配置调优过程。

        Args:
            use_ray (bool): 如果为 True,则使用 Ray Tune 进行超参数调优。默认为 False。
            iterations (int): 执行调优的迭代次数。默认为 10。
            *args (List): 可变长度的参数列表,用于传递额外的位置参数。
            **kwargs (Dict): 任意关键字参数。这些参数与模型的重写参数和默认参数合并。

        Returns:
            (Dict): 包含超参数搜索结果的字典。

        Raises:
            AssertionError: 如果模型不是 PyTorch 模型。

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> results = model.tune(use_ray=True, iterations=20)
            >>> print(results)
        """
        # 检查当前实例是否为 PyTorch 模型,否则抛出断言错误
        self._check_is_pytorch_model()
        
        # 根据 use_ray 参数选择调优方式
        if use_ray:
            # 如果 use_ray 为 True,则从 ultralytics.utils.tuner 导入 run_ray_tune 函数
            from ultralytics.utils.tuner import run_ray_tune
            # 调用 run_ray_tune 函数执行调优,传递模型实例、最大样本数、其他位置参数和关键字参数
            return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
        else:
            # 如果不使用 Ray Tune,则从当前目录中的 tuner 模块导入 Tuner 类
            from .tuner import Tuner
            # 准备用于调优的参数字典 args,包括默认值、重写值、自定义值和额外的关键字参数
            custom = {}  # 自定义方法默认值
            args = {**self.overrides, **custom, **kwargs, "mode": "train"}  # 最右边的参数具有最高优先级
            # 创建 Tuner 实例并调用,传递模型实例、迭代次数和回调函数列表
            return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
    def _apply(self, fn) -> "Model":
        """
        Applies a function to model tensors that are not parameters or registered buffers.

        This method extends the functionality of the parent class's _apply method by additionally resetting the
        predictor and updating the device in the model's overrides. It's typically used for operations like
        moving the model to a different device or changing its precision.

        Args:
            fn (Callable): A function to be applied to the model's tensors. This is typically a method like
                to(), cpu(), cuda(), half(), or float().

        Returns:
            (Model): The model instance with the function applied and updated attributes.

        Raises:
            AssertionError: If the model is not a PyTorch model.

        Examples:
            >>> model = Model("yolov8n.pt")
            >>> model = model._apply(lambda t: t.cuda())  # Move model to GPU
        """
        # 检查当前对象是否是 PyTorch 模型
        self._check_is_pytorch_model()
        # 调用父类的 _apply 方法,并应用传入的函数 fn
        self = super()._apply(fn)  # noqa
        # 重置预测器(predictor),因为设备可能已经更改
        self.predictor = None
        # 更新模型的设备信息到 overrides 字典中
        self.overrides["device"] = self.device  # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
        return self

    @property
    def names(self) -> list:
        """
        Retrieves the class names associated with the loaded model.

        This property returns the class names if they are defined in the model. It checks the class names for validity
        using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
        initialized, it sets it up before retrieving the names.

        Returns:
            (List[str]): A list of class names associated with the model.

        Raises:
            AttributeError: If the model or predictor does not have a 'names' attribute.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> print(model.names)
            ['person', 'bicycle', 'car', ...]
        """
        from ultralytics.nn.autobackend import check_class_names

        # 如果模型对象有 'names' 属性,则返回经过验证后的类名列表
        if hasattr(self.model, "names"):
            return check_class_names(self.model.names)
        # 如果预测器未初始化,则初始化预测器
        if not self.predictor:  # export formats will not have predictor defined until predict() is called
            self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks)
            self.predictor.setup_model(model=self.model, verbose=False)
        # 返回预测器中模型的类名列表
        return self.predictor.model.names
    def device(self) -> torch.device:
        """
        Retrieves the device on which the model's parameters are allocated.

        This property retrieves and returns the device (CPU or GPU) where the model's parameters are currently stored.
        It checks if the model is an instance of nn.Module and returns the device of the first parameter found.

        Returns:
            (torch.device): The device (CPU/GPU) of the model.

        Raises:
            AttributeError: If the model is not a PyTorch nn.Module instance.

        Examples:
            >>> model = YOLO("yolov8n.pt")
            >>> print(model.device)
            device(type='cuda', index=0)  # if CUDA is available
            >>> model = model.to("cpu")
            >>> print(model.device)
            device(type='cpu')
        """
        return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None

    @property
    def transforms(self):
        """
        Retrieves the transformations applied to the input data of the loaded model.

        This property returns the transformation object if defined in the model. These transformations typically include
        preprocessing steps like resizing, normalization, and data augmentation applied to input data before feeding it
        into the model.

        Returns:
            (object | None): The transform object of the model if available, otherwise None.

        Examples:
            >>> model = YOLO('yolov8n.pt')
            >>> transforms = model.transforms
            >>> if transforms:
            ...     print(f"Model transforms: {transforms}")
            ... else:
            ...     print("No transforms defined for this model.")
        """
        return self.model.transforms if hasattr(self.model, "transforms") else None

    def add_callback(self, event: str, func) -> None:
        """
        Adds a callback function for a specified event.

        This method allows registering custom callback functions that are triggered on specific events during
        model operations such as training or inference. Callbacks provide a way to extend and customize the
        behavior of the model at various stages of its lifecycle.

        Args:
            event (str): The name of the event to attach the callback to. Must be a valid event name recognized
                by the Ultralytics framework.
            func (Callable): The callback function to be registered. This function will be called when the
                specified event occurs.

        Raises:
            ValueError: If the event name is not recognized or is invalid.

        Examples:
            >>> def on_train_start(trainer):
            ...     print("Training is starting!")
            >>> model = YOLO('yolov8n.pt')
            >>> model.add_callback("on_train_start", on_train_start)
            >>> model.train(data='coco128.yaml', epochs=1)
        """
        self.callbacks[event].append(func)
    # 清除特定事件的所有回调函数。
    #
    # 此方法移除与给定事件关联的所有自定义和默认回调函数。
    # 它将指定事件的回调列表重置为空列表,有效地移除该事件的所有注册回调函数。
    #
    # Args:
    #     event (str): 要清除回调的事件名称。这应该是 Ultralytics 回调系统中识别的有效事件名称。
    #
    # Examples:
    #     >>> model = YOLO('yolov8n.pt')
    #     >>> model.add_callback('on_train_start', lambda: print('Training started'))
    #     >>> model.clear_callback('on_train_start')
    #     >>> # 'on_train_start' 的所有回调现在都被移除了
    #
    # Notes:
    #     - 此方法影响用户添加的自定义回调和 Ultralytics 框架提供的默认回调。
    #     - 调用此方法后,指定事件将不会执行任何回调,直到添加新的回调。
    #     - 使用时需谨慎,因为它会移除所有回调,包括可能需要用于某些操作正常运行的关键回调。
    def clear_callback(self, event: str) -> None:
        self.callbacks[event] = []

    # 重置所有回调函数为其默认函数。
    #
    # 此方法将所有事件的回调函数重置为其默认函数,移除之前添加的任何自定义回调。
    # 它遍历所有默认回调事件,并将当前回调替换为默认回调。
    #
    # 默认回调函数定义在 'callbacks.default_callbacks' 字典中,其中包含模型生命周期中各种事件的预定义函数,
    # 例如 on_train_start、on_epoch_end 等。
    #
    # 当您想要在进行自定义修改后恢复到原始回调集时,此方法非常有用,确保在不同运行或实验中保持一致的行为。
    #
    # Examples:
    #     >>> model = YOLO('yolov8n.pt')
    #     >>> model.add_callback('on_train_start', custom_function)
    #     >>> model.reset_callbacks()
    #     # 现在所有回调函数都已重置为其默认函数
    def reset_callbacks(self) -> None:
        for event in callbacks.default_callbacks.keys():
            self.callbacks[event] = [callbacks.default_callbacks[event][0]]
    def _reset_ckpt_args(args: dict) -> dict:
        """
        Resets specific arguments when loading a PyTorch model checkpoint.

        This static method filters the input arguments dictionary to retain only a specific set of keys that are
        considered important for model loading. It's used to ensure that only relevant arguments are preserved
        when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.

        Args:
            args (dict): A dictionary containing various model arguments and settings.

        Returns:
            (dict): A new dictionary containing only the specified include keys from the input arguments.

        Examples:
            >>> original_args = {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect', 'batch': 16, 'epochs': 100}
            >>> reset_args = Model._reset_ckpt_args(original_args)
            >>> print(reset_args)
            {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
        """
        include = {"imgsz", "data", "task", "single_cls"}  # only remember these arguments when loading a PyTorch model
        return {k: v for k, v in args.items() if k in include}

    # def __getattr__(self, attr):
    #    """Raises error if object has no requested attribute."""
    #    name = self.__class__.__name__
    #    raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")

    def _smart_load(self, key: str):
        """
        Loads the appropriate module based on the model task.

        This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
        based on the current task of the model and the provided key. It uses the task_map attribute to determine
        the correct module to load.

        Args:
            key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.

        Returns:
            (object): The loaded module corresponding to the specified key and current task.

        Raises:
            NotImplementedError: If the specified key is not supported for the current task.

        Examples:
            >>> model = Model(task='detect')
            >>> predictor = model._smart_load('predictor')
            >>> trainer = model._smart_load('trainer')

        Notes:
            - This method is typically used internally by other methods of the Model class.
            - The task_map attribute should be properly initialized with the correct mappings for each task.
        """
        try:
            return self.task_map[self.task][key]
        except Exception as e:
            name = self.__class__.__name__
            mode = inspect.stack()[1][3]  # get the function name.
            raise NotImplementedError(
                emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.")
            ) from e

    @property
    # 定义一个方法 task_map,返回一个字典,该字典将不同模式下的模型任务映射到对应的类

    """
    Provides a mapping from model tasks to corresponding classes for different modes.

    This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
    to a nested dictionary. The nested dictionary contains mappings for different operational modes
    (model, trainer, validator, predictor) to their respective class implementations.

    The mapping allows for dynamic loading of appropriate classes based on the model's task and the
    desired operational mode. This facilitates a flexible and extensible architecture for handling
    various tasks and modes within the Ultralytics framework.

    Returns:
        (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
        nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
        'predictor', mapping to their respective class implementations.

    Examples:
        >>> model = Model()
        >>> task_map = model.task_map
        >>> detect_class_map = task_map['detect']
        >>> segment_class_map = task_map['segment']

    Note:
        The actual implementation of this method may vary depending on the specific tasks and
        classes supported by the Ultralytics framework. The docstring provides a general
        description of the expected behavior and structure.
    """

    # 抛出 NotImplementedError 异常,提示需要为模型提供任务映射
    raise NotImplementedError("Please provide task map for your model!")

.\yolov8\ultralytics\engine\predictor.py

# 导入必要的库和模块
import platform  # 导入平台信息模块
import re  # 导入正则表达式模块
import threading  # 导入线程模块
from pathlib import Path  # 导入路径操作模块

import cv2  # 导入OpenCV库
import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库

# 从Ultralytics库中导入各种函数和类
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox, classify_transforms
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import select_device, smart_inference_mode

# 定义用于流警告的多行字符串
STREAM_WARNING = """
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.

Example:
    results = model(source=..., stream=True)  # generator of Results objects
"""
    for r in results:
        # 从结果列表中依次取出每个结果对象 r

        boxes = r.boxes  # 获取结果对象 r 中的 boxes 属性,用于包围框的输出
        masks = r.masks  # 获取结果对象 r 中的 masks 属性,用于分割掩模的输出
        probs = r.probs  # 获取结果对象 r 中的 probs 属性,用于分类输出的类别概率
    """
    BasePredictor.

    A base class for creating predictors.

    Attributes:
        args (SimpleNamespace): Configuration for the predictor.
        save_dir (Path): Directory to save results.
        done_warmup (bool): Whether the predictor has finished setup.
        model (nn.Module): Model used for prediction.
        data (dict): Data configuration.
        device (torch.device): Device used for prediction.
        dataset (Dataset): Dataset used for prediction.
        vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initializes the BasePredictor class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        """
        # 获取配置并初始化预测器参数
        self.args = get_cfg(cfg, overrides)
        # 获取保存结果的目录路径
        self.save_dir = get_save_dir(self.args)
        # 如果配置中未指定 conf 参数,则设为默认值 0.25
        if self.args.conf is None:
            self.args.conf = 0.25  # default conf=0.25
        # 标记预热过程未完成
        self.done_warmup = False
        # 如果设置了 args.show 为 True,则检查是否支持显示
        if self.args.show:
            self.args.show = check_imshow(warn=True)

        # 可用于完成设置后使用的变量初始化
        self.model = None
        self.data = self.args.data  # data_dict
        self.imgsz = None
        self.device = None
        self.dataset = None
        self.vid_writer = {}  # dict of {save_path: video_writer, ...}
        self.plotted_img = None
        self.source_type = None
        self.seen = 0
        self.windows = []
        self.batch = None
        self.results = None
        self.transforms = None
        # 获取默认回调函数,如果未提供则使用默认值
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        self.txt_path = None
        # 初始化用于自动线程安全推理的锁
        self._lock = threading.Lock()
        callbacks.add_integration_callbacks(self)

    def preprocess(self, im):
        """
        Prepares input image before inference.

        Args:
            im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
        """
        not_tensor = not isinstance(im, torch.Tensor)
        # 如果输入不是 Tensor,则进行预处理转换
        if not_tensor:
            # 对图像进行预处理转换
            im = np.stack(self.pre_transform(im))
            # 将图像由 BGR 转换为 RGB,将格式由 BHWC 转换为 BCHW,(n, 3, h, w)
            im = im[..., ::-1].transpose((0, 3, 1, 2))
            # 转换为连续的内存布局
            im = np.ascontiguousarray(im)
            # 转换为 PyTorch 的 Tensor 格式
            im = torch.from_numpy(im)

        # 将图像移动到指定的计算设备上
        im = im.to(self.device)
        # 如果模型使用 fp16,则将输入转换为半精度浮点数
        im = im.half() if self.model.fp16 else im.float()
        # 如果输入不是 Tensor,则将像素值范围从 0-255 转换为 0.0-1.0
        if not_tensor:
            im /= 255
        return im
    def inference(self, im, *args, **kwargs):
        """Runs inference on a given image using the specified model and arguments."""
        # 根据参数确定是否需要可视化输出,并根据条件创建保存目录路径
        visualize = (
            increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
            if self.args.visualize and (not self.source_type.tensor)
            else False
        )
        # 调用模型进行推理,传递参数 augment, visualize, embed 以及其他可变参数
        return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)

    def pre_transform(self, im):
        """
        Pre-transform input image before inference.

        Args:
            im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

        Returns:
            (list): A list of transformed images.
        """
        # 检查输入图像是否具有相同的形状
        same_shapes = len({x.shape for x in im}) == 1
        # 创建 LetterBox 对象进行图像预处理,保证图像尺寸与模型期望输入一致
        letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
        # 返回经过预处理的图像列表
        return [letterbox(image=x) for x in im]

    def postprocess(self, preds, img, orig_imgs):
        """Post-processes predictions for an image and returns them."""
        # 目前只是简单返回预测结果,后续可以在此处添加更复杂的后处理逻辑
        return preds

    def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
        """Performs inference on an image or stream."""
        # 根据传入的参数确定是否进行流式推理
        self.stream = stream
        if stream:
            return self.stream_inference(source, model, *args, **kwargs)
        else:
            # 对非流式推理的结果进行汇总,返回一个结果列表
            return list(self.stream_inference(source, model, *args, **kwargs))  # merge list of Result into one

    def predict_cli(self, source=None, model=None):
        """
        Method used for Command Line Interface (CLI) prediction.

        This function is designed to run predictions using the CLI. It sets up the source and model, then processes
        the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
        generator without storing results.

        Note:
            Do not modify this function or remove the generator. The generator ensures that no outputs are
            accumulated in memory, which is critical for preventing memory issues during long-running predictions.
        """
        # 获取流式推理生成器并逐个消费其结果,确保在长时间运行的预测过程中不会出现内存积累问题
        gen = self.stream_inference(source, model)
        for _ in gen:  # sourcery skip: remove-empty-nested-block, noqa
            pass
    def setup_source(self, source):
        """
        Sets up source and inference mode.
        """
        # 检查并获取图片尺寸
        self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
        
        # 根据任务类型设置数据变换(如果是分类任务)
        self.transforms = (
            getattr(
                self.model.model,
                "transforms",
                classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction),
            )
            if self.args.task == "classify"
            else None
        )
        
        # 载入推理数据集
        self.dataset = load_inference_source(
            source=source,
            batch=self.args.batch,
            vid_stride=self.args.vid_stride,
            buffer=self.args.stream_buffer,
        )
        
        # 设置数据源类型
        self.source_type = self.dataset.source_type
        
        # 如果不是流式处理,并且数据源类型表明是流式或截图,或者数据集长度超过1000(很多图片),或者数据集中包含视频标志
        if not getattr(self, "stream", True) and (
            self.source_type.stream
            or self.source_type.screenshot
            or len(self.dataset) > 1000  # many images
            or any(getattr(self.dataset, "video_flag", [False]))
        ):  # videos
            LOGGER.warning(STREAM_WARNING)  # 发出流式处理警告
        
        # 初始化视频写入器
        self.vid_writer = {}

    @smart_inference_mode()
    def setup_model(self, model, verbose=True):
        """
        Initialize YOLO model with given parameters and set it to evaluation mode.
        """
        # 使用给定参数初始化 YOLO 模型,并设置为评估模式
        self.model = AutoBackend(
            weights=model or self.args.model,
            device=select_device(self.args.device, verbose=verbose),
            dnn=self.args.dnn,
            data=self.args.data,
            fp16=self.args.half,
            batch=self.args.batch,
            fuse=True,
            verbose=verbose,
        )

        # 更新设备信息
        self.device = self.model.device  # update device
        self.args.half = self.model.fp16  # update half
        
        # 设置模型为评估模式
        self.model.eval()
    def write_results(self, i, p, im, s):
        """Write inference results to a file or directory."""
        string = ""  # 用于存储输出字符串

        # 如果图像是三维的,扩展成四维(针对批处理维度)
        if len(im.shape) == 3:
            im = im[None]

        # 如果数据源是流、图像或张量,则添加序号和帧数信息
        if self.source_type.stream or self.source_type.from_img or self.source_type.tensor:
            string += f"{i}: "  # 输出结果序号
            frame = self.dataset.count
        else:
            # 从字符串 s[i] 中提取帧数信息
            match = re.search(r"frame (\d+)/", s[i])
            frame = int(match[1]) if match else None  # 如果未确定帧数,则默认为0

        # 设置保存结果的文本路径
        self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))

        # 构建输出字符串,包括图像尺寸和推理结果的详细信息和速度
        string += "%gx%g " % im.shape[2:]
        result = self.results[i]
        result.save_dir = self.save_dir.__str__()  # 在其他位置可能会用到保存目录的字符串表示
        string += f"{result.verbose()}{result.speed['inference']:.1f}ms"

        # 如果需要保存或展示结果图像
        if self.args.save or self.args.show:
            self.plotted_img = result.plot(
                line_width=self.args.line_width,
                boxes=self.args.show_boxes,
                conf=self.args.show_conf,
                labels=self.args.show_labels,
                im_gpu=None if self.args.retina_masks else im[i],
            )

        # 如果需要保存为文本文件
        if self.args.save_txt:
            result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)

        # 如果需要保存裁剪的结果
        if self.args.save_crop:
            result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)

        # 如果需要展示结果
        if self.args.show:
            self.show(str(p))

        # 如果需要保存预测的图像
        if self.args.save:
            self.save_predicted_images(str(self.save_dir / p.name), frame)

        # 返回生成的字符串
        return string
    def save_predicted_images(self, save_path="", frame=0):
        """Save video predictions as mp4 at specified path."""
        # 获取要保存的图像
        im = self.plotted_img

        # 保存视频和流
        if self.dataset.mode in {"stream", "video"}:
            # 根据数据集模式确定帧率
            fps = self.dataset.fps if self.dataset.mode == "video" else 30
            # 创建保存帧图像的路径
            frames_path = f'{save_path.split(".", 1)[0]}_frames/'
            # 如果路径不存在于视频写入对象中,创建新视频文件
            if save_path not in self.vid_writer:  # new video
                # 如果需要保存帧图像,创建保存路径
                if self.args.save_frames:
                    Path(frames_path).mkdir(parents=True, exist_ok=True)
                # 根据操作系统选择文件后缀和编解码器
                suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
                # 创建视频写入对象
                self.vid_writer[save_path] = cv2.VideoWriter(
                    filename=str(Path(save_path).with_suffix(suffix)),
                    fourcc=cv2.VideoWriter_fourcc(*fourcc),
                    fps=fps,  # 需要整数值,浮点数在 MP4 编解码器中会出错
                    frameSize=(im.shape[1], im.shape[0]),  # (width, height)
                )

            # 将图像写入视频文件
            self.vid_writer[save_path].write(im)
            # 如果需要保存帧图像,写入帧图像
            if self.args.save_frames:
                cv2.imwrite(f"{frames_path}{frame}.jpg", im)

        # 保存单张图像
        else:
            cv2.imwrite(save_path, im)

    def show(self, p=""):
        """Display an image in a window using OpenCV imshow()."""
        # 获取要显示的图像
        im = self.plotted_img
        # 在 Linux 系统下,如果窗口名不在已有列表中,创建新窗口
        if platform.system() == "Linux" and p not in self.windows:
            self.windows.append(p)
            cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # 允许窗口调整大小 (Linux)
            cv2.resizeWindow(p, im.shape[1], im.shape[0])  # (width, height)
        # 显示图像
        cv2.imshow(p, im)
        # 等待按键响应,时间取决于数据集模式是图像还是视频
        cv2.waitKey(300 if self.dataset.mode == "image" else 1)  # 1 毫秒

    def run_callbacks(self, event: str):
        """Runs all registered callbacks for a specific event."""
        # 遍历特定事件的所有注册回调函数,并依次执行
        for callback in self.callbacks.get(event, []):
            callback(self)

    def add_callback(self, event: str, func):
        """Add callback."""
        # 向特定事件的回调函数列表中添加新的回调函数
        self.callbacks[event].append(func)
posted @ 2024-09-05 11:58  绝不原创的飞龙  阅读(4)  评论(0编辑  收藏  举报