Yolov8-源码解析-二十六-

Yolov8 源码解析(二十六)

.\yolov8\tests\test_engine.py

# 导入所需的模块和库
import sys  # 系统模块
from unittest import mock  # 导入 mock 模块

# 导入自定义模块和类
from tests import MODEL  # 导入 tests 模块中的 MODEL 对象
from ultralytics import YOLO  # 导入 ultralytics 库中的 YOLO 类
from ultralytics.cfg import get_cfg  # 导入 ultralytics 库中的 get_cfg 函数
from ultralytics.engine.exporter import Exporter  # 导入 ultralytics 库中的 Exporter 类
from ultralytics.models.yolo import classify, detect, segment  # 导入 ultralytics 库中的 classify, detect, segment 函数
from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR  # 导入 ultralytics 库中的 ASSETS, DEFAULT_CFG, WEIGHTS_DIR 变量


def test_func(*args):  # 定义测试函数,用于评估 YOLO 模型性能指标
    """Test function callback for evaluating YOLO model performance metrics."""
    print("callback test passed")  # 打印测试通过消息


def test_export():
    """Tests the model exporting function by adding a callback and asserting its execution."""
    exporter = Exporter()  # 创建 Exporter 对象
    exporter.add_callback("on_export_start", test_func)  # 添加回调函数到导出开始事件
    assert test_func in exporter.callbacks["on_export_start"], "callback test failed"  # 断言回调函数已成功添加
    f = exporter(model=YOLO("yolov8n.yaml").model)  # 导出模型
    YOLO(f)(ASSETS)  # 使用导出后的模型进行推理


def test_detect():
    """Test YOLO object detection training, validation, and prediction functionality."""
    overrides = {"data": "coco8.yaml", "model": "yolov8n.yaml", "imgsz": 32, "epochs": 1, "save": False}  # 定义参数覆盖字典
    cfg = get_cfg(DEFAULT_CFG)  # 获取默认配置
    cfg.data = "coco8.yaml"  # 设置配置数据文件
    cfg.imgsz = 32  # 设置配置图像尺寸

    # Trainer
    trainer = detect.DetectionTrainer(overrides=overrides)  # 创建检测训练器对象
    trainer.add_callback("on_train_start", test_func)  # 添加回调函数到训练开始事件
    assert test_func in trainer.callbacks["on_train_start"], "callback test failed"  # 断言回调函数已成功添加
    trainer.train()  # 执行训练

    # Validator
    val = detect.DetectionValidator(args=cfg)  # 创建检测验证器对象
    val.add_callback("on_val_start", test_func)  # 添加回调函数到验证开始事件
    assert test_func in val.callbacks["on_val_start"], "callback test failed"  # 断言回调函数已成功添加
    val(model=trainer.best)  # 使用最佳模型进行验证

    # Predictor
    pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]})  # 创建检测预测器对象
    pred.add_callback("on_predict_start", test_func)  # 添加回调函数到预测开始事件
    assert test_func in pred.callbacks["on_predict_start"], "callback test failed"  # 断言回调函数已成功添加
    # 确认 sys.argv 为空没有问题
    with mock.patch.object(sys, "argv", []):
        result = pred(source=ASSETS, model=MODEL)  # 执行预测
        assert len(result), "predictor test failed"  # 断言预测结果不为空

    overrides["resume"] = trainer.last  # 设置训练器的恢复模型
    trainer = detect.DetectionTrainer(overrides=overrides)  # 创建新的检测训练器对象
    try:
        trainer.train()  # 执行训练
    except Exception as e:
        print(f"Expected exception caught: {e}")  # 捕获并打印预期的异常
        return

    Exception("Resume test failed!")  # 报告恢复测试失败


def test_segment():
    """Tests image segmentation training, validation, and prediction pipelines using YOLO models."""
    overrides = {"data": "coco8-seg.yaml", "model": "yolov8n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False}  # 定义参数覆盖字典
    cfg = get_cfg(DEFAULT_CFG)  # 获取默认配置
    cfg.data = "coco8-seg.yaml"  # 设置配置数据文件
    cfg.imgsz = 32  # 设置配置图像尺寸
    # YOLO(CFG_SEG).train(**overrides)  # works

    # Trainer
    trainer = segment.SegmentationTrainer(overrides=overrides)  # 创建分割训练器对象
    trainer.add_callback("on_train_start", test_func)  # 添加回调函数到训练开始事件
    assert test_func in trainer.callbacks["on_train_start"], "callback test failed"  # 断言回调函数已成功添加
    trainer.train()  # 执行训练

    # Validator
    val = segment.SegmentationValidator(args=cfg)  # 创建分割验证器对象
    # 添加回调函数到“on_val_start”事件,使其在val对象开始时调用test_func函数
    val.add_callback("on_val_start", test_func)
    # 断言确认test_func确实添加到val对象的“on_val_start”事件回调列表中
    assert test_func in val.callbacks["on_val_start"], "callback test failed"
    # 使用trainer.best模型对val对象进行验证,验证best.pt模型的性能
    val(model=trainer.best)  # validate best.pt

    # 创建SegmentationPredictor对象pred,覆盖参数imgsz为[64, 64]
    pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]})
    # 添加回调函数到“on_predict_start”事件,使其在pred对象开始预测时调用test_func函数
    pred.add_callback("on_predict_start", test_func)
    # 断言确认test_func确实添加到pred对象的“on_predict_start”事件回调列表中
    assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
    # 使用指定的模型进行预测,源数据为ASSETS,模型为WEIGHTS_DIR / "yolov8n-seg.pt"
    result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolov8n-seg.pt")
    # 断言确保结果非空,验证预测器的功能
    assert len(result), "predictor test failed"

    # 测试恢复功能
    overrides["resume"] = trainer.last  # 设置恢复参数为trainer的最后状态
    trainer = segment.SegmentationTrainer(overrides=overrides)  # 使用指定参数创建SegmentationTrainer对象
    try:
        trainer.train()  # 尝试训练模型
    except Exception as e:
        # 捕获异常并输出异常信息
        print(f"Expected exception caught: {e}")
        return

    # 如果发生异常未被捕获,则抛出异常信息“Resume test failed!”
    Exception("Resume test failed!")
def test_classify():
    """Test image classification including training, validation, and prediction phases."""
    # 定义需要覆盖的配置项
    overrides = {"data": "imagenet10", "model": "yolov8n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False
    # 根据默认配置获取配置对象
    cfg = get_cfg(DEFAULT_CFG)
    # 调整配置项中的数据集为 imagenet10
    cfg.data = "imagenet10"
    # 调整配置项中的图片尺寸为 32
    cfg.imgsz = 32

    # YOLO(CFG_SEG).train(**overrides)  # works

    # 创建分类训练器对象,应用 overrides 中的配置项
    trainer = classify.ClassificationTrainer(overrides=overrides)
    # 添加在训练开始时执行的回调函数 test_func
    trainer.add_callback("on_train_start", test_func)
    # 断言 test_func 是否成功添加到训练器的 on_train_start 回调中
    assert test_func in trainer.callbacks["on_train_start"], "callback test failed"
    # 开始训练
    trainer.train()

    # 创建分类验证器对象,使用 cfg 中的配置项
    val = classify.ClassificationValidator(args=cfg)
    # 添加在验证开始时执行的回调函数 test_func
    val.add_callback("on_val_start", test_func)
    # 断言 test_func 是否成功添加到验证器的 on_val_start 回调中
    assert test_func in val.callbacks["on_val_start"], "callback test failed"
    # 执行验证,使用训练器中的最佳模型
    val(model=trainer.best)

    # 创建分类预测器对象,应用 imgsz 为 [64, 64] 的配置项
    pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]})
    # 添加在预测开始时执行的回调函数 test_func
    pred.add_callback("on_predict_start", test_func)
    # 断言 test_func 是否成功添加到预测器的 on_predict_start 回调中
    assert test_func in pred.callbacks["on_predict_start"], "callback test failed"
    # 使用 ASSETS 中的数据源和训练器中的最佳模型进行预测
    result = pred(source=ASSETS, model=trainer.best)
    # 断言预测结果不为空,表示预测器测试通过
    assert len(result), "predictor test failed"

.\yolov8\tests\test_explorer.py

# 导入必要的库和模块:PIL 图像处理库和 pytest 测试框架
import PIL
import pytest

# 从 ultralytics 包中导入 Explorer 类和 ASSETS 资源
from ultralytics import Explorer
from ultralytics.utils import ASSETS

# 使用 pytest 的标记 @pytest.mark.slow 标记此函数为慢速测试
@pytest.mark.slow
def test_similarity():
    """测试 Explorer 中相似性计算和 SQL 查询的正确性和返回长度。"""
    # 创建 Explorer 对象,使用配置文件 'coco8.yaml'
    exp = Explorer(data="coco8.yaml")
    # 创建嵌入表格
    exp.create_embeddings_table()
    # 获取索引为 1 的相似项
    similar = exp.get_similar(idx=1)
    # 断言相似项的长度为 4
    assert len(similar) == 4
    # 使用图像文件 'bus.jpg' 获取相似项
    similar = exp.get_similar(img=ASSETS / "bus.jpg")
    # 断言相似项的长度为 4
    assert len(similar) == 4
    # 获取索引为 [1, 2] 的相似项,限制返回结果为 2 个
    similar = exp.get_similar(idx=[1, 2], limit=2)
    # 断言相似项的长度为 2
    assert len(similar) == 2
    # 获取相似性索引
    sim_idx = exp.similarity_index()
    # 断言相似性索引的长度为 4
    assert len(sim_idx) == 4
    # 执行 SQL 查询,查询条件为 'labels LIKE '%zebra%''
    sql = exp.sql_query("WHERE labels LIKE '%zebra%'")
    # 断言 SQL 查询结果的长度为 1
    assert len(sql) == 1


@pytest.mark.slow
def test_det():
    """测试检测功能,并验证嵌入表格是否包含边界框。"""
    # 创建 Explorer 对象,使用配置文件 'coco8.yaml' 和模型 'yolov8n.pt'
    exp = Explorer(data="coco8.yaml", model="yolov8n.pt")
    # 强制创建嵌入表格
    exp.create_embeddings_table(force=True)
    # 断言表格中的边界框列的长度大于 0
    assert len(exp.table.head()["bboxes"]) > 0
    # 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
    similar = exp.get_similar(idx=[1, 2], limit=10)
    # 断言相似项的长度大于 0
    assert len(similar) > 0
    # 执行绘制相似项的操作,返回值应为 PIL 图像对象
    similar = exp.plot_similar(idx=[1, 2], limit=10)
    # 断言返回值是 PIL 图像对象
    assert isinstance(similar, PIL.Image.Image)


@pytest.mark.slow
def test_seg():
    """测试分割功能,并确保嵌入表格包含分割掩码。"""
    # 创建 Explorer 对象,使用配置文件 'coco8-seg.yaml' 和模型 'yolov8n-seg.pt'
    exp = Explorer(data="coco8-seg.yaml", model="yolov8n-seg.pt")
    # 强制创建嵌入表格
    exp.create_embeddings_table(force=True)
    # 断言表格中的分割掩码列的长度大于 0
    assert len(exp.table.head()["masks"]) > 0
    # 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
    similar = exp.get_similar(idx=[1, 2], limit=10)
    # 断言相似项的长度大于 0
    assert len(similar) > 0
    # 执行绘制相似项的操作,返回值应为 PIL 图像对象
    similar = exp.plot_similar(idx=[1, 2], limit=10)
    # 断言返回值是 PIL 图像对象
    assert isinstance(similar, PIL.Image.Image)


@pytest.mark.slow
def test_pose():
    """测试姿势估计功能,并验证嵌入表格是否包含关键点。"""
    # 创建 Explorer 对象,使用配置文件 'coco8-pose.yaml' 和模型 'yolov8n-pose.pt'
    exp = Explorer(data="coco8-pose.yaml", model="yolov8n-pose.pt")
    # 强制创建嵌入表格
    exp.create_embeddings_table(force=True)
    # 断言表格中的关键点列的长度大于 0
    assert len(exp.table.head()["keypoints"]) > 0
    # 获取索引为 [1, 2] 的相似项,限制返回结果为 10 个
    similar = exp.get_similar(idx=[1, 2], limit=10)
    # 断言相似项的长度大于 0
    assert len(similar) > 0
    # 执行绘制相似项的操作,返回值应为 PIL 图像对象
    similar = exp.plot_similar(idx=[1, 2], limit=10)
    # 断言返回值是 PIL 图像对象
    assert isinstance(similar, PIL.Image.Image)

.\yolov8\tests\test_exports.py

# 导入所需的库和模块
import shutil  # 文件操作工具,用于复制、移动和删除文件和目录
import uuid  # 用于生成唯一的UUID
from itertools import product  # 用于生成迭代器的笛卡尔积
from pathlib import Path  # 用于处理文件和目录路径的类

import pytest  # 测试框架

# 导入测试所需的模块和函数
from tests import MODEL, SOURCE
from ultralytics import YOLO  # 导入YOLO模型
from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS  # 导入配置信息
from ultralytics.utils import (
    IS_RASPBERRYPI,  # 检查是否在树莓派上运行
    LINUX,  # 检查是否在Linux系统上运行
    MACOS,  # 检查是否在macOS系统上运行
    WINDOWS,  # 检查是否在Windows系统上运行
    checks,  # 各种系统和Python版本的检查工具集合
)
from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13  # Torch相关的工具函数和版本检查
# 测试导出 YOLO 模型到 ONNX 格式,使用不同的配置和参数进行测试
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
    # 调用 YOLO 类,根据任务选择相应的模型,然后导出为 ONNX 格式的文件
    file = YOLO(TASK2MODEL[task]).export(
        format="onnx",
        imgsz=32,
        dynamic=dynamic,
        int8=int8,
        half=half,
        batch=batch,
        simplify=simplify,
    )
    # 使用导出的模型进行推理,传入相同的源数据多次以达到批处理要求
    YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32)  # exported model inference
    # 清理生成的文件
    Path(file).unlink()  # cleanup


@pytest.mark.slow
@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
# 测试导出 YOLO 模型到 TorchScript 格式,考虑不同的配置和参数组合
def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
    # 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TorchScript 格式的文件
    file = YOLO(TASK2MODEL[task]).export(
        format="torchscript",
        imgsz=32,
        dynamic=dynamic,
        int8=int8,
        half=half,
        batch=batch,
    )
    # 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
    YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32)  # exported model inference at batch=3
    # 清理生成的文件
    Path(file).unlink()  # cleanup


@pytest.mark.slow
# 在 macOS 上测试导出 YOLO 模型到 CoreML 格式,使用各种参数配置
@pytest.mark.skipif(not MACOS, reason="CoreML inference only supported on macOS")
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
@pytest.mark.parametrize(
    "task, dynamic, int8, half, batch",
    [  # 生成所有组合,但排除 int8 和 half 都为 True 的情况
        (task, dynamic, int8, half, batch)
        for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
        if not (int8 and half)  # 排除 int8 和 half 都为 True 的情况
    ],
)
def test_export_coreml_matrix(task, dynamic, int8, half, batch):
    # 调用 YOLO 类,根据任务选择相应的模型,然后导出为 CoreML 格式的文件
    file = YOLO(TASK2MODEL[task]).export(
        format="coreml",
        imgsz=32,
        dynamic=dynamic,
        int8=int8,
        half=half,
        batch=batch,
    )
    # 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
    YOLO(file)([SOURCE] * batch, imgsz=32)  # exported model inference at batch=3
    # 清理生成的文件夹
    shutil.rmtree(file)  # cleanup


@pytest.mark.slow
# 在 Python 版本大于等于 3.10 时,在 Linux 上测试导出 YOLO 模型到 TFLite 格式
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
@pytest.mark.parametrize(
    "task, dynamic, int8, half, batch",
    [  # 生成所有组合,但排除 int8 和 half 都为 True 的情况
        (task, dynamic, int8, half, batch)
        for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
        if not (int8 and half)  # 排除 int8 和 half 都为 True 的情况
    ],
)
# 测试导出 YOLO 模型到 TFLite 格式,考虑各种导出配置
def test_export_tflite_matrix(task, dynamic, int8, half, batch):
    # 调用 YOLO 类,根据任务选择相应的模型,然后导出为 TFLite 格式的文件
    file = YOLO(TASK2MODEL[task]).export(
        format="tflite",
        imgsz=32,
        dynamic=dynamic,
        int8=int8,
        half=half,
        batch=batch,
    )
    # 使用导出的模型进行推理,传入特定的源数据以达到批处理要求
    YOLO(file)([SOURCE] * batch, imgsz=32)  # exported model inference at batch=3
    # 清理生成的文件夹
    shutil.rmtree(file)  # cleanup
    # 使用指定任务的模型从YOLO导出模型,并以tflite格式输出到文件
    file = YOLO(TASK2MODEL[task]).export(
        format="tflite",
        imgsz=32,
        dynamic=dynamic,
        int8=int8,
        half=half,
        batch=batch,
    )
    
    # 使用导出的模型进行推理,输入为[SOURCE]的重复项,批量大小为3,图像尺寸为32
    YOLO(file)([SOURCE] * batch, imgsz=32)  # 批量大小为3时导出模型的推理
    
    # 删除导出的模型文件,进行清理工作
    Path(file).unlink()  # 清理
# 根据条件跳过测试,若 TORCH_1_9 为假则跳过,提示 PyTorch<=1.8 不支持 CoreML>=7.2
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
# 若在 Windows 系统上则跳过,提示 CoreML 在 Windows 上不受支持
@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows")  # RuntimeError: BlobWriter not loaded
# 若在树莓派上则跳过,提示 CoreML 在树莓派上不受支持
@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi")
# 若 Python 版本为 3.12 则跳过,提示 CoreML 不支持 Python 3.12
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
def test_export_coreml():
    """Test YOLO exports to CoreML format, optimized for macOS only."""
    if MACOS:
        # 在 macOS 上导出 YOLO 模型到 CoreML 格式,并优化为指定的 imgsz 大小
        file = YOLO(MODEL).export(format="coreml", imgsz=32)
        # 使用导出的 CoreML 模型进行预测,仅支持在 macOS 上进行,对于 nms=False 的模型
        YOLO(file)(SOURCE, imgsz=32)  # model prediction only supported on macOS for nms=False models
    else:
        # 在非 macOS 系统上导出 YOLO 模型到 CoreML 格式,使用默认的 nms=True 和指定的 imgsz 大小
        YOLO(MODEL).export(format="coreml", nms=True, imgsz=32)


# 若 Python 版本小于 3.10 则跳过,提示 TFLite 导出要求 Python>=3.10
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
# 若不在 Linux 系统上则跳过,提示在 Windows 和 macOS 上 TensorFlow 安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
def test_export_tflite():
    """Test YOLO exports to TFLite format under specific OS and Python version conditions."""
    # 创建 YOLO 模型对象
    model = YOLO(MODEL)
    # 导出 YOLO 模型到 TFLite 格式,使用指定的 imgsz 大小
    file = model.export(format="tflite", imgsz=32)
    # 使用导出的 TFLite 模型进行预测
    YOLO(file)(SOURCE, imgsz=32)


# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled")
# 若不在 Linux 系统上则跳过,提示 TensorFlow 在 Windows 和 macOS 上安装可能会冲突
@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS")
def test_export_pb():
    """Test YOLO exports to TensorFlow's Protobuf (*.pb) format."""
    # 创建 YOLO 模型对象
    model = YOLO(MODEL)
    # 导出 YOLO 模型到 TensorFlow 的 Protobuf 格式,使用指定的 imgsz 大小
    file = model.export(format="pb", imgsz=32)
    # 使用导出的 Protobuf 模型进行预测
    YOLO(file)(SOURCE, imgsz=32)


# 直接跳过此测试,无特定原因说明
@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirementsk conflict.")
def test_export_paddle():
    """Test YOLO exports to Paddle format, noting protobuf conflicts with ONNX."""
    # 导出 YOLO 模型到 Paddle 格式,使用指定的 imgsz 大小
    YOLO(MODEL).export(format="paddle", imgsz=32)


# 标记为慢速测试
@pytest.mark.slow
def test_export_ncnn():
    """Test YOLO exports to NCNN format."""
    # 导出 YOLO 模型到 NCNN 格式,使用指定的 imgsz 大小
    file = YOLO(MODEL).export(format="ncnn", imgsz=32)
    # 使用导出的 NCNN 模型进行预测
    YOLO(file)(SOURCE, imgsz=32)  # exported model inference

.\yolov8\tests\test_integrations.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 引入必要的库和模块
import contextlib
import os
import subprocess
import time
from pathlib import Path

import pytest

# 从自定义的模块导入常量和函数
from tests import MODEL, SOURCE, TMP
from ultralytics import YOLO, download
from ultralytics.utils import DATASETS_DIR, SETTINGS
from ultralytics.utils.checks import check_requirements

# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed")
def test_model_ray_tune():
    """Tune YOLO model using Ray for hyperparameter optimization."""
    # 调用 YOLO 类来进行模型调参
    YOLO("yolov8n-cls.yaml").tune(
        use_ray=True, data="imagenet10", grace_period=1, iterations=1, imgsz=32, epochs=1, plots=False, device="cpu"
    )

# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
def test_mlflow():
    """Test training with MLflow tracking enabled (see https://mlflow.org/ for details)."""
    # 设置 MLflow 跟踪开启
    SETTINGS["mlflow"] = True
    # 调用 YOLO 类来进行模型训练
    YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=3, plots=False, device="cpu")

# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868")
@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed")
def test_mlflow_keep_run_active():
    """Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings."""
    import mlflow

    # 设置 MLflow 跟踪开启
    SETTINGS["mlflow"] = True
    run_name = "Test Run"
    os.environ["MLFLOW_RUN"] = run_name

    # 测试 MLFLOW_KEEP_RUN_ACTIVE=True 的情况
    os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "True"
    YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
    # 获取当前 MLflow 运行的状态
    status = mlflow.active_run().info.status
    assert status == "RUNNING", "MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True"

    run_id = mlflow.active_run().info.run_id

    # 测试 MLFLOW_KEEP_RUN_ACTIVE=False 的情况
    os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "False"
    YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
    # 获取指定运行 ID 的 MLflow 运行状态
    status = mlflow.get_run(run_id=run_id).info.status
    assert status == "FINISHED", "MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False"

    # 测试 MLFLOW_KEEP_RUN_ACTIVE 未设置的情况
    os.environ.pop("MLFLOW_KEEP_RUN_ACTIVE", None)
    YOLO("yolov8n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu")
    # 获取指定运行 ID 的 MLflow 运行状态
    status = mlflow.get_run(run_id=run_id).info.status
    assert status == "FINISHED", "MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set"

# 使用 pytest 标记,当条件不满足时跳过测试
@pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed")
def test_triton():
    """
    Test NVIDIA Triton Server functionalities with YOLO model.

    See https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver.
    """
    # 检查 tritonclient 是否安装
    check_requirements("tritonclient[all]")
    # 导入 Triton 的推理服务器客户端模块
    from tritonclient.http import InferenceServerClient  # noqa
    # Create variables
    model_name = "yolo"  # 设置模型名称为 "yolo"
    triton_repo = TMP / "triton_repo"  # Triton仓库路径设为临时文件目录下的 triton_repo 文件夹
    triton_model = triton_repo / model_name  # Triton模型路径为 Triton仓库路径下的模型名称文件夹路径

    # Export model to ONNX
    f = YOLO(MODEL).export(format="onnx", dynamic=True)  # 将模型导出为ONNX格式文件,并保存路径到变量f

    # Prepare Triton repo
    (triton_model / "1").mkdir(parents=True, exist_ok=True)  # 在 Triton模型路径下创建版本号为1的子文件夹,若存在则忽略
    Path(f).rename(triton_model / "1" / "model.onnx")  # 将导出的ONNX模型文件移动到 Triton模型路径下的版本1文件夹中命名为model.onnx
    (triton_model / "config.pbtxt").touch()  # 在 Triton模型路径下创建一个名为config.pbtxt的空文件

    # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
    tag = "nvcr.io/nvidia/tritonserver:23.09-py3"  # 定义Docker镜像标签为nvcr.io/nvidia/tritonserver:23.09-py3,大小为6.4 GB

    # Pull the image
    subprocess.call(f"docker pull {tag}", shell=True)  # 使用Docker命令拉取指定标签的镜像

    # Run the Triton server and capture the container ID
    container_id = (
        subprocess.check_output(
            f"docker run -d --rm -v {triton_repo}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models",
            shell=True,
        )
        .decode("utf-8")
        .strip()
    )  # 启动 Triton 服务器,并获取容器的ID

    # Wait for the Triton server to start
    triton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False)  # 创建 Triton 客户端实例连接到本地的 Triton 服务器,端口为8000,关闭详细信息输出,不使用SSL

    # Wait until model is ready
    for _ in range(10):  # 循环10次
        with contextlib.suppress(Exception):  # 忽略异常
            assert triton_client.is_model_ready(model_name)  # 断言检查模型是否准备就绪
            break  # 如果模型就绪,跳出循环
        time.sleep(1)  # 等待1秒钟

    # Check Triton inference
    YOLO(f"http://localhost:8000/{model_name}", "detect")(SOURCE)  # 使用导出的模型进行 Triton 推理,传入参数SOURCE作为输入

    # Kill and remove the container at the end of the test
    subprocess.call(f"docker kill {container_id}", shell=True)  # 使用Docker命令终止指定ID的容器并删除
@pytest.mark.skipif(not check_requirements("pycocotools", install=False), reason="pycocotools not installed")
def test_pycocotools():
    """Validate YOLO model predictions on COCO dataset using pycocotools."""
    from ultralytics.models.yolo.detect import DetectionValidator
    from ultralytics.models.yolo.pose import PoseValidator
    from ultralytics.models.yolo.segment import SegmentationValidator

    # Download annotations after each dataset downloads first
    url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/"

    # 设置检测模型的参数和初始化检测器
    args = {"model": "yolov8n.pt", "data": "coco8.yaml", "save_json": True, "imgsz": 64}
    validator = DetectionValidator(args=args)
    # 运行检测器,执行评估
    validator()
    # 标记为COCO数据集
    validator.is_coco = True
    # 下载实例注释文件
    download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8/annotations")
    # 对评估的JSON文件进行评估
    _ = validator.eval_json(validator.stats)

    # 设置分割模型的参数和初始化分割器
    args = {"model": "yolov8n-seg.pt", "data": "coco8-seg.yaml", "save_json": True, "imgsz": 64}
    validator = SegmentationValidator(args=args)
    # 运行分割器,执行评估
    validator()
    # 标记为COCO数据集
    validator.is_coco = True
    # 下载实例注释文件
    download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8-seg/annotations")
    # 对评估的JSON文件进行评估
    _ = validator.eval_json(validator.stats)

    # 设置姿势估计模型的参数和初始化姿势估计器
    args = {"model": "yolov8n-pose.pt", "data": "coco8-pose.yaml", "save_json": True, "imgsz": 64}
    validator = PoseValidator(args=args)
    # 运行姿势估计器,执行评估
    validator()
    # 标记为COCO数据集
    validator.is_coco = True
    # 下载人体关键点注释文件
    download(f"{url}person_keypoints_val2017.json", dir=DATASETS_DIR / "coco8-pose/annotations")
    # 对评估的JSON文件进行评估
    _ = validator.eval_json(validator.stats)

.\yolov8\tests\test_python.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib  # 上下文管理工具
import urllib  # URL 处理模块
from copy import copy  # 复制对象的浅拷贝
from pathlib import Path  # 处理路径的对象

import cv2  # OpenCV 库
import numpy as np  # 数组操作库
import pytest  # 测试框架
import torch  # PyTorch 深度学习库
import yaml  # YAML 格式处理库
from PIL import Image  # Python 图像库

from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP  # 导入测试模块
from ultralytics import RTDETR, YOLO  # 导入 YOLO 和 RTDETR 模型类
from ultralytics.cfg import MODELS, TASK2DATA, TASKS  # 导入配置相关模块
from ultralytics.data.build import load_inference_source  # 导入数据构建函数
from ultralytics.utils import (  # 导入工具函数和变量
    ASSETS,
    DEFAULT_CFG,
    DEFAULT_CFG_PATH,
    LOGGER,
    ONLINE,
    ROOT,
    WEIGHTS_DIR,
    WINDOWS,
    checks,
)
from ultralytics.utils.downloads import download  # 导入下载函数
from ultralytics.utils.torch_utils import TORCH_1_9  # 导入 PyTorch 工具函数


def test_model_forward():
    """Test the forward pass of the YOLO model."""
    model = YOLO(CFG)  # 使用给定配置创建 YOLO 模型对象
    model(source=None, imgsz=32, augment=True)  # 测试不同参数的模型前向传播


def test_model_methods():
    """Test various methods and properties of the YOLO model to ensure correct functionality."""
    model = YOLO(MODEL)  # 使用给定模型路径创建 YOLO 模型对象

    # Model methods
    model.info(verbose=True, detailed=True)  # 调用模型的信息打印方法,详细展示
    model = model.reset_weights()  # 重置模型的权重
    model = model.load(MODEL)  # 加载指定模型
    model.to("cpu")  # 将模型转移到 CPU 设备
    model.fuse()  # 融合模型
    model.clear_callback("on_train_start")  # 清除指定的回调函数
    model.reset_callbacks()  # 重置所有回调函数

    # Model properties
    _ = model.names  # 获取模型的类别名称
    _ = model.device  # 获取模型当前设备
    _ = model.transforms  # 获取模型的数据转换
    _ = model.task_map  # 获取模型的任务映射


def test_model_profile():
    """Test profiling of the YOLO model with `profile=True` to assess performance and resource usage."""
    from ultralytics.nn.tasks import DetectionModel  # 导入检测模型类

    model = DetectionModel()  # 创建检测模型对象
    im = torch.randn(1, 3, 64, 64)  # 创建输入张量
    _ = model.predict(im, profile=True)  # 使用性能分析模式进行模型预测


@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_txt():
    """Tests YOLO predictions with file, directory, and pattern sources listed in a text file."""
    txt_file = TMP / "sources.txt"  # 创建临时文件路径
    with open(txt_file, "w") as f:
        for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]:
            f.write(f"{x}\n")  # 将多种数据源写入文本文件

    _ = YOLO(MODEL)(source=txt_file, imgsz=32)  # 使用文本文件中的数据源进行 YOLO 模型预测


@pytest.mark.parametrize("model_name", MODELS)
def test_predict_img(model_name):
    """Test YOLO model predictions on various image input types and sources, including online images."""
    model = YOLO(WEIGHTS_DIR / model_name)  # 使用给定模型名称加载 YOLO 模型

    im = cv2.imread(str(SOURCE))  # 读取输入图像为 numpy 数组
    assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1  # 使用 PIL 图像进行模型预测
    assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1  # 使用 numpy 数组进行模型预测
    assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2  # 使用 Tensor 数据进行批处理预测
    assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2  # 使用多个输入进行批处理预测
    assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2  # 使用流式数据进行预测
    assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1  # 使用 Tensor 转换为 numpy 数组进行预测
    batch = [
        str(SOURCE),  # 将 SOURCE 转换为字符串并存储在列表中,表示文件名
        Path(SOURCE),  # 使用 SOURCE 创建一个 Path 对象,并存储在列表中,表示路径
        "https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE,  # 如果 ONLINE 变量为真,则使用 GitHub 上的 URL,否则使用 SOURCE 变量,表示统一资源标识符(URI)
        cv2.imread(str(SOURCE)),  # 使用 OpenCV 读取 SOURCE 变量指定的图像,并将其存储在列表中
        Image.open(SOURCE),  # 使用 PIL 库打开 SOURCE 变量指定的图像,并将其存储在列表中
        np.zeros((320, 640, 3), dtype=np.uint8),  # 创建一个 320x640 大小,数据类型为 uint8 的全零数组,并存储在列表中,表示使用 numpy 库
    ]
    assert len(model(batch, imgsz=32)) == len(batch)  # 断言模型处理批量数据的输出长度与输入列表 batch 的长度相同
@pytest.mark.parametrize("model", MODELS)
def test_predict_visualize(model):
    """Test model prediction methods with 'visualize=True' to generate and display prediction visualizations."""
    # 使用不同的模型参数化测试模型的预测方法,设置 visualize=True 以生成和显示预测的可视化结果
    YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True)


def test_predict_grey_and_4ch():
    """Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames."""
    # 测试 YOLO 模型在将 SOURCE 转换为灰度图和四通道图像,并使用不同的文件名进行测试
    im = Image.open(SOURCE)
    directory = TMP / "im4"
    directory.mkdir(parents=True, exist_ok=True)

    source_greyscale = directory / "greyscale.jpg"
    source_rgba = directory / "4ch.png"
    source_non_utf = directory / "non_UTF_测试文件_tést_image.jpg"
    source_spaces = directory / "image with spaces.jpg"

    im.convert("L").save(source_greyscale)  # 将图像转换为灰度图并保存
    im.convert("RGBA").save(source_rgba)  # 将图像转换为四通道 PNG 并保存
    im.save(source_non_utf)  # 使用包含非 UTF 字符的文件名保存图像
    im.save(source_spaces)  # 使用包含空格的文件名保存图像

    # 推断过程
    model = YOLO(MODEL)
    for f in source_rgba, source_greyscale, source_non_utf, source_spaces:
        for source in Image.open(f), cv2.imread(str(f)), f:
            # 对每个文件进行模型预测,设置 save=True 和 verbose=True,imgsz=32
            results = model(source, save=True, verbose=True, imgsz=32)
            assert len(results) == 1  # 验证是否运行了一次图像预测
        f.unlink()  # 清理生成的临时文件


@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_youtube():
    """Test YOLO model on a YouTube video stream, handling potential network-related errors."""
    # 在 YouTube 视频流上测试 YOLO 模型,处理可能出现的网络相关错误
    model = YOLO(MODEL)
    try:
        model.predict("https://youtu.be/G17sBkb38XQ", imgsz=96, save=True)
    # 处理因网络连接问题引起的错误,例如 'urllib.error.HTTPError: HTTP Error 429: Too Many Requests'
    except (urllib.error.HTTPError, ConnectionError) as e:
        LOGGER.warning(f"WARNING: YouTube Test Error: {e}")


@pytest.mark.skipif(not ONLINE, reason="environment is offline")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_track_stream():
    """
    Tests streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods.

    Note imgsz=160 required for tracking for higher confidence and better matches.
    """
    # 测试在短10帧视频上使用 ByteTrack 跟踪器和不同的全局运动补偿(GMC)方法进行实时跟踪

    video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov"
    model = YOLO(MODEL)
    model.track(video_url, imgsz=160, tracker="bytetrack.yaml")  # 使用 ByteTrack 跟踪器进行跟踪
    model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True)  # 测试帧保存功能

    # 测试不同的全局运动补偿(GMC)方法
    for gmc in "orb", "sift", "ecc":
        with open(ROOT / "cfg/trackers/botsort.yaml", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        tracker = TMP / f"botsort-{gmc}.yaml"
        data["gmc_method"] = gmc
        with open(tracker, "w", encoding="utf-8") as f:
            yaml.safe_dump(data, f)
        model.track(video_url, imgsz=160, tracker=tracker)


def test_val():
    # 这是一个空测试函数,没有任何代码内容
    # 使用 YOLO 模型的验证模式进行测试
    # 实例化 YOLO 类,并调用其 val 方法,传入以下参数:
    #   - data="coco8.yaml": 指定配置文件为 "coco8.yaml"
    #   - imgsz=32: 指定图像尺寸为 32
    #   - save_hybrid=True: 设置保存混合结果为 True
    YOLO(MODEL).val(data="coco8.yaml", imgsz=32, save_hybrid=True)
def test_train_scratch():
    """Test training the YOLO model from scratch using the provided configuration."""
    # 创建一个 YOLO 模型对象,使用给定的配置 CFG
    model = YOLO(CFG)
    # 使用指定参数训练模型:数据为 coco8.yaml,训练周期为 2,图像大小为 32 像素,缓存方式为磁盘,批量大小为 -1,关闭马赛克效果,命名为 "model"
    model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="model")
    # 使用模型处理 SOURCE 数据
    model(SOURCE)


def test_train_pretrained():
    """Test training of the YOLO model starting from a pre-trained checkpoint."""
    # 创建一个 YOLO 模型对象,从预训练的检查点 WEIGHTS_DIR / "yolov8n-seg.pt" 开始
    model = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt")
    # 使用指定参数训练模型:数据为 coco8-seg.yaml,训练周期为 1,图像大小为 32 像素,缓存方式为 RAM,复制粘贴概率为 0.5,混合比例为 0.5,命名为 0
    model.train(data="coco8-seg.yaml", epochs=1, imgsz=32, cache="ram", copy_paste=0.5, mixup=0.5, name=0)
    # 使用模型处理 SOURCE 数据
    model(SOURCE)


def test_all_model_yamls():
    """Test YOLO model creation for all available YAML configurations in the `cfg/models` directory."""
    # 遍历 cfg/models 目录下所有的 YAML 配置文件
    for m in (ROOT / "cfg" / "models").rglob("*.yaml"):
        # 如果文件名包含 "rtdetr"
        if "rtdetr" in m.name:
            # 如果使用的是 Torch 版本 1.9 及以上
            if TORCH_1_9:
                # 创建 RTDETR 模型对象,传入 m.name 文件名,对 SOURCE 数据进行处理,图像大小为 640
                _ = RTDETR(m.name)(SOURCE, imgsz=640)  # 必须为 640
        else:
            # 创建 YOLO 模型对象,传入 m.name 文件名
            YOLO(m.name)


def test_workflow():
    """Test the complete workflow including training, validation, prediction, and exporting."""
    # 创建一个 YOLO 模型对象,使用指定的 MODEL
    model = YOLO(MODEL)
    # 训练模型:数据为 coco8.yaml,训练周期为 1,图像大小为 32 像素,优化器选择 SGD
    model.train(data="coco8.yaml", epochs=1, imgsz=32, optimizer="SGD")
    # 进行模型验证,图像大小为 32 像素
    model.val(imgsz=32)
    # 对 SOURCE 数据进行预测,图像大小为 32 像素
    model.predict(SOURCE, imgsz=32)
    # 导出模型为 TorchScript 格式
    model.export(format="torchscript")


def test_predict_callback_and_setup():
    """Test callback functionality during YOLO prediction setup and execution."""

    def on_predict_batch_end(predictor):
        """Callback function that handles operations at the end of a prediction batch."""
        # 获取 predictor.batch 的路径、图像和批量大小
        path, im0s, _ = predictor.batch
        # 将 im0s 转换为列表(如果不是),以便处理多图像情况
        im0s = im0s if isinstance(im0s, list) else [im0s]
        # 创建与预测结果、图像和批量大小相关联的元组列表
        bs = [predictor.dataset.bs for _ in range(len(path))]
        predictor.results = zip(predictor.results, im0s, bs)  # results is List[batch_size]

    # 创建一个 YOLO 模型对象,使用指定的 MODEL
    model = YOLO(MODEL)
    # 添加 on_predict_batch_end 回调函数到模型中
    model.add_callback("on_predict_batch_end", on_predict_batch_end)

    # 加载推理数据源,获取数据集的批量大小
    dataset = load_inference_source(source=SOURCE)
    bs = dataset.bs  # noqa access predictor properties
    # 对数据集进行预测,流式处理,图像大小为 160 像素
    results = model.predict(dataset, stream=True, imgsz=160)  # source already setup
    # 遍历预测结果列表
    for r, im0, bs in results:
        # 打印图像形状信息
        print("test_callback", im0.shape)
        # 打印批量大小信息
        print("test_callback", bs)
        # 获取预测结果的边界框对象
        boxes = r.boxes  # Boxes object for bbox outputs
        print(boxes)


@pytest.mark.parametrize("model", MODELS)
def test_results(model):
    """Ensure YOLO model predictions can be processed and printed in various formats."""
    # 使用指定模型 WEIGHTS_DIR / model 创建 YOLO 模型对象,并对 SOURCE 数据进行预测,图像大小为 160 像素
    results = YOLO(WEIGHTS_DIR / model)([SOURCE, SOURCE], imgsz=160)
    # 遍历预测结果列表
    for r in results:
        # 将结果转换为 CPU 上的 numpy 数组
        r = r.cpu().numpy()
        # 打印 numpy 数组的属性信息及路径
        print(r, len(r), r.path)  # print numpy attributes
        # 将结果转换为 CPU 上的 torch.float32 类型
        r = r.to(device="cpu", dtype=torch.float32)
        # 将结果保存为文本文件,保存置信度信息
        r.save_txt(txt_file=TMP / "runs/tests/label.txt", save_conf=True)
        # 将结果中的区域裁剪保存到指定目录
        r.save_crop(save_dir=TMP / "runs/tests/crops/")
        # 将结果转换为 JSON 格式,并进行规范化处理
        r.tojson(normalize=True)
        # 绘制结果的图像,返回 PIL 图像
        r.plot(pil=True)
        # 绘制结果的置信度图及边界框信息
        r.plot(conf=True, boxes=True)
        # 再次打印结果及路径信息
        print(r, len(r), r.path)  # print after methods


def test_labels_and_crops():
    # 这个函数是空的,未提供代码
    pass
    """Test output from prediction args for saving YOLO detection labels and crops; ensures accurate saving."""
    # 定义图片列表,包括源路径和指定的图像文件路径
    imgs = [SOURCE, ASSETS / "zidane.jpg"]
    # 使用预训练的 YOLO 模型处理图像列表,设置图像大小为160,保存检测结果的文本和裁剪图像
    results = YOLO(WEIGHTS_DIR / "yolov8n.pt")(imgs, imgsz=160, save_txt=True, save_crop=True)
    # 保存路径为结果中第一个元素的保存目录
    save_path = Path(results[0].save_dir)
    # 遍历每个结果
    for r in results:
        # 提取图像文件名作为标签文件名的基础
        im_name = Path(r.path).stem
        # 提取每个检测框的类别索引,转换为整数列表
        cls_idxs = r.boxes.cls.int().tolist()
        # 检查标签文件路径是否存在
        labels = save_path / f"labels/{im_name}.txt"
        assert labels.exists()  # 断言标签文件存在
        # 检查检测结果的数量是否与标签文件中的行数匹配
        assert len(r.boxes.data) == len([line for line in labels.read_text().splitlines() if line])
        # 获取所有裁剪图像的路径
        crop_dirs = list((save_path / "crops").iterdir())
        crop_files = [f for p in crop_dirs for f in p.glob("*")]
        # 断言每个类别索引对应的裁剪目录在裁剪目录中存在
        assert all(r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs)
        # 断言裁剪文件数量与检测框数量相匹配
        assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 标记为跳过测试,如果环境处于离线状态
def test_data_utils():
    """Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting."""
    # 导入需要测试的函数和模块
    from ultralytics.data.utils import HUBDatasetStats, autosplit
    from ultralytics.utils.downloads import zip_directory

    # from ultralytics.utils.files import WorkingDirectory
    # with WorkingDirectory(ROOT.parent / 'tests'):

    # 遍历任务列表,进行测试
    for task in TASKS:
        # 构建数据文件的路径,例如 coco8.zip
        file = Path(TASK2DATA[task]).with_suffix(".zip")  # i.e. coco8.zip
        # 下载数据文件
        download(f"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}", unzip=False, dir=TMP)
        # 创建数据集统计对象
        stats = HUBDatasetStats(TMP / file, task=task)
        # 生成数据集统计信息的 JSON 文件
        stats.get_json(save=True)
        # 处理图像数据
        stats.process_images()

    # 自动划分数据集
    autosplit(TMP / "coco8")
    # 压缩指定路径下的文件夹
    zip_directory(TMP / "coco8/images/val")  # zip


@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 标记为跳过测试,如果环境处于离线状态
def test_data_converter():
    """Test dataset conversion functions from COCO to YOLO format and class mappings."""
    # 导入需要测试的函数
    from ultralytics.data.converter import coco80_to_coco91_class, convert_coco

    # 下载 COCO 数据集的实例文件
    file = "instances_val2017.json"
    download(f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{file}", dir=TMP)
    # 将 COCO 数据集转换为 YOLO 格式
    convert_coco(labels_dir=TMP, save_dir=TMP / "yolo_labels", use_segments=True, use_keypoints=False, cls91to80=True)
    # 将 COCO80 类别映射为 COCO91 类别
    coco80_to_coco91_class()


def test_data_annotator():
    """Automatically annotate data using specified detection and segmentation models."""
    # 导入自动标注数据的函数
    from ultralytics.data.annotator import auto_annotate

    # 使用指定的检测和分割模型自动标注数据
    auto_annotate(
        ASSETS,
        det_model=WEIGHTS_DIR / "yolov8n.pt",
        sam_model=WEIGHTS_DIR / "mobile_sam.pt",
        output_dir=TMP / "auto_annotate_labels",
    )


def test_events():
    """Test event sending functionality."""
    # 导入事件发送功能模块
    from ultralytics.hub.utils import Events

    # 创建事件对象
    events = Events()
    events.enabled = True
    cfg = copy(DEFAULT_CFG)  # does not require deepcopy
    cfg.mode = "test"
    # 发送事件
    events(cfg)


def test_cfg_init():
    """Test configuration initialization utilities from the 'ultralytics.cfg' module."""
    # 导入配置初始化相关的函数
    from ultralytics.cfg import check_dict_alignment, copy_default_cfg, smart_value

    # 检查字典对齐性
    with contextlib.suppress(SyntaxError):
        check_dict_alignment({"a": 1}, {"b": 2})
    # 复制默认配置
    copy_default_cfg()
    # 删除复制的配置文件
    (Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")).unlink(missing_ok=False)
    # 对多个值应用智能化处理
    [smart_value(x) for x in ["none", "true", "false"]]


def test_utils_init():
    """Test initialization utilities in the Ultralytics library."""
    # 导入初始化工具函数
    from ultralytics.utils import get_git_branch, get_git_origin_url, get_ubuntu_version, is_github_action_running

    # 获取 Ubuntu 版本信息
    get_ubuntu_version()
    # 检查是否在 GitHub Action 环境下运行
    is_github_action_running()
    # 获取 Git 仓库的远程 URL
    get_git_origin_url()
    # 获取 Git 分支信息
    get_git_branch()


def test_utils_checks():
    """Test various utility checks for filenames, git status, requirements, image sizes, and versions."""
    # 导入各种检查函数
    from ultralytics.utils import checks

    # 检查 YOLOv5u 文件名格式
    checks.check_yolov5u_filename("yolov5n.pt")
    # 检查 Git 仓库状态
    checks.git_describe(ROOT)
    # 检查项目的要求是否符合 requirements.txt 中指定的依赖
    checks.check_requirements()  # check requirements.txt
    
    # 检查图像大小是否在指定范围内,确保宽度和高度均不超过 600 像素
    checks.check_imgsz([600, 600], max_dim=1)
    
    # 检查是否可以显示图像,若不能显示则发出警告
    checks.check_imshow(warn=True)
    
    # 检查指定模块的版本是否符合要求,这里检查 ultralytics 模块是否至少是 8.0.0 版本
    checks.check_version("ultralytics", "8.0.0")
    
    # 打印当前设置和参数,用于调试和确认运行时的配置
    checks.print_args()
@pytest.mark.skipif(WINDOWS, reason="Windows profiling is extremely slow (cause unknown)")
# 如果在 Windows 下运行,跳过此测试,原因是 Windows 上的性能分析非常缓慢(原因不明)
def test_utils_benchmarks():
    """Benchmark model performance using 'ProfileModels' from 'ultralytics.utils.benchmarks'."""
    # 导入性能分析工具 'ProfileModels' 来评估模型性能
    from ultralytics.utils.benchmarks import ProfileModels

    # 使用 ProfileModels 类来对 'yolov8n.yaml' 模型进行性能分析,设置图像大小为 32,最小运行时间为 1 秒,运行 3 次,预热 1 次
    ProfileModels(["yolov8n.yaml"], imgsz=32, min_time=1, num_timed_runs=3, num_warmup_runs=1).profile()


def test_utils_torchutils():
    """Test Torch utility functions including profiling and FLOP calculations."""
    # 导入相关模块和函数进行测试,包括性能分析和 FLOP 计算
    from ultralytics.nn.modules.conv import Conv
    from ultralytics.utils.torch_utils import get_flops_with_torch_profiler, profile, time_sync

    # 创建一个随机张量作为输入
    x = torch.randn(1, 64, 20, 20)
    # 创建一个 Conv 模型实例
    m = Conv(64, 64, k=1, s=2)

    # 使用 profile 函数对模型 m 进行性能分析,运行 3 次
    profile(x, [m], n=3)
    # 使用 get_flops_with_torch_profiler 函数获取模型 m 的 FLOP
    get_flops_with_torch_profiler(m)
    # 执行时间同步操作
    time_sync()


@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
# 如果处于离线环境,跳过此测试
def test_utils_downloads():
    """Test file download utilities from ultralytics.utils.downloads."""
    # 导入文件下载工具函数 get_google_drive_file_info
    from ultralytics.utils.downloads import get_google_drive_file_info

    # 调用 get_google_drive_file_info 函数下载特定 Google Drive 文件的信息
    get_google_drive_file_info("https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link")


def test_utils_ops():
    """Test utility operations functions for coordinate transformation and normalization."""
    # 导入坐标转换和归一化等操作函数
    from ultralytics.utils.ops import (
        ltwh2xywh,
        ltwh2xyxy,
        make_divisible,
        xywh2ltwh,
        xywh2xyxy,
        xywhn2xyxy,
        xywhr2xyxyxyxy,
        xyxy2ltwh,
        xyxy2xywh,
        xyxy2xywhn,
        xyxyxyxy2xywhr,
    )

    # 使用 make_divisible 函数,确保 17 能够被 8 整除
    make_divisible(17, torch.tensor([8]))

    # 创建随机框坐标张量
    boxes = torch.rand(10, 4)  # xywh
    # 检查通过 xywh2xyxy 和 xyxy2xywh 函数的转换后的张量是否相等
    torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes)))
    # 检查通过 xywhn2xyxy 和 xyxy2xywhn 函数的转换后的张量是否相等
    torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes)))
    # 检查通过 ltwh2xywh 和 xywh2ltwh 函数的转换后的张量是否相等
    torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes)))
    # 检查通过 xyxy2ltwh 和 ltwh2xyxy 函数的转换后的张量是否相等
    torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes)))

    # 创建带有方向信息的随机框坐标张量
    boxes = torch.rand(10, 5)  # xywhr for OBB
    # 随机生成方向信息
    boxes[:, 4] = torch.randn(10) * 30
    # 检查通过 xywhr2xyxyxyxy 和 xyxyxyxy2xywhr 函数的转换后的张量是否相等,相对误差容忍度为 1e-3
    torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)


def test_utils_files():
    """Test file handling utilities including file age, date, and paths with spaces."""
    # 导入文件处理工具函数,包括文件年龄、日期和带空格路径的处理
    from ultralytics.utils.files import file_age, file_date, get_latest_run, spaces_in_path

    # 获取指定文件的年龄
    file_age(SOURCE)
    # 获取指定文件的日期
    file_date(SOURCE)
    # 获取根目录下运行记录的最新一次运行
    get_latest_run(ROOT / "runs")

    # 创建一个带有空格路径的临时目录
    path = TMP / "path/with spaces"
    path.mkdir(parents=True, exist_ok=True)
    # 在带有空格路径的临时目录中执行 spaces_in_path 函数,返回处理后的新路径并打印
    with spaces_in_path(path) as new_path:
        print(new_path)


@pytest.mark.slow
def test_utils_patches_torch_save():
    """Test torch_save backoff when _torch_save raises RuntimeError to ensure robustness."""
    # 导入测试函数和 mock
    from unittest.mock import MagicMock, patch

    # 导入要测试的函数 torch_save
    from ultralytics.utils.patches import torch_save

    # 创建一个 mock 对象,模拟 RuntimeError 异常
    mock = MagicMock(side_effect=RuntimeError)

    # 使用 patch 替换 _torch_save 函数,使其在调用时抛出 RuntimeError 异常
    with patch("ultralytics.utils.patches._torch_save", new=mock):
        # 断言调用 torch_save 函数时会抛出 RuntimeError 异常
        with pytest.raises(RuntimeError):
            torch_save(torch.zeros(1), TMP / "test.pt")
    # 断言,验证 mock 对象的方法被调用的次数是否等于 4
    assert mock.call_count == 4, "torch_save was not attempted the expected number of times"
def test_nn_modules_conv():
    """Test Convolutional Neural Network modules including CBAM, Conv2, and ConvTranspose."""
    from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focus

    c1, c2 = 8, 16  # 输入通道数和输出通道数
    x = torch.zeros(4, c1, 10, 10)  # BCHW,创建一个大小为4x8x10x10的张量(批量大小x通道数x高度x宽度)

    # 运行所有未在测试中涵盖的模块
    DWConvTranspose2d(c1, c2)(x)  # 使用DWConvTranspose2d进行转置卷积操作
    ConvTranspose(c1, c2)(x)  # 使用ConvTranspose进行转置卷积操作
    Focus(c1, c2)(x)  # 使用Focus模块处理输入
    CBAM(c1)(x)  # 使用CBAM模块处理输入

    # 合并操作
    m = Conv2(c1, c2)  # 创建Conv2对象
    m.fuse_convs()  # 融合卷积操作
    m(x)  # 对输入x进行Conv2操作


def test_nn_modules_block():
    """Test various blocks in neural network modules including C1, C3TR, BottleneckCSP, C3Ghost, and C3x."""
    from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3x

    c1, c2 = 8, 16  # 输入通道数和输出通道数
    x = torch.zeros(4, c1, 10, 10)  # BCHW,创建一个大小为4x8x10x10的张量(批量大小x通道数x高度x宽度)

    # 运行所有未在测试中涵盖的模块
    C1(c1, c2)(x)  # 使用C1模块处理输入
    C3x(c1, c2)(x)  # 使用C3x模块处理输入
    C3TR(c1, c2)(x)  # 使用C3TR模块处理输入
    C3Ghost(c1, c2)(x)  # 使用C3Ghost模块处理输入
    BottleneckCSP(c1, c2)(x)  # 使用BottleneckCSP模块处理输入


@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_hub():
    """Test Ultralytics HUB functionalities (e.g. export formats, logout)."""
    from ultralytics.hub import export_fmts_hub, logout
    from ultralytics.hub.utils import smart_request

    export_fmts_hub()  # 调用导出格式函数
    logout()  # 执行注销操作
    smart_request("GET", "https://github.com", progress=True)  # 发起一个GET请求至GitHub


@pytest.fixture
def image():
    """Load and return an image from a predefined source using OpenCV."""
    return cv2.imread(str(SOURCE))  # 使用OpenCV从预定义源加载并返回一张图像


@pytest.mark.parametrize(
    "auto_augment, erasing, force_color_jitter",
    [
        (None, 0.0, False),
        ("randaugment", 0.5, True),
        ("augmix", 0.2, False),
        ("autoaugment", 0.0, True),
    ],
)
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
    """Tests classification transforms during training with various augmentations to ensure proper functionality."""
    from ultralytics.data.augment import classify_augmentations

    transform = classify_augmentations(
        size=224,
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        hflip=0.5,
        vflip=0.5,
        auto_augment=auto_augment,
        hsv_h=0.015,
        hsv_s=0.4,
        hsv_v=0.4,
        force_color_jitter=force_color_jitter,
        erasing=erasing,
    )

    transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))

    assert transformed_image.shape == (3, 224, 224)  # 断言转换后图像的形状为(3, 224, 224)
    assert torch.is_tensor(transformed_image)  # 断言转换后图像是一个PyTorch张量
    assert transformed_image.dtype == torch.float32  # 断言转换后图像的数据类型为torch.float32


@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason="environment is offline")
def test_model_tune():
    """Tune YOLO model for performance improvement."""
    YOLO("yolov8n-pose.pt").tune(data="coco8-pose.yaml", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu")
    # 使用 YOLO 模型加载 "yolov8n-cls.pt" 权重文件,并进行调参和微调
    YOLO("yolov8n-cls.pt").tune(data="imagenet10", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu")
# 定义测试函数,用于测试模型嵌入(embeddings)
def test_model_embeddings():
    """Test YOLO model embeddings."""
    # 创建 YOLO 检测模型对象,使用指定模型
    model_detect = YOLO(MODEL)
    # 创建 YOLO 分割模型对象,使用指定权重文件
    model_segment = YOLO(WEIGHTS_DIR / "yolov8n-seg.pt")

    # 分别测试批次大小为1和2的情况
    for batch in [SOURCE], [SOURCE, SOURCE]:  # test batch size 1 and 2
        # 断言检测模型返回的嵌入特征长度与批次大小相同
        assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)
        # 断言分割模型返回的嵌入特征长度与批次大小相同
        assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)


# 使用 pytest.mark.skipif 标记,如果条件满足,则跳过该测试
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="YOLOWorld with CLIP is not supported in Python 3.12")
# 定义测试函数,测试支持 CLIP 的 YOLO 模型
def test_yolo_world():
    """Tests YOLO world models with CLIP support, including detection and training scenarios."""
    # 创建 YOLO World 模型对象,加载指定模型
    model = YOLO("yolov8s-world.pt")  # no YOLOv8n-world model yet
    # 设置模型的分类类别为 ["tree", "window"]
    model.set_classes(["tree", "window"])
    # 运行模型进行目标检测,设定置信度阈值为 0.01
    model(SOURCE, conf=0.01)

    # 创建 YOLO Worldv2 模型对象,加载指定模型
    model = YOLO("yolov8s-worldv2.pt")  # no YOLOv8n-world model yet
    # 从预训练模型开始训练,最后阶段包括评估
    # 使用 dota8.yaml,该文件少量类别以减少 CLIP 模型推理时间
    model.train(
        data="dota8.yaml",
        epochs=1,
        imgsz=32,
        cache="disk",
        close_mosaic=1,
    )

    # 测试 WorWorldTrainerFromScratch
    from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch

    # 创建 YOLO Worldv2 模型对象,加载指定模型
    model = YOLO("yolov8s-worldv2.yaml")  # no YOLOv8n-world model yet
    # 从头开始训练模型
    model.train(
        data={"train": {"yolo_data": ["dota8.yaml"]}, "val": {"yolo_data": ["dota8.yaml"]}},
        epochs=1,
        imgsz=32,
        cache="disk",
        close_mosaic=1,
        trainer=WorldTrainerFromScratch,
    )


# 定义测试函数,测试 YOLOv10 模型的训练、验证和预测步骤,使用最小配置
def test_yolov10():
    """Test YOLOv10 model training, validation, and prediction steps with minimal configurations."""
    # 创建 YOLOv10n 模型对象,加载指定模型配置文件
    model = YOLO("yolov10n.yaml")
    # 训练模型,使用 coco8.yaml 数据集,训练1轮,图像尺寸为32,使用磁盘缓存,关闭马赛克
    model.train(data="coco8.yaml", epochs=1, imgsz=32, close_mosaic=1, cache="disk")
    # 验证模型,使用 coco8.yaml 数据集,图像尺寸为32
    model.val(data="coco8.yaml", imgsz=32)
    # 进行预测,图像尺寸为32,保存文本输出和裁剪后的图像,进行数据增强
    model.predict(imgsz=32, save_txt=True, save_crop=True, augment=True)
    # 对给定的 SOURCE 数据进行预测
    model(SOURCE)

.\yolov8\tests\test_solutions.py

# 导入需要的库和模块
import cv2  # OpenCV库,用于图像和视频处理
import pytest  # 测试框架pytest

# 从ultralytics包中导入YOLO对象检测模型和解决方案
from ultralytics import YOLO, solutions
# 从ultralytics.utils.downloads模块中导入安全下载函数
from ultralytics.utils.downloads import safe_download

# 主要解决方案演示视频的下载链接
MAJOR_SOLUTIONS_DEMO = "https://github.com/ultralytics/assets/releases/download/v0.0.0/solutions_ci_demo.mp4"
# 运动监控解决方案演示视频的下载链接
WORKOUTS_SOLUTION_DEMO = "https://github.com/ultralytics/assets/releases/download/v0.0.0/solution_ci_pose_demo.mp4"

# 使用pytest.mark.slow标记的测试函数,测试主要解决方案
@pytest.mark.slow
def test_major_solutions():
    """Test the object counting, heatmap, speed estimation and queue management solution."""
    
    # 下载主要解决方案演示视频
    safe_download(url=MAJOR_SOLUTIONS_DEMO)
    # 加载YOLO模型,用于目标检测
    model = YOLO("yolov8n.pt")
    # 获取YOLO模型的类别名称
    names = model.names
    # 打开主要解决方案演示视频
    cap = cv2.VideoCapture("solutions_ci_demo.mp4")
    assert cap.isOpened(), "Error reading video file"
    
    # 设置感兴趣区域的四个顶点坐标
    region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
    
    # 初始化解决方案对象:目标计数器、热度图、速度估计器和队列管理器
    counter = solutions.ObjectCounter(reg_pts=region_points, names=names, view_img=False)
    heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, names=names, view_img=False)
    speed = solutions.SpeedEstimator(reg_pts=region_points, names=names, view_img=False)
    queue = solutions.QueueManager(names=names, reg_pts=region_points, view_img=False)
    
    # 循环处理视频中的每一帧
    while cap.isOpened():
        success, im0 = cap.read()
        if not success:
            break
        # 备份原始图像
        original_im0 = im0.copy()
        
        # 使用YOLO模型进行目标跟踪
        tracks = model.track(im0, persist=True, show=False)
        
        # 调用解决方案对象的方法处理每一帧图像并获取结果
        _ = counter.start_counting(original_im0.copy(), tracks)
        _ = heatmap.generate_heatmap(original_im0.copy(), tracks)
        _ = speed.estimate_speed(original_im0.copy(), tracks)
        _ = queue.process_queue(original_im0.copy(), tracks)
    
    # 释放视频流
    cap.release()
    # 关闭所有窗口
    cv2.destroyAllWindows()


# 使用pytest.mark.slow标记的测试函数,测试AI健身监控解决方案
@pytest.mark.slow
def test_aigym():
    """Test the workouts monitoring solution."""
    
    # 下载运动监控解决方案演示视频
    safe_download(url=WORKOUTS_SOLUTION_DEMO)
    # 加载YOLO模型,用于姿态检测
    model = YOLO("yolov8n-pose.pt")
    # 打开运动监控解决方案演示视频
    cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
    assert cap.isOpened(), "Error reading video file"
    
    # 初始化AI健身监控对象
    gym_object = solutions.AIGym(line_thickness=2, pose_type="squat", kpts_to_check=[5, 11, 13])
    
    # 循环处理视频中的每一帧
    while cap.isOpened():
        success, im0 = cap.read()
        if not success:
            break
        # 使用YOLO模型进行姿态检测
        results = model.track(im0, verbose=False)
        # 调用AI健身监控对象的方法处理每一帧图像并获取结果
        _ = gym_object.start_counting(im0, results)
    
    # 释放视频流
    cap.release()
    # 关闭所有窗口
    cv2.destroyAllWindows()


# 使用pytest.mark.slow标记的测试函数,测试实例分割解决方案
@pytest.mark.slow
def test_instance_segmentation():
    """Test the instance segmentation solution."""
    
    # 从ultralytics.utils.plotting模块中导入Annotator和colors
    from ultralytics.utils.plotting import Annotator, colors
    
    # 加载YOLO模型,用于实例分割
    model = YOLO("yolov8n-seg.pt")
    # 获取YOLO模型的类别名称
    names = model.names
    # 打开主要解决方案演示视频(假设这里的视频与前面的测试相同)
    cap = cv2.VideoCapture("solutions_ci_demo.mp4")
    assert cap.isOpened(), "Error reading video file"
    # 循环检查视频流是否打开,如果打开则继续执行
    while cap.isOpened():
        # 从视频流中读取一帧图像,同时返回读取状态和图像数据
        success, im0 = cap.read()
        # 如果读取不成功(可能是视频流已经结束),则退出循环
        if not success:
            break
        # 使用模型对当前帧图像进行预测,返回预测结果
        results = model.predict(im0)
        # 创建一个注解器对象,用于在图像上绘制标注
        annotator = Annotator(im0, line_width=2)
        # 如果预测结果中包含实例的掩码信息
        if results[0].masks is not None:
            # 获取预测结果中每个实例的类别和掩码信息
            clss = results[0].boxes.cls.cpu().tolist()
            masks = results[0].masks.xy
            # 遍历每个实例的掩码和类别,为其添加边界框和标签
            for mask, cls in zip(masks, clss):
                # 根据类别获取对应的颜色,并设置是否使用模糊效果
                color = colors(int(cls), True)
                # 在图像上绘制带有边界框的实例掩码,并添加类别标签
                annotator.seg_bbox(mask=mask, mask_color=color, label=names[int(cls)])
    # 释放视频流资源
    cap.release()
    # 关闭所有 OpenCV 窗口,释放图形界面资源
    cv2.destroyAllWindows()
# 使用 pytest 的标记 @pytest.mark.slow 来标记这个测试函数为慢速测试
@pytest.mark.slow
# 定义一个测试函数,用于测试 Streamlit 预测的实时推理解决方案
def test_streamlit_predict():
    """Test streamlit predict live inference solution."""
    # 调用 solutions 模块中的 inference 函数进行测试
    solutions.inference()

.\yolov8\tests\__init__.py

# 导入Ultralytics YOLO的相关模块和函数,该项目使用AGPL-3.0许可证

# 从ultralytics.utils模块中导入常量和函数
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable

# 设置用于测试的常量
# MODEL代表YOLO模型的权重文件路径,包含空格
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt"  # test spaces in path

# CFG是YOLO配置文件的文件名
CFG = "yolov8n.yaml"

# SOURCE是用于测试的示例图片文件路径
SOURCE = ASSETS / "bus.jpg"

# TMP是用于存储测试文件的临时目录路径
TMP = (ROOT / "../tests/tmp").resolve()  # temp directory for test files

# 检查临时目录TMP是否可写
IS_TMP_WRITEABLE = is_dir_writeable(TMP)

# 检查CUDA是否可用
CUDA_IS_AVAILABLE = checks.cuda_is_available()

# 获取CUDA设备的数量
CUDA_DEVICE_COUNT = checks.cuda_device_count()

# 导出所有的常量和变量,以便在模块外部使用
__all__ = (
    "MODEL",
    "CFG",
    "SOURCE",
    "TMP",
    "IS_TMP_WRITEABLE",
    "CUDA_IS_AVAILABLE",
    "CUDA_DEVICE_COUNT",
)

Models

Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration files (*.yamls) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.

These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs.

To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've selected a model, you can use the provided *.yaml file to train and deploy your custom YOLO model with ease. See full details at the Ultralytics Docs, and if you need help or have any questions, feel free to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!

Usage

Model *.yaml files may be used directly in the Command Line Interface (CLI) with a yolo command:

# Train a YOLOv8n model using the coco8 dataset for 100 epochs
yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=100

They may also be used directly in a Python environment, and accept the same arguments as in the CLI example above:

from ultralytics import YOLO

# Initialize a YOLOv8n model from a YAML configuration file
model = YOLO("model.yaml")

# If a pre-trained model is available, use it instead
# model = YOLO("model.pt")

# Display model information
model.info()

# Train the model using the COCO8 dataset for 100 epochs
model.train(data="coco8.yaml", epochs=100)

Pre-trained Model Architectures

Ultralytics supports many model architectures. Visit Ultralytics Models to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available.

Contribute New Models

Have you trained a new YOLO variant or achieved state-of-the-art performance with specific tuning? We'd love to showcase your work in our Models section! Contributions from the community in the form of new models, architectures, or optimizations are highly valued and can significantly enrich our repository.

By contributing to this section, you're helping us offer a wider array of model choices and configurations to the community. It's a fantastic way to share your knowledge and expertise while making the Ultralytics YOLO ecosystem even more versatile.

To get started, please consult our Contributing Guide for step-by-step instructions on how to submit a Pull Request (PR) 🛠️. Your contributions are eagerly awaited!

Let's join hands to extend the range and capabilities of the Ultralytics YOLO models 🙏!

.\yolov8\ultralytics\cfg\__init__.py

# 导入必要的库和模块
import contextlib  # 提供上下文管理工具的模块
import shutil  # 提供高级文件操作功能的模块
import subprocess  # 用于执行外部命令的模块
import sys  # 提供与 Python 解释器及其环境相关的功能
from pathlib import Path  # 提供处理路径的类和函数
from types import SimpleNamespace  # 提供创建简单命名空间的类
from typing import Dict, List, Union  # 提供类型提示支持

# 从Ultralytics的utils模块中导入多个工具和变量
from ultralytics.utils import (
    ASSETS,  # 资源目录的路径
    DEFAULT_CFG,  # 默认配置文件名
    DEFAULT_CFG_DICT,  # 默认配置字典
    DEFAULT_CFG_PATH,  # 默认配置文件的路径
    LOGGER,  # 日志记录器
    RANK,  # 运行的排名
    ROOT,  # 根目录路径
    RUNS_DIR,  # 运行结果保存的目录路径
    SETTINGS,  # 设置信息
    SETTINGS_YAML,  # 设置信息的YAML文件路径
    TESTS_RUNNING,  # 是否正在运行测试的标志
    IterableSimpleNamespace,  # 可迭代的简单命名空间
    __version__,  # Ultralytics工具包的版本信息
    checks,  # 检查函数
    colorstr,  # 带有颜色的字符串处理函数
    deprecation_warn,  # 弃用警告函数
    yaml_load,  # 加载YAML文件的函数
    yaml_print,  # 打印YAML内容的函数
)

# 定义有效的任务和模式集合
MODES = {"train", "val", "predict", "export", "track", "benchmark"}  # 可执行的模式集合
TASKS = {"detect", "segment", "classify", "pose", "obb"}  # 可执行的任务集合

# 将任务映射到其对应的数据文件
TASK2DATA = {
    "detect": "coco8.yaml",
    "segment": "coco8-seg.yaml",
    "classify": "imagenet10",
    "pose": "coco8-pose.yaml",
    "obb": "dota8.yaml",
}

# 将任务映射到其对应的模型文件
TASK2MODEL = {
    "detect": "yolov8n.pt",
    "segment": "yolov8n-seg.pt",
    "classify": "yolov8n-cls.pt",
    "pose": "yolov8n-pose.pt",
    "obb": "yolov8n-obb.pt",
}

# 将任务映射到其对应的指标文件
TASK2METRIC = {
    "detect": "metrics/mAP50-95(B)",
    "segment": "metrics/mAP50-95(M)",
    "classify": "metrics/accuracy_top1",
    "pose": "metrics/mAP50-95(P)",
    "obb": "metrics/mAP50-95(B)",
}

# 从TASKS集合中提取模型文件集合
MODELS = {TASK2MODEL[task] for task in TASKS}

# 获取命令行参数,如果不存在则设置为空列表
ARGV = sys.argv or ["", ""]

# 定义CLI帮助信息,说明如何使用Ultralytics 'yolo'命令
CLI_HELP_MSG = f"""
    Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax:

        yolo TASK MODE ARGS

        Where   TASK (optional) is one of {TASKS}
                MODE (required) is one of {MODES}
                ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
                    See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg'

    1. Train a detection model for 10 epochs with an initial learning_rate of 0.01
        yolo train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.01

    2. Predict a YouTube video using a pretrained segmentation model at image size 320:
        yolo predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320

    3. Val a pretrained detection model at batch-size 1 and image size 640:
        yolo val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640

    4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
        yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128

    5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
        yolo explorer data=data.yaml model=yolov8n.pt
    
    6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8
        yolo streamlit-predict
        
    7. Run special commands:
        yolo help
        yolo checks
        yolo version
        yolo settings
        yolo copy-cfg
        yolo cfg

    Docs: https://docs.ultralytics.com
    Community: https://community.ultralytics.com
"""
    GitHub: https://github.com/ultralytics/ultralytics
    """
    GitHub: https://github.com/ultralytics/ultralytics
    # 在代码中添加一个字符串文档注释,指向项目的GitHub页面
    """
# Define keys for arg type checks
CFG_FLOAT_KEYS = {  # integer or float arguments, i.e. x=2 and x=2.0
    "warmup_epochs",
    "box",
    "cls",
    "dfl",
    "degrees",
    "shear",
    "time",
    "workspace",
    "batch",
}
CFG_FRACTION_KEYS = {  # fractional float arguments with 0.0<=values<=1.0
    "dropout",
    "lr0",
    "lrf",
    "momentum",
    "weight_decay",
    "warmup_momentum",
    "warmup_bias_lr",
    "label_smoothing",
    "hsv_h",
    "hsv_s",
    "hsv_v",
    "translate",
    "scale",
    "perspective",
    "flipud",
    "fliplr",
    "bgr",
    "mosaic",
    "mixup",
    "copy_paste",
    "conf",
    "iou",
    "fraction",
}
CFG_INT_KEYS = {  # integer-only arguments
    "epochs",
    "patience",
    "workers",
    "seed",
    "close_mosaic",
    "mask_ratio",
    "max_det",
    "vid_stride",
    "line_width",
    "nbs",
    "save_period",
}
CFG_BOOL_KEYS = {  # boolean-only arguments
    "save",
    "exist_ok",
    "verbose",
    "deterministic",
    "single_cls",
    "rect",
    "cos_lr",
    "overlap_mask",
    "val",
    "save_json",
    "save_hybrid",
    "half",
    "dnn",
    "plots",
    "show",
    "save_txt",
    "save_conf",
    "save_crop",
    "save_frames",
    "show_labels",
    "show_conf",
    "visualize",
    "augment",
    "agnostic_nms",
    "retina_masks",
    "show_boxes",
    "keras",
    "optimize",
    "int8",
    "dynamic",
    "simplify",
    "nms",
    "profile",
    "multi_scale",
}


def cfg2dict(cfg):
    """
    Converts a configuration object to a dictionary.

    Args:
        cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,
            a string, a dictionary, or a SimpleNamespace object.

    Returns:
        (Dict): Configuration object in dictionary format.

    Examples:
        Convert a YAML file path to a dictionary:
        >>> config_dict = cfg2dict('config.yaml')

        Convert a SimpleNamespace to a dictionary:
        >>> from types import SimpleNamespace
        >>> config_sn = SimpleNamespace(param1='value1', param2='value2')
        >>> config_dict = cfg2dict(config_sn)

        Pass through an already existing dictionary:
        >>> config_dict = cfg2dict({'param1': 'value1', 'param2': 'value2'})

    Notes:
        - If cfg is a path or string, it's loaded as YAML and converted to a dictionary.
        - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars().
        - If cfg is already a dictionary, it's returned unchanged.
    """
    if isinstance(cfg, (str, Path)):
        cfg = yaml_load(cfg)  # load dict from YAML file or string
    elif isinstance(cfg, SimpleNamespace):
        cfg = vars(cfg)  # convert SimpleNamespace to dictionary
    return cfg


def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
    """
    Load and merge configuration data from a file or dictionary, with optional overrides.

    Args:
        cfg (str | Path | Dict | SimpleNamespace): Configuration source to load from.
            Defaults to DEFAULT_CFG_DICT if not provided.
        overrides (Dict): Optional dictionary containing configuration overrides.

    Returns:
        (Dict): Merged configuration dictionary.

    Notes:
        - cfg can be a YAML file path, string, dictionary, or SimpleNamespace object.
        - If overrides are provided, they overwrite values from cfg.
    """
    # 将 cfg 转换为字典形式,统一处理配置数据来源为不同类型的情况(文件路径、字典、SimpleNamespace 对象)
    cfg = cfg2dict(cfg)

    # 合并 overrides
    if overrides:
        # 将 overrides 转换为字典形式
        overrides = cfg2dict(overrides)
        # 如果 cfg 中没有 "save_dir" 键,则在合并过程中忽略 "save_dir" 键
        if "save_dir" not in cfg:
            overrides.pop("save_dir", None)  # 特殊的覆盖键,忽略处理
        # 检查 cfg 和 overrides 字典的对齐性,确保正确性
        check_dict_alignment(cfg, overrides)
        # 合并 cfg 和 overrides 字典,以 overrides 为优先
        cfg = {**cfg, **overrides}  # 合并 cfg 和 overrides 字典(优先使用 overrides)

    # 对于数字类型的 "project" 和 "name" 进行特殊处理,转换为字符串
    for k in "project", "name":
        if k in cfg and isinstance(cfg[k], (int, float)):
            cfg[k] = str(cfg[k])
    
    # 如果配置中 "name" 等于 "model",则将其更新为 "model" 键对应值的第一个点之前的部分
    if cfg.get("name") == "model":
        cfg["name"] = cfg.get("model", "").split(".")[0]
        # 发出警告信息,提示自动更新 "name" 为新值
        LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")

    # 对配置数据进行类型和值的检查
    check_cfg(cfg)

    # 返回包含合并配置的 IterableSimpleNamespace 实例
    return IterableSimpleNamespace(**cfg)
# 验证和修正 Ultralytics 库的配置参数类型和值

def check_cfg(cfg, hard=True):
    """
    Checks configuration argument types and values for the Ultralytics library.

    This function validates the types and values of configuration arguments, ensuring correctness and converting
    them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS,
    CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS.

    Args:
        cfg (Dict): Configuration dictionary to validate.
        hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them.

    Examples:
        >>> config = {
        ...     'epochs': 50,     # valid integer
        ...     'lr0': 0.01,      # valid float
        ...     'momentum': 1.2,  # invalid float (out of 0.0-1.0 range)
        ...     'save': 'true',   # invalid bool
        ... }
        >>> check_cfg(config, hard=False)
        >>> print(config)
        {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False}  # corrected 'save' key

    Notes:
        - The function modifies the input dictionary in-place.
        - None values are ignored as they may be from optional arguments.
        - Fraction keys are checked to be within the range [0.0, 1.0].
    """
    # 遍历配置字典中的每个键值对
    for k, v in cfg.items():
        # 忽略值为 None 的情况,因为它们可能是可选参数的结果
        if v is not None:
            # 如果键在浮点数键集合中,但值不是 int 或 float 类型
            if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
                # 如果 hard 为 True,则抛出类型错误异常,否则尝试将值转换为 float 类型
                if hard:
                    raise TypeError(
                        f"'{k}={v}' is of invalid type {type(v).__name__}. "
                        f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
                    )
                cfg[k] = float(v)
            # 如果键在分数键集合中
            elif k in CFG_FRACTION_KEYS:
                # 如果值不是 int 或 float 类型,进行类型检查和可能的转换
                if not isinstance(v, (int, float)):
                    if hard:
                        raise TypeError(
                            f"'{k}={v}' is of invalid type {type(v).__name__}. "
                            f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
                        )
                    cfg[k] = v = float(v)
                # 检查分数值是否在 [0.0, 1.0] 范围内,否则抛出值错误异常
                if not (0.0 <= v <= 1.0):
                    raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
            # 如果键在整数键集合中,但值不是 int 类型
            elif k in CFG_INT_KEYS and not isinstance(v, int):
                if hard:
                    raise TypeError(
                        f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
                    )
                cfg[k] = int(v)
            # 如果键在布尔键集合中,但值不是 bool 类型
            elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
                if hard:
                    raise TypeError(
                        f"'{k}={v}' is of invalid type {type(v).__name__}. "
                        f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
                    )
                cfg[k] = bool(v)


def get_save_dir(args, name=None):
    """
    # 根据参数和默认设置确定输出目录路径。

    # 判断是否存在 args 中的 save_dir 属性,若存在则直接使用该路径
    if getattr(args, "save_dir", None):
        save_dir = args.save_dir
    else:
        # 如果不存在 save_dir 属性,则从 ultralytics.utils.files 中导入 increment_path 函数
        from ultralytics.utils.files import increment_path
        
        # 根据条件设定 project 的路径,若在测试环境中(TESTS_RUNNING 为真),则使用默认路径,否则使用 RUNS_DIR
        project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
        
        # 根据参数或默认值设置 name 的值,优先级顺序是提供的 name > args.name > args.mode
        name = name or args.name or f"{args.mode}"
        
        # 使用 increment_path 函数生成一个递增的路径,以确保路径的唯一性,根据 exist_ok 参数决定是否创建新路径
        save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)

    # 返回生成的路径作为 Path 对象
    return Path(save_dir)
def _handle_deprecation(custom):
    """
    Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings.

    Args:
        custom (Dict): Configuration dictionary potentially containing deprecated keys.

    Examples:
        >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2}
        >>> _handle_deprecation(custom_config)
        >>> print(custom_config)
        {'show_boxes': True, 'show_labels': True, 'line_width': 2}

    Notes:
        This function modifies the input dictionary in-place, replacing deprecated keys with their current
        equivalents. It also handles value conversions where necessary, such as inverting boolean values for
        'hide_labels' and 'hide_conf'.
    """

    # 遍历输入字典的副本,以便安全地修改原字典
    for key in custom.copy().keys():
        # 如果发现 'boxes' 键,发出弃用警告,并将其映射到 'show_boxes'
        if key == "boxes":
            deprecation_warn(key, "show_boxes")
            custom["show_boxes"] = custom.pop("boxes")
        # 如果发现 'hide_labels' 键,发出弃用警告,并根据值将其映射到 'show_labels'
        if key == "hide_labels":
            deprecation_warn(key, "show_labels")
            custom["show_labels"] = custom.pop("hide_labels") == "False"
        # 如果发现 'hide_conf' 键,发出弃用警告,并根据值将其映射到 'show_conf'
        if key == "hide_conf":
            deprecation_warn(key, "show_conf")
            custom["show_conf"] = custom.pop("hide_conf") == "False"
        # 如果发现 'line_thickness' 键,发出弃用警告,并将其映射到 'line_width'
        if key == "line_thickness":
            deprecation_warn(key, "line_width")
            custom["line_width"] = custom.pop("line_thickness")

    # 返回更新后的自定义配置字典
    return custom


def check_dict_alignment(base: Dict, custom: Dict, e=None):
    """
    Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
    messages for mismatched keys.

    Args:
        base (Dict): The base configuration dictionary containing valid keys.
        custom (Dict): The custom configuration dictionary to be checked for alignment.
        e (Exception | None): Optional error instance passed by the calling function.

    Raises:
        SystemExit: If mismatched keys are found between the custom and base dictionaries.

    Examples:
        >>> base_cfg = {'epochs': 50, 'lr0': 0.01, 'batch_size': 16}
        >>> custom_cfg = {'epoch': 100, 'lr': 0.02, 'batch_size': 32}
        >>> try:
        ...     check_dict_alignment(base_cfg, custom_cfg)
        ... except SystemExit:
        ...     print("Mismatched keys found")

    Notes:
        - Suggests corrections for mismatched keys based on similarity to valid keys.
        - Automatically replaces deprecated keys in the custom configuration with updated equivalents.
        - Prints detailed error messages for each mismatched key to help users correct their configurations.
    """

    # 处理自定义配置中的弃用键,将其更新为当前版本的等效键
    custom = _handle_deprecation(custom)
    
    # 获取基础配置和自定义配置的键集合
    base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
    
    # 找出自定义配置中存在但基础配置中不存在的键
    mismatched = [k for k in custom_keys if k not in base_keys]
    # 如果存在不匹配的情况,则执行以下代码块
    if mismatched:
        # 导入模块 difflib 中的 get_close_matches 函数
        from difflib import get_close_matches

        # 初始化空字符串,用于存储错误信息
        string = ""
        
        # 遍历所有不匹配的项
        for x in mismatched:
            # 使用 get_close_matches 函数寻找在 base_keys 中与 x 最接近的匹配项
            matches = get_close_matches(x, base_keys)  # key list
            
            # 将匹配项转换为字符串,如果 base 中存在对应项,则添加其值
            matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches]
            
            # 如果有找到匹配项,生成匹配信息字符串
            match_str = f"Similar arguments are i.e. {matches}." if matches else ""
            
            # 构造错误信息字符串,指出不是有效 YOLO 参数的项及其可能的匹配项
            string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
        
        # 抛出 SyntaxError 异常,包含错误信息和 CLI_HELP_MSG 的帮助信息
        raise SyntaxError(string + CLI_HELP_MSG) from e
# 处理命令行参数列表中隔离的 '=',合并相关参数
def merge_equals_args(args: List[str]) -> List[str]:
    """
    Merges arguments around isolated '=' in a list of strings, handling three cases:
    1. ['arg', '=', 'val'] becomes ['arg=val'],
    2. ['arg=', 'val'] becomes ['arg=val'],
    3. ['arg', '=val'] becomes ['arg=val'].

    Args:
        args (List[str]): A list of strings where each element represents an argument.

    Returns:
        (List[str]): A list of strings where the arguments around isolated '=' are merged.

    Examples:
        >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3"]
        >>> merge_equals_args(args)
        ['arg1=value', 'arg2=value2', 'arg3=value3']
    """
    new_args = []
    for i, arg in enumerate(args):
        if arg == "=" and 0 < i < len(args) - 1:  # merge ['arg', '=', 'val']
            new_args[-1] += f"={args[i + 1]}"
            del args[i + 1]
        elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]:  # merge ['arg=', 'val']
            new_args.append(f"{arg}{args[i + 1]}")
            del args[i + 1]
        elif arg.startswith("=") and i > 0:  # merge ['arg', '=val']
            new_args[-1] += arg
        else:
            new_args.append(arg)
    return new_args


# 处理 Ultralytics HUB 命令行接口 (CLI) 命令,用于认证
def handle_yolo_hub(args: List[str]) -> None:
    """
    Handles Ultralytics HUB command-line interface (CLI) commands for authentication.

    This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a
    script with arguments related to HUB authentication.

    Args:
        args (List[str]): A list of command line arguments. The first argument should be either 'login'
            or 'logout'. For 'login', an optional second argument can be the API key.

    Examples:
        ```bash
        yolo hub login YOUR_API_KEY
        ```

    Notes:
        - The function imports the 'hub' module from ultralytics to perform login and logout operations.
        - For the 'login' command, if no API key is provided, an empty string is passed to the login function.
        - The 'logout' command does not require any additional arguments.
    """
    from ultralytics import hub

    if args[0] == "login":
        key = args[1] if len(args) > 1 else ""
        # 使用提供的 API 密钥登录到 Ultralytics HUB
        hub.login(key)
    elif args[0] == "logout":
        # 从 Ultralytics HUB 注销
        hub.logout()


# 处理 YOLO 设置命令行接口 (CLI) 命令
def handle_yolo_settings(args: List[str]) -> None:
    """
    Handles YOLO settings command-line interface (CLI) commands.

    This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be
    called when executing a script with arguments related to YOLO settings management.

    Args:
        args (List[str]): A list of command line arguments for YOLO settings management.

    """
    url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings"  # 帮助文档的URL

    try:
        # 如果有任何参数
        if any(args):
            # 如果第一个参数是"reset"
            if args[0] == "reset":
                SETTINGS_YAML.unlink()  # 删除设置文件
                SETTINGS.reset()  # 创建新的设置
                LOGGER.info("Settings reset successfully")  # 提示用户设置已成功重置
            else:  # 否则,保存一个新的设置
                # 生成键值对字典,解析每个参数
                new = dict(parse_key_value_pair(a) for a in args)
                # 检查新设置和现有设置的对齐情况
                check_dict_alignment(SETTINGS, new)
                # 更新设置
                SETTINGS.update(new)

        LOGGER.info(f"💡 Learn about settings at {url}")  # 提示用户查看设置文档
        yaml_print(SETTINGS_YAML)  # 打印当前的设置到YAML文件
    except Exception as e:
        # 捕获异常并记录警告信息,提醒用户查看帮助文档
        LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
# 检查并确保 'streamlit' 包的版本符合要求(至少为1.29.0)
checks.check_requirements("streamlit>=1.29.0")
# 输出日志信息,指示正在加载 Explorer 仪表板
LOGGER.info("💡 Loading Explorer dashboard...")
# 定义运行 Streamlit 的命令行参数列表
cmd = ["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]
# 将命令行参数转换成字典形式,解析其中的键值对
new = dict(parse_key_value_pair(a) for a in args)
# 检查并对齐参数字典的默认值与自定义值
check_dict_alignment(base={k: DEFAULT_CFG_DICT[k] for k in ["model", "data"]}, custom=new)
# 遍历自定义参数字典,将其键值对添加到命令行参数列表中
for k, v in new.items():
    cmd += [k, v]
# 运行拼装好的命令行参数列表,启动 Streamlit 应用
subprocess.run(cmd)
    Notes:
        - Split the input string `pair` into two parts based on the first '=' character.
        - Remove leading and trailing whitespace from both `k` (key) and `v` (value).
        - Raise an assertion error if `v` (value) becomes empty after stripping.
    """
    k, v = pair.split("=", 1)  # split on first '=' sign
    k, v = k.strip(), v.strip()  # remove spaces
    assert v, f"missing '{k}' value"
    return k, smart_value(v)
# Ultralytics入口函数,用于解析和执行命令行参数
def entrypoint(debug=""):
    """
    Ultralytics entrypoint function for parsing and executing command-line arguments.

    This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and
    executing the corresponding tasks such as training, validation, prediction, exporting models, and more.

    Args:
        debug (str): Space-separated string of command-line arguments for debugging purposes.

    Examples:
        Train a detection model for 10 epochs with an initial learning_rate of 0.01:
        >>> entrypoint("train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.01")

        Predict a YouTube video using a pretrained segmentation model at image size 320:
        >>> entrypoint("predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320")

        Validate a pretrained detection model at batch-size 1 and image size 640:
        >>> entrypoint("val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640")

    Notes:
        - If no arguments are passed, the function will display the usage help message.
        - For a list of all available commands and their arguments, see the provided help messages and the
          Ultralytics documentation at https://docs.ultralytics.com.
    """
    # 解析调试参数,若未传入参数则使用全局变量ARGV
    args = (debug.split(" ") if debug else ARGV)[1:]
    # 若没有传入参数,则打印使用帮助信息并返回
    if not args:  # no arguments passed
        LOGGER.info(CLI_HELP_MSG)
        return
    # 定义特殊命令及其对应的操作
    special = {
        "help": lambda: LOGGER.info(CLI_HELP_MSG),  # 打印帮助信息
        "checks": checks.collect_system_info,  # 收集系统信息
        "version": lambda: LOGGER.info(__version__),  # 打印版本信息
        "settings": lambda: handle_yolo_settings(args[1:]),  # 处理设置命令
        "cfg": lambda: yaml_print(DEFAULT_CFG_PATH),  # 打印默认配置路径
        "hub": lambda: handle_yolo_hub(args[1:]),  # 处理hub命令
        "login": lambda: handle_yolo_hub(args),  # 处理登录命令
        "copy-cfg": copy_default_cfg,  # 复制默认配置文件
        "explorer": lambda: handle_explorer(args[1:]),  # 处理explorer命令
        "streamlit-predict": lambda: handle_streamlit_inference(),  # 处理streamlit预测命令
    }
    
    # 将特殊命令合并到完整的参数字典中,包括默认配置、任务和模式
    full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}

    # 定义特殊命令的常见误用,例如-h, -help, --help等,添加到特殊命令字典中
    special.update({k[0]: v for k, v in special.items()})  # 单数形式
    special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")})  # 单数形式
    special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}}

    # 初始化覆盖参数字典
    overrides = {}

    # 遍历合并等号周围的参数,并进行处理
    for a in merge_equals_args(args):
        if a.startswith("--"):
            # 警告:参数'a'不需要前导破折号'--',更新为'{a[2:]}'。
            LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
            a = a[2:]
        if a.endswith(","):
            # 警告:参数'a'不需要尾随逗号',',更新为'{a[:-1]}'。
            LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
            a = a[:-1]
        if "=" in a:
            try:
                # 解析键值对(a),并处理特定情况下的覆盖
                k, v = parse_key_value_pair(a)
                if k == "cfg" and v is not None:  # 如果传递了自定义yaml路径
                    LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}")
                    # 更新覆盖字典,排除键为'cfg'的条目
                    overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"}
                else:
                    overrides[k] = v
            except (NameError, SyntaxError, ValueError, AssertionError) as e:
                # 检查覆盖参数时出现异常
                check_dict_alignment(full_args_dict, {a: ""}, e)
        elif a in TASKS:
            overrides["task"] = a
        elif a in MODES:
            overrides["mode"] = a
        elif a.lower() in special:
            # 如果参数在特殊命令中,则执行对应的操作并返回
            special[a.lower()]()
            return
        elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
            # 对于默认布尔参数,例如'yolo show',自动设为True
            overrides[a] = True
        elif a in DEFAULT_CFG_DICT:
            # 抛出语法错误,提示缺少等号以设置参数值
            raise SyntaxError(
                f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign "
                f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}"
            )
        else:
            # 检查参数字典对齐性,处理未知参数情况
            check_dict_alignment(full_args_dict, {a: ""})

    # 检查参数字典的键对齐性,确保没有漏掉任何参数
    check_dict_alignment(full_args_dict, overrides)

    # 获取覆盖参数中的模式(mode)
    mode = overrides.get("mode")
    if mode is None:
        # 如果 mode 参数为 None,则使用默认值 'predict' 或从 DEFAULT_CFG 中获取的默认模式
        mode = DEFAULT_CFG.mode or "predict"
        # 发出警告日志,指示 'mode' 参数缺失,并显示可用的模式列表 MODES
        LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
    elif mode not in MODES:
        # 如果 mode 参数不在预定义的模式列表 MODES 中,则抛出 ValueError 异常
        raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")

    # Task
    # 从 overrides 字典中弹出 'task' 键对应的值
    task = overrides.pop("task", None)
    if task:
        if task not in TASKS:
            # 如果提供的 task 不在 TASKS 列表中,则抛出 ValueError 异常
            raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
        if "model" not in overrides:
            # 如果 'model' 不在 overrides 中,则设置 'model' 为 TASK2MODEL[task]
            overrides["model"] = TASK2MODEL[task]

    # Model
    # 从 overrides 字典中弹出 'model' 键对应的值,如果不存在,则使用 DEFAULT_CFG 中的默认模型
    model = overrides.pop("model", DEFAULT_CFG.model)
    if model is None:
        # 如果 model 仍为 None,则使用默认模型 'yolov8n.pt',并发出警告日志
        model = "yolov8n.pt"
        LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.")
    # 更新 overrides 字典中的 'model' 键为当前的 model 值
    overrides["model"] = model
    # 获取模型文件的基本文件名,并转换为小写
    stem = Path(model).stem.lower()
    # 根据模型文件名的特征选择合适的模型类
    if "rtdetr" in stem:  # 猜测架构
        from ultralytics import RTDETR
        # 使用 RTDETR 类初始化模型对象,没有指定 task 参数
        model = RTDETR(model)
    elif "fastsam" in stem:
        from ultralytics import FastSAM
        # 使用 FastSAM 类初始化模型对象
        model = FastSAM(model)
    elif "sam" in stem:
        from ultralytics import SAM
        # 使用 SAM 类初始化模型对象
        model = SAM(model)
    else:
        from ultralytics import YOLO
        # 使用 YOLO 类初始化模型对象,并传入 task 参数
        model = YOLO(model, task=task)
    if isinstance(overrides.get("pretrained"), str):
        # 如果 overrides 中的 'pretrained' 是字符串类型,则加载预训练模型
        model.load(overrides["pretrained"])

    # Task Update
    # 如果指定的 task 与 model 的 task 不一致,则更新 task
    if task != model.task:
        if task:
            # 发出警告日志,指示传入的 task 与模型的 task 不匹配
            LOGGER.warning(
                f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. "
                f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model."
            )
        task = model.task

    # Mode
    # 根据 mode 执行不同的逻辑
    if mode in {"predict", "track"} and "source" not in overrides:
        # 如果 mode 是 'predict' 或 'track',并且 overrides 中没有 'source',则使用默认的数据源 ASSETS
        overrides["source"] = DEFAULT_CFG.source or ASSETS
        LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
    elif mode in {"train", "val"}:
        if "data" not in overrides and "resume" not in overrides:
            # 如果 mode 是 'train' 或 'val',并且 overrides 中没有 'data' 和 'resume',则使用默认的数据配置
            overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
            LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
    elif mode == "export":
        if "format" not in overrides:
            # 如果 mode 是 'export',并且 overrides 中没有 'format',则使用默认的导出格式 'torchscript'
            overrides["format"] = DEFAULT_CFG.format or "torchscript"
            LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.")

    # 在模型对象上调用指定的 mode 方法,传入 overrides 字典中的参数
    getattr(model, mode)(**overrides)  # default args from model

    # Show help
    # 输出提示信息,指示用户查阅模式相关的文档
    LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}")
# Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg():
    """
    Copies the default configuration file and creates a new one with '_copy' appended to its name.

    This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it
    with '_copy' appended to its name in the current working directory. It provides a convenient way
    to create a custom configuration file based on the default settings.

    Examples:
        >>> copy_default_cfg()
        # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml
        # Example YOLO command with this new custom cfg:
        #   yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8

    Notes:
        - The new configuration file is created in the current working directory.
        - After copying, the function prints a message with the new file's location and an example
          YOLO command demonstrating how to use the new configuration file.
        - This function is useful for users who want to modify the default configuration without
          altering the original file.
    """
    # 创建新文件路径,将默认配置文件复制到当前工作目录并在文件名末尾添加 '_copy'
    new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
    # 使用 shutil 库的 copy2 函数复制 DEFAULT_CFG_PATH 指定的文件到新的文件路径
    shutil.copy2(DEFAULT_CFG_PATH, new_file)
    # 记录信息到日志,包括已复制的文件路径和示例 YOLO 命令,指导如何使用新的配置文件
    LOGGER.info(
        f"{DEFAULT_CFG_PATH} copied to {new_file}\n"
        f"Example YOLO command with this new custom cfg:\n    yolo cfg='{new_file}' imgsz=320 batch=8"
    )


if __name__ == "__main__":
    # Example: entrypoint(debug='yolo predict model=yolov8n.pt')
    # 当作为主程序运行时,调用 entrypoint 函数并传递一个空的 debug 参数
    entrypoint(debug="")
posted @ 2024-09-05 11:58  绝不原创的飞龙  阅读(5)  评论(0编辑  收藏  举报