Yolov8-源码解析-四十四-

Yolov8 源码解析(四十四)

.\yolov8\ultralytics\utils\triton.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 引入必要的类型
from typing import List
# 解析 URL 的组件
from urllib.parse import urlsplit

# 引入 NumPy 库
import numpy as np


class TritonRemoteModel:
    """
    用于与远程 Triton 推理服务器模型进行交互的客户端类。

    Attributes:
        endpoint (str): Triton 服务器上模型的名称。
        url (str): Triton 服务器的 URL。
        triton_client: Triton 客户端(可以是 HTTP 或 gRPC)。
        InferInput: Triton 客户端的输入类。
        InferRequestedOutput: Triton 客户端的输出请求类。
        input_formats (List[str]): 模型输入的数据类型。
        np_input_formats (List[type]): 模型输入的 NumPy 数据类型。
        input_names (List[str]): 模型输入的名称列表。
        output_names (List[str]): 模型输出的名称列表。
    """

    def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
        """
        初始化 TritonRemoteModel。

        参数可以单独提供,也可以从形如 <scheme>://<netloc>/<endpoint>/<task_name> 的 'url' 参数中解析。

        Args:
            url (str): Triton 服务器的 URL。
            endpoint (str): Triton 服务器上模型的名称。
            scheme (str): 通信协议('http' 或 'grpc')。
        """
        if not endpoint and not scheme:  # 从 URL 字符串中解析所有参数
            splits = urlsplit(url)
            endpoint = splits.path.strip("/").split("/")[0]
            scheme = splits.scheme
            url = splits.netloc

        self.endpoint = endpoint
        self.url = url

        # 根据通信协议选择 Triton 客户端
        if scheme == "http":
            import tritonclient.http as client  # noqa

            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint)
        else:
            import tritonclient.grpc as client  # noqa

            self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
            config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]

        # 按字母顺序对输出名称进行排序,例如 'output0', 'output1' 等。
        config["output"] = sorted(config["output"], key=lambda x: x.get("name"))

        # 定义模型属性
        type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
        self.InferRequestedOutput = client.InferRequestedOutput
        self.InferInput = client.InferInput
        self.input_formats = [x["data_type"] for x in config["input"]]
        self.np_input_formats = [type_map[x] for x in self.input_formats]
        self.input_names = [x["name"] for x in config["input"]]
        self.output_names = [x["name"] for x in config["output"]]
    # 定义一个特殊方法 __call__,允许将实例对象像函数一样调用,接受多个 numpy 数组作为输入,并返回多个 numpy 数组作为输出
    def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
        """
        Call the model with the given inputs.

        Args:
            *inputs (List[np.ndarray]): Input data to the model.

        Returns:
            (List[np.ndarray]): Model outputs.
        """
        # 初始化一个空列表,用于存储推理输入
        infer_inputs = []
        # 获取输入数组的数据类型,假设所有输入数组的数据类型相同
        input_format = inputs[0].dtype
        # 遍历所有输入数组
        for i, x in enumerate(inputs):
            # 如果当前输入数组的数据类型与预期的输入数据类型不匹配,将其转换为预期的数据类型
            if x.dtype != self.np_input_formats[i]:
                x = x.astype(self.np_input_formats[i])
            # 创建一个推理输入对象,指定输入名称、形状和数据类型
            infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
            # 将 numpy 数组的数据复制到推理输入对象中
            infer_input.set_data_from_numpy(x)
            # 将推理输入对象添加到推理输入列表中
            infer_inputs.append(infer_input)

        # 根据输出名称列表创建推理输出对象列表
        infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
        # 调用 Triton 客户端进行推理,传入模型名称、输入列表和输出列表,并获取推理结果
        outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)

        # 将每个输出结果转换回原始输入数据类型,并存储在列表中返回
        return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

.\yolov8\ultralytics\utils\tuner.py

# 导入 subprocess 模块,用于执行外部命令
import subprocess

# 从 ultralytics 的 cfg 模块中导入 TASK2DATA、TASK2METRIC、get_save_dir 函数
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
# 从 ultralytics 的 utils 模块中导入 DEFAULT_CFG、DEFAULT_CFG_DICT、LOGGER、NUM_THREADS、checks 函数
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks


# 定义函数 run_ray_tune,用于运行 Ray Tune 进行超参数调优
def run_ray_tune(
    model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
):
    """
    Runs hyperparameter tuning using Ray Tune.

    Args:
        model (YOLO): Model to run the tuner on.
        space (dict, optional): The hyperparameter search space. Defaults to None.
        grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.
        gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.
        max_samples (int, optional): The maximum number of trials to run. Defaults to 10.
        train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.

    Returns:
        (dict): A dictionary containing the results of the hyperparameter search.

    Example:
        ```python
        from ultralytics import YOLO

        # Load a YOLOv8n model
        model = YOLO('yolov8n.pt')

        # Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset
        result_grid = model.tune(data='coco8.yaml', use_ray=True)
        ```
    """

    # 在日志中输出一条信息,引导用户了解 RayTune 的文档链接
    LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
    
    # 如果 train_args 为 None,则将其设为一个空字典
    if train_args is None:
        train_args = {}

    try:
        # 尝试安装 Ray Tune 相关依赖
        subprocess.run("pip install ray[tune]".split(), check=True)  # do not add single quotes here

        # 导入必要的 Ray Tune 模块
        import ray
        from ray import tune
        from ray.air import RunConfig
        from ray.air.integrations.wandb import WandbLoggerCallback
        from ray.tune.schedulers import ASHAScheduler
    except ImportError:
        # 若导入失败,则抛出 ModuleNotFoundError 异常
        raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')

    try:
        # 尝试导入 wandb 模块,并确保其有 __version__ 属性
        import wandb
        assert hasattr(wandb, "__version__")
    except (ImportError, AssertionError):
        # 若导入失败或缺少 __version__ 属性,则将 wandb 设为 False
        wandb = False

    # 使用 checks 模块检查 ray 的版本是否符合要求
    checks.check_version(ray.__version__, ">=2.0.0", "ray")
    # 默认的超参数搜索空间,包含各种超参数的取值范围或概率分布
    default_space = {
        # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
        "lr0": tune.uniform(1e-5, 1e-1),  # 初始学习率范围
        "lrf": tune.uniform(0.01, 1.0),  # 最终OneCycleLR学习率 (lr0 * lrf)
        "momentum": tune.uniform(0.6, 0.98),  # SGD动量/Adam beta1
        "weight_decay": tune.uniform(0.0, 0.001),  # 优化器权重衰减
        "warmup_epochs": tune.uniform(0.0, 5.0),  # 预热 epochs 数量(允许小数)
        "warmup_momentum": tune.uniform(0.0, 0.95),  # 预热初始动量
        "box": tune.uniform(0.02, 0.2),  # 盒损失增益
        "cls": tune.uniform(0.2, 4.0),  # 分类损失增益(与像素缩放比例相关)
        "hsv_h": tune.uniform(0.0, 0.1),  # 图像HSV-Hue增强(比例)
        "hsv_s": tune.uniform(0.0, 0.9),  # 图像HSV-Saturation增强(比例)
        "hsv_v": tune.uniform(0.0, 0.9),  # 图像HSV-Value增强(比例)
        "degrees": tune.uniform(0.0, 45.0),  # 图像旋转角度范围(+/- 度)
        "translate": tune.uniform(0.0, 0.9),  # 图像平移范围(+/- 比例)
        "scale": tune.uniform(0.0, 0.9),  # 图像缩放范围(+/- 增益)
        "shear": tune.uniform(0.0, 10.0),  # 图像剪切范围(+/- 度)
        "perspective": tune.uniform(0.0, 0.001),  # 图像透视变换范围(+/- 比例),范围 0-0.001
        "flipud": tune.uniform(0.0, 1.0),  # 图像上下翻转概率
        "fliplr": tune.uniform(0.0, 1.0),  # 图像左右翻转概率
        "bgr": tune.uniform(0.0, 1.0),  # 图像通道BGR转换概率
        "mosaic": tune.uniform(0.0, 1.0),  # 图像混合(mosaic)概率
        "mixup": tune.uniform(0.0, 1.0),  # 图像混合(mixup)概率
        "copy_paste": tune.uniform(0.0, 1.0),  # 分割复制粘贴概率
    }

    # 将模型放入Ray存储
    task = model.task
    model_in_store = ray.put(model)

    def _tune(config):
        """
        使用指定的超参数和额外的参数训练YOLO模型。

        Args:
            config (dict): 用于训练的超参数字典。

        Returns:
            dict: 训练结果字典。
        """
        model_to_train = ray.get(model_in_store)  # 从Ray存储中获取要调整的模型
        model_to_train.reset_callbacks()  # 重置回调函数
        config.update(train_args)  # 更新训练参数
        results = model_to_train.train(**config)  # 使用给定的超参数进行训练
        return results.results_dict  # 返回训练结果字典

    # 获取搜索空间
    if not space:
        space = default_space  # 如果未提供搜索空间,则使用默认搜索空间
        LOGGER.warning("WARNING ⚠️ 没有提供搜索空间,使用默认搜索空间。")

    # 获取数据集
    data = train_args.get("data", TASK2DATA[task])  # 获取数据集
    space["data"] = data  # 将数据集添加到搜索空间
    if "data" not in train_args:
        LOGGER.warning(f'WARNING ⚠️ 没有提供数据集,使用默认数据集 "data={data}".')

    # 定义带有分配资源的可训练函数
    trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
    # 定义 ASHA 调度器用于超参数搜索
    asha_scheduler = ASHAScheduler(
        time_attr="epoch",  # 指定时间属性为 epoch
        metric=TASK2METRIC[task],  # 设置优化的指标为任务对应的指标
        mode="max",  # 设定优化模式为最大化
        max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,  # 设置最大的训练轮数
        grace_period=grace_period,  # 设置优雅期间
        reduction_factor=3,  # 设置收敛因子
    )

    # 定义用于超参数搜索的回调函数
    tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []

    # 创建 Ray Tune 的超参数搜索调谐器
    tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve()  # 获取保存调优结果的绝对路径,确保目录存在
    tune_dir.mkdir(parents=True, exist_ok=True)  # 创建保存调优结果的目录,如果不存在则创建
    tuner = tune.Tuner(
        trainable_with_resources,  # 可训练函数
        param_space=space,  # 参数空间
        tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),  # 调优配置
        run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),  # 运行配置
    )

    # 运行超参数搜索
    tuner.fit()

    # 返回超参数搜索的结果
    return tuner.get_results()

.\yolov8\ultralytics\utils\__init__.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib
import importlib.metadata
import inspect
import logging.config
import os
import platform
import re
import subprocess
import sys
import threading
import time
import urllib
import uuid
from pathlib import Path
from types import SimpleNamespace
from typing import Union

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from tqdm import tqdm as tqdm_original

from ultralytics import __version__

# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv("RANK", -1))  # 获取环境变量 RANK 的整数值,默认为 -1
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))  # 获取环境变量 LOCAL_RANK 的整数值,默认为 -1,用于 PyTorch Elastic运行

# Other Constants
ARGV = sys.argv or ["", ""]  # 获取命令行参数列表,若为空则初始化为包含两个空字符串的列表
FILE = Path(__file__).resolve()  # 获取当前脚本文件的绝对路径
ROOT = FILE.parents[1]  # 获取当前脚本文件的父目录的父目录路径,即 YOLO 的根目录
ASSETS = ROOT / "assets"  # 默认图像文件目录的路径
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"  # 默认配置文件路径
NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # YOLO 多进程线程数,至少为 1,最多为 CPU 核心数减 1
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true"  # 全局自动安装模式,默认为 True
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true"  # 全局详细模式,默认为 True
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None  # tqdm 进度条显示格式,如果详细模式开启则使用指定格式,否则为 None
LOGGING_NAME = "ultralytics"  # 日志记录器名称
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"])  # 操作系统类型的布尔值
ARM64 = platform.machine() in {"arm64", "aarch64"}  # ARM64 架构的布尔值
PYTHON_VERSION = platform.python_version()  # Python 版本号
TORCH_VERSION = torch.__version__  # PyTorch 版本号
TORCHVISION_VERSION = importlib.metadata.version("torchvision")  # torchvision 的版本号,比直接导入速度更快
HELP_MSG = """
    Usage examples for running YOLOv8:

    1. Install the ultralytics package:

        pip install ultralytics

    2. Use the Python SDK:

        from ultralytics import YOLO

        # Load a model
        model = YOLO('yolov8n.yaml')  # 从头开始构建一个新的模型
        model = YOLO("yolov8n.pt")  # 加载预训练模型(推荐用于训练)

        # Use the model
        results = model.train(data="coco8.yaml", epochs=3)  # 训练模型
        results = model.val()  # 在验证集上评估模型性能
        results = model('https://ultralytics.com/images/bus.jpg')  # 对图像进行预测
        success = model.export(format='onnx')  # 将模型导出为 ONNX 格式
"""
    # 使用命令行界面 (CLI):
    
    YOLOv8 的 'yolo' CLI 命令遵循以下语法:
    
        yolo TASK MODE ARGS
    
        其中   TASK (可选) 可以是 [detect, segment, classify] 中的一个
              MODE (必需) 可以是 [train, val, predict, export] 中的一个
              ARGS (可选) 是任意数量的自定义 'arg=value' 对,如 'imgsz=320',用于覆盖默认设置。
                  可以在 https://docs.ultralytics.com/usage/cfg 或通过 'yolo cfg' 查看所有 ARGS。
    
    - 训练一个检测模型,使用 coco8.yaml 数据集,模型为 yolov8n.pt,训练 10 个 epochs,初始学习率为 0.01
        yolo detect train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.01
    
    - 使用预训练的分割模型预测 YouTube 视频,图像尺寸为 320:
        yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
    
    - 在批量大小为 1 和图像尺寸为 640 的情况下,验证预训练的检测模型:
        yolo detect val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640
    
    - 将 YOLOv8n 分类模型导出为 ONNX 格式,图像尺寸为 224x128 (不需要 TASK 参数)
        yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
    
    - 运行特殊命令:
        yolo help
        yolo checks
        yolo version
        yolo settings
        yolo copy-cfg
        yolo cfg
    
    文档链接:https://docs.ultralytics.com
    社区链接:https://community.ultralytics.com
    GitHub 仓库链接:https://github.com/ultralytics/ultralytics
# 设置和环境变量

# 设置 Torch 的打印选项,包括行宽、精度和默认配置文件
torch.set_printoptions(linewidth=320, precision=4, profile="default")

# 设置 NumPy 的打印选项,包括行宽和浮点数格式
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format})  # format short g, %precision=5

# 设置 OpenCV 的线程数为 0,以防止与 PyTorch DataLoader 的多线程不兼容
cv2.setNumThreads(0)

# 设置 NumExpr 的最大线程数为 NUM_THREADS
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS)

# 设置 CUBLAS 的工作空间配置为 ":4096:8",用于确定性训练
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# 设置 TensorFlow 的最小日志级别为 "3",以在 Colab 中抑制冗长的 TF 编译器警告
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# 设置 Torch 的 C++ 日志级别为 "ERROR",以抑制 "NNPACK.cpp could not initialize NNPACK" 警告
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"

# 设置 Kineto 的日志级别为 "5",以在计算 FLOPs 时抑制冗长的 PyTorch 分析器输出
os.environ["KINETO_LOG_LEVEL"] = "5"


class TQDM(tqdm_original):
    """
    自定义的 Ultralytics tqdm 类,具有不同的默认参数设置。

    Args:
        *args (list): 传递给原始 tqdm 的位置参数。
        **kwargs (any): 关键字参数,应用自定义默认值。
    """

    def __init__(self, *args, **kwargs):
        """
        初始化自定义的 Ultralytics tqdm 类,具有不同的默认参数设置。

        注意,这些参数在调用 TQDM 时仍然可以被覆盖。
        """
        kwargs["disable"] = not VERBOSE or kwargs.get("disable", False)  # 逻辑 'and' 操作,使用默认值(如果传递)。
        kwargs.setdefault("bar_format", TQDM_BAR_FORMAT)  # 如果传递,则覆盖默认值。
        super().__init__(*args, **kwargs)


class SimpleClass:
    """
    Ultralytics SimpleClass 是一个基类,提供更简单的调试和使用方法,包括有用的字符串表示、错误报告和属性访问方法。
    """

    def __str__(self):
        """返回对象的人类可读字符串表示形式。"""
        attr = []
        for a in dir(self):
            v = getattr(self, a)
            if not callable(v) and not a.startswith("_"):
                if isinstance(v, SimpleClass):
                    # 对于子类,仅显示模块和类名
                    s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
                else:
                    s = f"{a}: {repr(v)}"
                attr.append(s)
        return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)

    def __repr__(self):
        """返回对象的机器可读字符串表示形式。"""
        return self.__str__()

    def __getattr__(self, attr):
        """自定义属性访问错误消息,提供有用的信息。"""
        name = self.__class__.__name__
        raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")


class IterableSimpleNamespace(SimpleNamespace):
    """
    Ultralytics IterableSimpleNamespace 是 SimpleNamespace 的扩展类,添加了可迭代功能并支持与 dict() 和 for 循环一起使用。
    """
    # 返回一个迭代器,迭代命名空间的属性键值对
    def __iter__(self):
        """Return an iterator of key-value pairs from the namespace's attributes."""
        return iter(vars(self).items())

    # 返回对象的可读字符串表示,每行显示属性名=属性值
    def __str__(self):
        """Return a human-readable string representation of the object."""
        return "\n".join(f"{k}={v}" for k, v in vars(self).items())

    # 自定义属性访问错误消息,提供帮助信息
    def __getattr__(self, attr):
        """Custom attribute access error message with helpful information."""
        name = self.__class__.__name__
        raise AttributeError(
            f"""
            '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
            'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
            {DEFAULT_CFG_PATH} with the latest version from
            https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
            """
        )

    # 返回指定键的值,如果不存在则返回默认值
    def get(self, key, default=None):
        """Return the value of the specified key if it exists; otherwise, return the default value."""
        return getattr(self, key, default)
# 定义一个函数 plt_settings,用作装饰器,临时设置绘图函数的 rc 参数和后端

def plt_settings(rcparams=None, backend="Agg"):
    """
    Decorator to temporarily set rc parameters and the backend for a plotting function.

    Example:
        decorator: @plt_settings({"font.size": 12})
        context manager: with plt_settings({"font.size": 12}):

    Args:
        rcparams (dict): Dictionary of rc parameters to set.
        backend (str, optional): Name of the backend to use. Defaults to 'Agg'.

    Returns:
        (Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be
            applied to any function that needs to have specific matplotlib rc parameters and backend for its execution.
    """

    # 如果未提供 rcparams,则使用默认字典设置 {"font.size": 11}
    if rcparams is None:
        rcparams = {"font.size": 11}

    # 定义装饰器函数 decorator
    def decorator(func):
        """Decorator to apply temporary rc parameters and backend to a function."""

        # 定义 wrapper 函数,用于设置 rc 参数和后端,调用原始函数,并恢复设置
        def wrapper(*args, **kwargs):
            """Sets rc parameters and backend, calls the original function, and restores the settings."""
            # 获取当前的后端
            original_backend = plt.get_backend()
            # 如果指定的后端与当前不同,则关闭所有图形,并切换到指定的后端
            if backend.lower() != original_backend.lower():
                plt.close("all")  # auto-close()ing of figures upon backend switching is deprecated since 3.8
                plt.switch_backend(backend)

            # 使用指定的 rc 参数上下文管理器
            with plt.rc_context(rcparams):
                result = func(*args, **kwargs)

            # 如果使用了不同的后端,则关闭所有图形,并恢复原始后端
            if backend != original_backend:
                plt.close("all")
                plt.switch_backend(original_backend)

            return result

        return wrapper

    return decorator


# 定义函数 set_logging,为给定名称设置日志记录,支持 UTF-8 编码,并确保在不同环境中的兼容性
def set_logging(name="LOGGING_NAME", verbose=True):
    """Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across different
    environments.
    """
    # 根据 verbose 参数设置日志级别,多 GPU 训练中考虑到 RANK 的情况
    level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR  # rank in world for Multi-GPU trainings

    # 配置控制台(stdout)的编码为 UTF-8,以确保兼容性
    formatter = logging.Formatter("%(message)s")  # Default formatter
    # 如果在 Windows 环境下,并且 sys.stdout 具有 "encoding" 属性,并且编码不是 "utf-8"
    if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8":

        # 定义一个定制的日志格式化器 CustomFormatter
        class CustomFormatter(logging.Formatter):
            def format(self, record):
                """Sets up logging with UTF-8 encoding and configurable verbosity."""
                # 返回格式化后的日志记录,包括表情符号处理
                return emojis(super().format(record))

        try:
            # 尝试重新配置 stdout 使用 UTF-8 编码(如果支持)
            if hasattr(sys.stdout, "reconfigure"):
                sys.stdout.reconfigure(encoding="utf-8")
            # 对于不支持 reconfigure 的环境,用 TextIOWrapper 包装 stdout
            elif hasattr(sys.stdout, "buffer"):
                import io
                sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
            else:
                # 创建自定义格式化器以应对非 UTF-8 环境
                formatter = CustomFormatter("%(message)s")
        except Exception as e:
            # 如果出现异常,创建适应非 UTF-8 环境的自定义格式化器
            print(f"Creating custom formatter for non UTF-8 environments due to {e}")
            formatter = CustomFormatter("%(message)s")

    # 创建并配置流处理器 StreamHandler,使用适当的格式化器和日志级别
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)
    stream_handler.setLevel(level)

    # 设置日志记录器 logger
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(stream_handler)
    # 禁止日志传播到父记录器
    logger.propagate = False

    # 返回配置好的 logger 对象
    return logger
# 设置日志记录器
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE)  # 在全局定义日志记录器,用于 train.py、val.py、predict.py 等
# 设置 "sentry_sdk" 和 "urllib3.connectionpool" 的日志级别为 CRITICAL + 1
for logger in "sentry_sdk", "urllib3.connectionpool":
    logging.getLogger(logger).setLevel(logging.CRITICAL + 1)


def emojis(string=""):
    """返回与平台相关的安全的字符串版本,支持表情符号。"""
    return string.encode().decode("ascii", "ignore") if WINDOWS else string


class ThreadingLocked:
    """
    用于确保函数或方法的线程安全执行的装饰器类。可以作为装饰器使用,以确保如果从多个线程调用装饰的函数,
    则只有一个线程能够执行该函数。

    Attributes:
        lock (threading.Lock): 管理对装饰函数访问的锁对象。

    Example:
        ```py
        from ultralytics.utils import ThreadingLocked

        @ThreadingLocked()
        def my_function():
            # 在此处编写代码
        ```
    """

    def __init__(self):
        """初始化装饰器类,用于函数或方法的线程安全执行。"""
        self.lock = threading.Lock()

    def __call__(self, f):
        """执行函数或方法的线程安全装饰器。"""
        from functools import wraps

        @wraps(f)
        def decorated(*args, **kwargs):
            """应用线程安全性到装饰的函数或方法。"""
            with self.lock:
                return f(*args, **kwargs)

        return decorated


def yaml_save(file="data.yaml", data=None, header=""):
    """
    将数据保存为 YAML 格式到文件中。

    Args:
        file (str, optional): 文件名。默认为 'data.yaml'。
        data (dict): 要保存的数据,以 YAML 格式。
        header (str, optional): 要添加的 YAML 头部。

    Returns:
        (None): 数据保存到指定文件中。
    """
    if data is None:
        data = {}
    file = Path(file)
    if not file.parent.exists():
        # 如果父目录不存在,则创建父目录
        file.parent.mkdir(parents=True, exist_ok=True)

    # 将路径对象转换为字符串
    valid_types = int, float, str, bool, list, tuple, dict, type(None)
    for k, v in data.items():
        if not isinstance(v, valid_types):
            data[k] = str(v)

    # 将数据以 YAML 格式写入文件
    with open(file, "w", errors="ignore", encoding="utf-8") as f:
        if header:
            f.write(header)
        yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)


def yaml_load(file="data.yaml", append_filename=False):
    """
    从文件中加载 YAML 格式的数据。

    Args:
        file (str, optional): 文件名。默认为 'data.yaml'。
        append_filename (bool): 是否将 YAML 文件名添加到 YAML 字典中。默认为 False。

    Returns:
        (dict): YAML 格式的数据和文件名。
    """
    assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"
    # 使用指定的编码打开文件,忽略解码错误,返回文件对象 f
    with open(file, errors="ignore", encoding="utf-8") as f:
        s = f.read()  # 将文件内容读取为字符串 s

        # 如果字符串 s 中存在不可打印字符,则通过正则表达式去除特殊字符
        if not s.isprintable():
            s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)

        # 使用 yaml.safe_load() 加载字符串 s,转换为 Python 字典;若文件为空则返回空字典
        data = yaml.safe_load(s) or {}
        
        # 如果 append_filename 为真,则将文件名以字符串形式添加到字典 data 中
        if append_filename:
            data["yaml_file"] = str(file)
        
        # 返回加载后的数据字典 data
        return data
# 漂亮地打印一个 YAML 文件或 YAML 格式的字典
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
    # 如果 yaml_file 是字符串或 Path 对象,使用 yaml_load 加载 YAML 文件
    yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
    # 将 yaml_dict 转换成 YAML 格式的字符串,不排序键,允许 Unicode,宽度为无限大
    dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=float("inf"))
    # 记录打印信息到日志,包含文件名(加粗黑色),以及 YAML 格式的字符串
    LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")


# 默认配置字典,从 DEFAULT_CFG_PATH 加载 YAML 文件得到
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
# 将值为字符串且为 "None" 的项改为 None
for k, v in DEFAULT_CFG_DICT.items():
    if isinstance(v, str) and v.lower() == "none":
        DEFAULT_CFG_DICT[k] = None
# 获取默认配置字典的键集合
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
# 用 DEFAULT_CFG_DICT 创建一个 IterableSimpleNamespace 对象作为默认配置
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)


def read_device_model() -> str:
    """
    从系统中读取设备型号信息,并缓存以便快速访问。被 is_jetson() 和 is_raspberrypi() 使用。

    Returns:
        (str): 如果成功读取,返回设备型号文件内容,否则返回空字符串。
    """
    # 尝试打开 "/proc/device-tree/model" 文件,读取并返回其内容
    with contextlib.suppress(Exception):
        with open("/proc/device-tree/model") as f:
            return f.read()
    # 如果发生异常(如文件不存在),返回空字符串
    return ""


def is_ubuntu() -> bool:
    """
    检查当前操作系统是否为 Ubuntu。

    Returns:
        (bool): 如果操作系统是 Ubuntu,则返回 True,否则返回 False。
    """
    # 尝试打开 "/etc/os-release" 文件,检查是否包含 "ID=ubuntu" 的行
    with contextlib.suppress(FileNotFoundError):
        with open("/etc/os-release") as f:
            return "ID=ubuntu" in f.read()
    # 如果文件未找到或未包含 "ID=ubuntu" 行,则返回 False
    return False


def is_colab():
    """
    检查当前脚本是否运行在 Google Colab 笔记本环境中。

    Returns:
        (bool): 如果运行在 Colab 笔记本中,则返回 True,否则返回 False。
    """
    # 检查环境变量中是否包含 "COLAB_RELEASE_TAG" 或 "COLAB_BACKEND_VERSION"
    return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ


def is_kaggle():
    """
    检查当前脚本是否运行在 Kaggle 内核中。

    Returns:
        (bool): 如果运行在 Kaggle 内核中,则返回 True,否则返回 False。
    """
    # 检查当前工作目录是否为 "/kaggle/working",并且检查 KAGGLE_URL_BASE 是否为 "https://www.kaggle.com"
    return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"


def is_jupyter():
    """
    检查当前脚本是否运行在 Jupyter Notebook 中。在 Colab、Jupyterlab、Kaggle、Paperspace 环境中验证通过。

    Returns:
        (bool): 如果运行在 Jupyter Notebook 中,则返回 True,否则返回 False。
    """
    # 尝试导入 IPython 中的 get_ipython 函数,如果导入成功并返回不为 None,则说明在 Jupyter 环境中
    with contextlib.suppress(Exception):
        from IPython import get_ipython

        return get_ipython() is not None
    # 如果导入过程中发生异常,则返回 False
    return False


def is_docker() -> bool:
    """
    判断当前脚本是否运行在 Docker 容器中。

    Returns:
        (bool): 如果运行在 Docker 容器中,则返回 True,否则返回 False。
    """
    # 尝试打开 "/proc/self/cgroup" 文件,检查是否包含 "docker" 字符串
    with contextlib.suppress(Exception):
        with open("/proc/self/cgroup") as f:
            return "docker" in f.read()
    # 如果发生异常(如文件不存在),返回 False
    return False


def is_raspberrypi() -> bool:
    """
    判断当前设备是否为 Raspberry Pi。

    Returns:
        (bool): 如果设备为 Raspberry Pi,则返回 True,否则返回 False。
    """
    # 检查当前 Python 环境是否运行在树莓派上,通过检查设备模型信息
    # 返回值为布尔类型:如果在树莓派上运行则返回 True,否则返回 False
    """
    检查 Python 环境的设备模型信息中是否包含"Raspberry Pi"
    """
    return "Raspberry Pi" in PROC_DEVICE_MODEL
def is_jetson() -> bool:
    """
    Determines if the Python environment is running on a Jetson Nano or Jetson Orin device by checking the device model
    information.

    Returns:
        (bool): True if running on a Jetson Nano or Jetson Orin, False otherwise.
    """
    return "NVIDIA" in PROC_DEVICE_MODEL  # 检查 PROC_DEVICE_MODEL 是否包含 "NVIDIA",表示运行在 Jetson Nano 或 Jetson Orin


def is_online() -> bool:
    """
    Check internet connectivity by attempting to connect to a known online host.

    Returns:
        (bool): True if connection is successful, False otherwise.
    """
    with contextlib.suppress(Exception):
        assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true"  # 检查环境变量 YOLO_OFFLINE 是否为 "True"
        import socket

        for dns in ("1.1.1.1", "8.8.8.8"):  # 检查 Cloudflare 和 Google DNS
            socket.create_connection(address=(dns, 80), timeout=2.0).close()
            return True
    return False


def is_pip_package(filepath: str = __name__) -> bool:
    """
    Determines if the file at the given filepath is part of a pip package.

    Args:
        filepath (str): The filepath to check.

    Returns:
        (bool): True if the file is part of a pip package, False otherwise.
    """
    import importlib.util

    # 获取模块的规范
    spec = importlib.util.find_spec(filepath)

    # 返回规范不为 None 且 origin 不为 None(表示是一个包)
    return spec is not None and spec.origin is not None


def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
    """
    Check if a directory is writeable.

    Args:
        dir_path (str | Path): The path to the directory.

    Returns:
        (bool): True if the directory is writeable, False otherwise.
    """
    return os.access(str(dir_path), os.W_OK)


def is_pytest_running():
    """
    Determines whether pytest is currently running or not.

    Returns:
        (bool): True if pytest is running, False otherwise.
    """
    return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)


def is_github_action_running() -> bool:
    """
    Determine if the current environment is a GitHub Actions runner.

    Returns:
        (bool): True if the current environment is a GitHub Actions runner, False otherwise.
    """
    return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ


def get_git_dir():
    """
    Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
    the current file is not part of a git repository, returns None.

    Returns:
        (Path | None): Git root directory if found or None if not found.
    """
    for d in Path(__file__).parents:
        if (d / ".git").is_dir():
            return d


def is_git_dir():
    """
    Determines whether the current file is part of a git repository. If the current file is not part of a git
    repository, returns False.

    Returns:
        (bool): True if the current file is part of a git repository, False otherwise.
    """
    for d in Path(__file__).parents:
        if (d / ".git").is_dir():
            return True
    return False
    # 检查全局变量 GIT_DIR 是否为 None
    return GIT_DIR is not None
def get_git_origin_url():
    """
    Retrieves the origin URL of a git repository.

    Returns:
        (str | None): The origin URL of the git repository or None if not git directory.
    """
    # 检查当前是否在 Git 仓库中
    if IS_GIT_DIR:
        # 使用 subprocess 模块调用 git 命令获取远程仓库的 URL
        with contextlib.suppress(subprocess.CalledProcessError):
            origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
            # 将获取到的字节流解码成字符串,并去除首尾的空白字符
            return origin.decode().strip()


def get_git_branch():
    """
    Returns the current git branch name. If not in a git repository, returns None.

    Returns:
        (str | None): The current git branch name or None if not a git directory.
    """
    # 检查当前是否在 Git 仓库中
    if IS_GIT_DIR:
        # 使用 subprocess 模块调用 git 命令获取当前分支名
        with contextlib.suppress(subprocess.CalledProcessError):
            origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
            # 将获取到的字节流解码成字符串,并去除首尾的空白字符
            return origin.decode().strip()


def get_default_args(func):
    """
    Returns a dictionary of default arguments for a function.

    Args:
        func (callable): The function to inspect.

    Returns:
        (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.
    """
    # 使用 inspect 模块获取函数的签名信息
    signature = inspect.signature(func)
    # 构建并返回参数名与默认参数值组成的字典
    return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}


def get_ubuntu_version():
    """
    Retrieve the Ubuntu version if the OS is Ubuntu.

    Returns:
        (str): Ubuntu version or None if not an Ubuntu OS.
    """
    # 检查当前操作系统是否为 Ubuntu
    if is_ubuntu():
        # 尝试打开 /etc/os-release 文件并匹配版本号信息
        with contextlib.suppress(FileNotFoundError, AttributeError):
            with open("/etc/os-release") as f:
                # 使用正则表达式搜索并提取版本号
                return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]


def get_user_config_dir(sub_dir="Ultralytics"):
    """
    Return the appropriate config directory based on the environment operating system.

    Args:
        sub_dir (str): The name of the subdirectory to create.

    Returns:
        (Path): The path to the user config directory.
    """
    # 根据不同的操作系统返回相应的用户配置目录
    if WINDOWS:
        path = Path.home() / "AppData" / "Roaming" / sub_dir
    elif MACOS:  # macOS
        path = Path.home() / "Library" / "Application Support" / sub_dir
    elif LINUX:
        path = Path.home() / ".config" / sub_dir
    else:
        # 如果不支持当前操作系统,抛出 ValueError 异常
        raise ValueError(f"Unsupported operating system: {platform.system()}")

    # 对于 GCP 和 AWS Lambda,只有 /tmp 目录可写入
    if not is_dir_writeable(path.parent):
        # 如果父目录不可写,输出警告信息并使用备选路径 /tmp 或当前工作目录
        LOGGER.warning(
            f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD."
            "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path."
        )
        path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir

    # 如果路径不存在,则创建相应的子目录
    path.mkdir(parents=True, exist_ok=True)

    return path


# Define constants (required below)
PROC_DEVICE_MODEL = read_device_model()  # is_jetson() and is_raspberrypi() depend on this constant
# 检查当前是否在线
ONLINE = is_online()

# 检查当前是否在Google Colab环境中
IS_COLAB = is_colab()

# 检查当前是否在Docker容器中
IS_DOCKER = is_docker()

# 检查当前是否在NVIDIA Jetson设备上
IS_JETSON = is_jetson()

# 检查当前是否在Jupyter环境中
IS_JUPYTER = is_jupyter()

# 检查当前是否在Kaggle环境中
IS_KAGGLE = is_kaggle()

# 检查当前代码是否安装为pip包
IS_PIP_PACKAGE = is_pip_package()

# 检查当前是否在树莓派环境中
IS_RASPBERRYPI = is_raspberrypi()

# 获取当前Git仓库的目录
GIT_DIR = get_git_dir()

# 检查当前目录是否是Git仓库
IS_GIT_DIR = is_git_dir()

# 获取用户配置目录,默认为环境变量YOLO_CONFIG_DIR,或者使用系统默认配置目录
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir())  # Ultralytics settings dir

# 设置配置文件路径为用户配置目录下的settings.yaml文件
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"
    Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.

    Examples:
        As a decorator:
        >>> @TryExcept(msg="Error occurred in func", verbose=True)
        >>> def func():
        >>>    # Function logic here
        >>>     pass

        As a context manager:
        >>> with TryExcept(msg="Error occurred in block", verbose=True):
        >>>     # Code block here
        >>>     pass
    """
    # 定义 TryExcept 类,用于处理异常,可以作为装饰器或上下文管理器使用

    def __init__(self, msg="", verbose=True):
        """Initialize TryExcept class with optional message and verbosity settings."""
        self.msg = msg
        self.verbose = verbose
        # 初始化 TryExcept 类,可以设置错误消息和详细输出选项

    def __enter__(self):
        """Executes when entering TryExcept context, initializes instance."""
        # 进入上下文管理器时执行的方法,初始化实例
        pass

    def __exit__(self, exc_type, value, traceback):
        """Defines behavior when exiting a 'with' block, prints error message if necessary."""
        # 定义退出上下文管理器时的行为,如果需要,打印错误消息
        if self.verbose and value:
            print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
            # 如果设置了详细输出并且出现了异常,打印错误消息
        return True
        # 返回 True 表示已经处理了异常,不会向上层抛出异常
class Retry(contextlib.ContextDecorator):
    """
    Retry class for function execution with exponential backoff.

    Can be used as a decorator to retry a function on exceptions, up to a specified number of times with an
    exponentially increasing delay between retries.

    Examples:
        Example usage as a decorator:
        >>> @Retry(times=3, delay=2)
        >>> def test_func():
        >>>     # Replace with function logic that may raise exceptions
        >>>     return True
    """

    def __init__(self, times=3, delay=2):
        """Initialize Retry class with specified number of retries and delay."""
        self.times = times  # 设置重试次数
        self.delay = delay  # 设置初始重试延迟时间
        self._attempts = 0  # 记录当前重试次数

    def __call__(self, func):
        """Decorator implementation for Retry with exponential backoff."""

        def wrapped_func(*args, **kwargs):
            """Applies retries to the decorated function or method."""
            self._attempts = 0  # 重试次数初始化为0
            while self._attempts < self.times:  # 循环直到达到最大重试次数
                try:
                    return func(*args, **kwargs)  # 调用被装饰的函数或方法
                except Exception as e:
                    self._attempts += 1  # 增加重试次数计数
                    print(f"Retry {self._attempts}/{self.times} failed: {e}")  # 打印重试失败信息
                    if self._attempts >= self.times:  # 如果达到最大重试次数
                        raise e  # 抛出异常
                    time.sleep(self.delay * (2**self._attempts))  # 按指数增加的延迟时间

        return wrapped_func


def threaded(func):
    """
    Multi-threads a target function by default and returns the thread or function result.

    Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
    """

    def wrapper(*args, **kwargs):
        """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
        if kwargs.pop("threaded", True):  # 如果未指定 'threaded' 参数或为 True
            thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)  # 创建线程对象
            thread.start()  # 启动线程
            return thread  # 返回线程对象
        else:
            return func(*args, **kwargs)  # 直接调用函数或方法

    return wrapper


def set_sentry():
    """
    Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and
    sync=True in settings. Run 'yolo settings' to see and update settings YAML file.

    Conditions required to send errors (ALL conditions must be met or no errors will be reported):
        - sentry_sdk package is installed
        - sync=True in YOLO settings
        - pytest is not running
        - running in a pip package installation
        - running in a non-git directory
        - running with rank -1 or 0
        - online environment
        - CLI used to run package (checked with 'yolo' as the name of the main CLI command)

    The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError
    exceptions and to exclude events with 'out of memory' in their exception message.
    """
    # 设置 Sentry 事件的自定义标签和用户信息
    def before_send(event, hint):
        """
        根据特定的异常类型和消息修改事件,然后发送给 Sentry。

        Args:
            event (dict): 包含错误信息的事件字典。
            hint (dict): 包含额外错误信息的字典。

        Returns:
            dict: 修改后的事件字典,如果不发送事件到 Sentry 则返回 None。
        """
        # 如果 hint 中包含异常信息
        if "exc_info" in hint:
            exc_type, exc_value, tb = hint["exc_info"]
            # 如果异常类型是 KeyboardInterrupt、FileNotFoundError,或者异常消息包含"out of memory"
            if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
                return None  # 不发送事件

        # 设置事件的标签
        event["tags"] = {
            "sys_argv": ARGV[0],  # 系统参数列表中的第一个参数
            "sys_argv_name": Path(ARGV[0]).name,  # 第一个参数的文件名部分
            "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",  # 安装来源是 git、pip 还是其他
            "os": ENVIRONMENT,  # 环境变量中的操作系统信息
        }
        return event  # 返回修改后的事件字典

    # 如果满足一系列条件,则配置 Sentry
    if (
        SETTINGS["sync"]  # 同步设置为 True
        and RANK in {-1, 0}  # 运行时的进程等级为 -1 或 0
        and Path(ARGV[0]).name == "yolo"  # 第一个参数的文件名为 "yolo"
        and not TESTS_RUNNING  # 没有正在运行的测试
        and ONLINE  # 处于联机状态
        and IS_PIP_PACKAGE  # 安装方式为 pip
        and not IS_GIT_DIR  # 不是从 git 安装
    ):
        # 如果 sentry_sdk 包未安装,则返回
        try:
            import sentry_sdk  # 导入 sentry_sdk 包
        except ImportError:
            return

        # 初始化 Sentry SDK
        sentry_sdk.init(
            dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",  # Sentry 项目的 DSN
            debug=False,  # 调试模式设为 False
            traces_sample_rate=1.0,  # 所有跟踪数据采样率设为 100%
            release=__version__,  # 使用当前应用版本号
            environment="production",  # 环境设置为生产环境
            before_send=before_send,  # 设置发送前的处理函数为 before_send
            ignore_errors=[KeyboardInterrupt, FileNotFoundError],  # 忽略的错误类型列表
        )
        # 设置 Sentry 用户信息,使用 SHA-256 匿名化的 UUID 哈希
        sentry_sdk.set_user({"id": SETTINGS["uuid"]})
class SettingsManager(dict):
    """
    Manages Ultralytics settings stored in a YAML file.

    Args:
        file (str | Path): Path to the Ultralytics settings YAML file. Default is USER_CONFIG_DIR / 'settings.yaml'.
        version (str): Settings version. In case of local version mismatch, new default settings will be saved.
    """
    def __init__(self, file=SETTINGS_YAML, version="0.0.4"):
        """
        Initialize the SettingsManager with default settings, load and validate current settings from the YAML
        file.
        """
        import copy  # 导入用于深拷贝对象的模块
        import hashlib  # 导入用于哈希计算的模块

        from ultralytics.utils.checks import check_version  # 导入版本检查函数
        from ultralytics.utils.torch_utils import torch_distributed_zero_first  # 导入分布式训练相关的函数

        root = GIT_DIR or Path()  # 设置根目录为环境变量GIT_DIR的值或当前路径
        datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()  # 设置数据集根目录路径

        self.file = Path(file)  # 将传入的文件路径转换为Path对象
        self.version = version  # 设置版本号
        self.defaults = {  # 设置默认配置字典
            "settings_version": version,
            "datasets_dir": str(datasets_root / "datasets"),
            "weights_dir": str(root / "weights"),
            "runs_dir": str(root / "runs"),
            "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),  # 计算当前机器的唯一标识符
            "sync": True,
            "api_key": "",
            "openai_api_key": "",
            "clearml": True,  # 各种集成配置
            "comet": True,
            "dvc": True,
            "hub": True,
            "mlflow": True,
            "neptune": True,
            "raytune": True,
            "tensorboard": True,
            "wandb": True,
        }
        self.help_msg = (
            f"\nView settings with 'yolo settings' or at '{self.file}'"
            "\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. "
            "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings."
        )

        super().__init__(copy.deepcopy(self.defaults))  # 调用父类构造函数并使用深拷贝的默认配置

        with torch_distributed_zero_first(RANK):  # 在分布式环境中,仅主节点执行以下操作
            if not self.file.exists():  # 如果配置文件不存在,则保存默认配置
                self.save()

            self.load()  # 载入配置文件中的设置
            correct_keys = self.keys() == self.defaults.keys()  # 检查载入的配置键与默认配置键是否一致
            correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))  # 检查各项设置的类型是否与默认设置一致
            correct_version = check_version(self["settings_version"], self.version)  # 检查配置文件中的版本与当前版本是否一致
            if not (correct_keys and correct_types and correct_version):
                LOGGER.warning(
                    "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "
                    f"with your settings or a recent ultralytics package update. {self.help_msg}"
                )
                self.reset()  # 将设置重置为默认值

            if self.get("datasets_dir") == self.get("runs_dir"):  # 如果数据集目录与运行目录相同,则发出警告
                LOGGER.warning(
                    f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "
                    f"must be different than 'runs_dir: {self.get('runs_dir')}'. "
                    f"Please change one to avoid possible issues during training. {self.help_msg}"
                )

    def load(self):
        """Loads settings from the YAML file."""  # 从YAML文件加载设置
        super().update(yaml_load(self.file))  # 调用父类方法,使用yaml_load函数加载配置文件内容
    def save(self):
        """将当前设置保存到YAML文件中。"""
        # 调用yaml_save函数,将当前对象转换为字典并保存到文件中
        yaml_save(self.file, dict(self))

    def update(self, *args, **kwargs):
        """更新当前设置中的一个设置值。"""
        # 遍历关键字参数kwargs,检查每个设置项的有效性
        for k, v in kwargs.items():
            # 如果设置项不在默认设置中,则引发KeyError异常
            if k not in self.defaults:
                raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}")
            # 获取默认设置项k的类型
            t = type(self.defaults[k])
            # 如果传入的值v不是预期的类型t,则引发TypeError异常
            if not isinstance(v, t):
                raise TypeError(f"Ultralytics setting '{k}' must be of type '{t}', not '{type(v)}'. {self.help_msg}")
        # 调用父类的update方法,更新设置项
        super().update(*args, **kwargs)
        # 更新后立即保存设置到文件
        self.save()

    def reset(self):
        """将设置重置为默认值并保存。"""
        # 清空当前设置
        self.clear()
        # 使用默认设置更新当前设置
        self.update(self.defaults)
        # 保存更新后的设置到文件
        self.save()
# 发出弃用警告的函数,用于提示已弃用的参数,建议使用更新的参数
def deprecation_warn(arg, new_arg):
    # 使用 LOGGER 发出警告消息,指出已弃用的参数和建议的新参数
    LOGGER.warning(
        f"WARNING ⚠️ '{arg}' is deprecated and will be removed in in the future. " f"Please use '{new_arg}' instead."
    )


# 清理 URL,去除授权信息,例如 https://url.com/file.txt?auth -> https://url.com/file.txt
def clean_url(url):
    # 使用 Path 对象将 URL 转换成 POSIX 格式的字符串,并替换 Windows 下的 "://" 为 "://"
    url = Path(url).as_posix().replace(":/", "://")  # Pathlib turns :// -> :/, as_posix() for Windows
    # 对 URL 解码并按 "?" 进行分割,保留问号前的部分作为最终的清理后的 URL
    return urllib.parse.unquote(url).split("?")[0]  # '%2F' to '/', split https://url.com/file.txt?auth


# 将 URL 转换为文件名,例如 https://url.com/file.txt?auth -> file.txt
def url2file(url):
    # 清理 URL 后使用 Path 对象获取文件名部分作为结果
    return Path(clean_url(url)).name


# 在 utils 初始化过程中运行以下代码 ------------------------------------------------------------------------------------

# 检查首次安装步骤
PREFIX = colorstr("Ultralytics: ")  # 设置日志前缀
SETTINGS = SettingsManager()  # 初始化设置管理器
DATASETS_DIR = Path(SETTINGS["datasets_dir"])  # 全局数据集目录
WEIGHTS_DIR = Path(SETTINGS["weights_dir"])  # 全局权重目录
RUNS_DIR = Path(SETTINGS["runs_dir"])  # 全局运行目录
# 确定当前环境,根据不同情况设置 ENVIRONMENT 变量
ENVIRONMENT = (
    "Colab"
    if IS_COLAB
    else "Kaggle"
    if IS_KAGGLE
    else "Jupyter"
    if IS_JUPYTER
    else "Docker"
    if IS_DOCKER
    else platform.system()
)
TESTS_RUNNING = is_pytest_running() or is_github_action_running()  # 检查是否正在运行测试
set_sentry()  # 初始化 Sentry 错误监控

# 应用 Monkey Patch
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save

torch.load = torch_load  # 覆盖默认的 torch.load 函数
torch.save = torch_save  # 覆盖默认的 torch.save 函数
if WINDOWS:
    # 对于 Windows 平台,应用 cv2 的补丁以支持图像路径中的非 ASCII 和非 UTF 字符
    cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow

.\yolov8\ultralytics\__init__.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 定义模块版本号
__version__ = "8.2.69"

# 导入操作系统模块
import os

# 设置环境变量(放置在导入语句之前)
# 设置 OpenMP 线程数为 1,以减少训练过程中的 CPU 使用率
os.environ["OMP_NUM_THREADS"] = "1"

# 从 ultralytics.data.explorer.explorer 模块中导入 Explorer 类
from ultralytics.data.explorer.explorer import Explorer
# 从 ultralytics.models 模块中导入 NAS、RTDETR、SAM、YOLO、FastSAM、YOLOWorld 类
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
# 从 ultralytics.utils 模块中导入 ASSETS 和 SETTINGS
from ultralytics.utils import ASSETS, SETTINGS
# 从 ultralytics.utils.checks 模块中导入 check_yolo 函数,并将其命名为 checks
from ultralytics.utils.checks import check_yolo as checks
# 从 ultralytics.utils.downloads 模块中导入 download 函数
from ultralytics.utils.downloads import download

# 将 SETTINGS 赋值给 settings 变量
settings = SETTINGS

# 定义 __all__ 变量,包含了可以通过 `from package import *` 导入的名字
__all__ = (
    "__version__",
    "ASSETS",
    "YOLO",
    "YOLOWorld",
    "NAS",
    "SAM",
    "FastSAM",
    "RTDETR",
    "checks",
    "download",
    "settings",
    "Explorer",
)
posted @ 2024-09-05 12:04  绝不原创的飞龙  阅读(11)  评论(0编辑  收藏  举报