Yolov8-源码解析-三十二-

Yolov8 源码解析(三十二)

.\yolov8\ultralytics\engine\tuner.py

"""
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
instance segmentation, image classification, pose estimation, and multi-object tracking.

Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.

Example:
    Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
    ```python
    from ultralytics import YOLO

    model = YOLO('yolov8n.pt')
    model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
    ```py
"""

import random
import shutil
import subprocess
import time

import numpy as np
import torch

from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
from ultralytics.utils.plotting import plot_tune_results


class Tuner:
    """
    Class responsible for hyperparameter tuning of YOLO models.

    The class evolves YOLO model hyperparameters over a given number of iterations
    by mutating them according to the search space and retraining the model to evaluate their performance.

    Attributes:
        space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
        tune_dir (Path): Directory where evolution logs and results will be saved.
        tune_csv (Path): Path to the CSV file where evolution logs are saved.

    Methods:
        _mutate(hyp: dict) -> dict:
            Mutates the given hyperparameters within the bounds specified in `self.space`.

        __call__():
            Executes the hyperparameter evolution across multiple iterations.

    Example:
        Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
        ```python
        from ultralytics import YOLO

        model = YOLO('yolov8n.pt')
        model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
        ```py

        Tune with custom search space.
        ```python
        from ultralytics import YOLO

        model = YOLO('yolov8n.pt')
        model.tune(space={key1: val1, key2: val2})  # custom search space dictionary
        ```
    """
    def __init__(self, args=DEFAULT_CFG, _callbacks=None):
        """
        Initialize the Tuner with configurations.

        Args:
            args (dict, optional): Configuration for hyperparameter evolution.
        """
        # 将参数中的'space'键弹出,如果不存在则使用默认空间字典
        self.space = args.pop("space", None) or {  # key: (min, max, gain(optional))
            # 初始学习率范围 (例如 SGD=1E-2, Adam=1E-3)
            "lr0": (1e-5, 1e-1),
            # 最终的 OneCycleLR 学习率范围 (lr0 * lrf)
            "lrf": (0.0001, 0.1),
            # SGD 动量/Adam beta1 范围
            "momentum": (0.7, 0.98, 0.3),
            # 优化器权重衰减范围
            "weight_decay": (0.0, 0.001),
            # 温升 epochs 范围 (可以是小数)
            "warmup_epochs": (0.0, 5.0),
            # 温升初始动量范围
            "warmup_momentum": (0.0, 0.95),
            # box 损失增益范围
            "box": (1.0, 20.0),
            # cls 损失增益范围 (与像素缩放相关)
            "cls": (0.2, 4.0),
            # dfl 损失增益范围
            "dfl": (0.4, 6.0),
            # 图像 HSV-Hue 增强范围 (分数)
            "hsv_h": (0.0, 0.1),
            # 图像 HSV-Saturation 增强范围 (分数)
            "hsv_s": (0.0, 0.9),
            # 图像 HSV-Value 增强范围 (分数)
            "hsv_v": (0.0, 0.9),
            # 图像旋转范围 (+/- 度数)
            "degrees": (0.0, 45.0),
            # 图像平移范围 (+/- 分数)
            "translate": (0.0, 0.9),
            # 图像缩放范围 (+/- 增益)
            "scale": (0.0, 0.95),
            # 图像剪切范围 (+/- 度数)
            "shear": (0.0, 10.0),
            # 图像透视范围 (+/- 分数),范围 0-0.001
            "perspective": (0.0, 0.001),
            # 图像上下翻转概率
            "flipud": (0.0, 1.0),
            # 图像左右翻转概率
            "fliplr": (0.0, 1.0),
            # 图像通道 bgr 变换概率
            "bgr": (0.0, 1.0),
            # 图像混合概率
            "mosaic": (0.0, 1.0),
            # 图像 mixup 概率
            "mixup": (0.0, 1.0),
            # 分割复制粘贴概率
            "copy_paste": (0.0, 1.0),
        }
        # 使用参数获取配置并初始化
        self.args = get_cfg(overrides=args)
        # 获取保存目录路径
        self.tune_dir = get_save_dir(self.args, name="tune")
        # 定义保存结果的 CSV 文件路径
        self.tune_csv = self.tune_dir / "tune_results.csv"
        # 获取回调函数或者使用默认回调函数列表
        self.callbacks = _callbacks or callbacks.get_default_callbacks()
        # 设置前缀字符串
        self.prefix = colorstr("Tuner: ")
        # 添加整合回调函数
        callbacks.add_integration_callbacks(self)
        # 记录初始化信息
        LOGGER.info(
            f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
            f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
        )
    # 根据指定的参数变异超参数,基于self.space中指定的边界和缩放因子。
    def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
        """
        Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.

        Args:
            parent (str): Parent selection method: 'single' or 'weighted'.
            n (int): Number of parents to consider.
            mutation (float): Probability of a parameter mutation in any given iteration.
            sigma (float): Standard deviation for Gaussian random number generator.

        Returns:
            (dict): A dictionary containing mutated hyperparameters.
        """
        if self.tune_csv.exists():  # if CSV file exists: select best hyps and mutate
            # Select parent(s)
            x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
            fitness = x[:, 0]  # first column
            n = min(n, len(x))  # number of previous results to consider
            x = x[np.argsort(-fitness)][:n]  # top n mutations
            w = x[:, 0] - x[:, 0].min() + 1e-6  # weights (sum > 0)
            if parent == "single" or len(x) == 1:
                # x = x[random.randint(0, n - 1)]  # random selection
                x = x[random.choices(range(n), weights=w)[0]]  # weighted selection
            elif parent == "weighted":
                x = (x * w.reshape(n, 1)).sum(0) / w.sum()  # weighted combination

            # Mutate
            r = np.random  # method
            r.seed(int(time.time()))
            g = np.array([v[2] if len(v) == 3 else 1.0 for k, v in self.space.items()])  # gains 0-1
            ng = len(self.space)
            v = np.ones(ng)
            while all(v == 1):  # mutate until a change occurs (prevent duplicates)
                v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
            hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
        else:
            # 如果没有调优CSV文件,则使用self.args中的值初始化超参数
            hyp = {k: getattr(self.args, k) for k in self.space.keys()}

        # Constrain to limits
        # 将超参数限制在定义的边界内
        for k, v in self.space.items():
            hyp[k] = max(hyp[k], v[0])  # lower limit
            hyp[k] = min(hyp[k], v[1])  # upper limit
            hyp[k] = round(hyp[k], 5)  # significant digits

        return hyp

.\yolov8\ultralytics\engine\validator.py

# 导入必要的库
import json  # 导入处理 JSON 格式数据的模块
import time  # 导入时间相关的模块
from pathlib import Path  # 导入处理文件路径的模块

import numpy as np  # 导入处理数值数据的模块
import torch  # 导入 PyTorch 深度学习框架

# 导入 Ultralytics 自定义模块和函数
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode

# 定义一个基础验证器类
class BaseValidator:
    """
    BaseValidator.

    A base class for creating validators.

    Attributes:
        args (SimpleNamespace): Configuration for the validator.
        dataloader (DataLoader): Dataloader to use for validation.
        pbar (tqdm): Progress bar to update during validation.
        model (nn.Module): Model to validate.
        data (dict): Data dictionary.
        device (torch.device): Device to use for validation.
        batch_i (int): Current batch index.
        training (bool): Whether the model is in training mode.
        names (dict): Class names.
        seen: Records the number of images seen so far during validation.
        stats: Placeholder for statistics during validation.
        confusion_matrix: Placeholder for a confusion matrix.
        nc: Number of classes.
        iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
        jdict (dict): Dictionary to store JSON validation results.
        speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
                      batch processing times in milliseconds.
        save_dir (Path): Directory to save results.
        plots (dict): Dictionary to store plots for visualization.
        callbacks (dict): Dictionary to store various callback functions.
    """
    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """
        Initializes a BaseValidator instance.

        Args:
            dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
            save_dir (Path, optional): Directory to save results.
            pbar (tqdm.tqdm): Progress bar for displaying progress.
            args (SimpleNamespace): Configuration for the validator.
            _callbacks (dict): Dictionary to store various callback functions.
        """
        # 使用给定的参数初始化 BaseValidator 实例
        self.args = get_cfg(overrides=args)  # 获取配置参数,并用其覆盖默认配置
        self.dataloader = dataloader  # 存储数据加载器
        self.pbar = pbar  # 存储进度条对象
        self.stride = None  # 初始化步长为 None
        self.data = None  # 初始化数据为 None
        self.device = None  # 初始化设备为 None
        self.batch_i = None  # 初始化批次索引为 None
        self.training = True  # 标记当前为训练模式
        self.names = None  # 初始化名称列表为 None
        self.seen = None  # 初始化 seen 为 None
        self.stats = None  # 初始化统计信息为 None
        self.confusion_matrix = None  # 初始化混淆矩阵为 None
        self.nc = None  # 初始化类别数为 None
        self.iouv = None  # 初始化 iouv 为 None
        self.jdict = None  # 初始化 jdict 为 None
        self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}  # 初始化速度字典为各项均为 0.0

        self.save_dir = save_dir or get_save_dir(self.args)  # 设置保存结果的目录,如果未提供 save_dir,则使用默认目录
        (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
        # 如果保存为文本标签,则在保存目录下创建 'labels' 子目录;否则直接创建保存目录,并确保父目录存在
        if self.args.conf is None:
            self.args.conf = 0.001  # 如果未提供 conf 参数,则设置默认的 conf=0.001
        self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)  # 检查并修正图像尺寸参数

        self.plots = {}  # 初始化绘图字典为空
        self.callbacks = _callbacks or callbacks.get_default_callbacks()  # 设置回调函数字典,如果未提供,则获取默认回调函数
    def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
        """
        Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.

        Args:
            pred_classes (torch.Tensor): Predicted class indices of shape(N,).
            true_classes (torch.Tensor): Target class indices of shape(M,).
            iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
            use_scipy (bool): Whether to use scipy for matching (more precise).

        Returns:
            (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
        """
        # 创建一个全零的形状为 (预测类别数, IoU 阈值数) 的布尔类型数组,用于存储正确匹配结果
        correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
        
        # 创建一个形状为 (真实类别数, 预测类别数) 的布尔类型数组,标记哪些预测类别与真实类别相匹配
        correct_class = true_classes[:, None] == pred_classes
        
        # 将 IoU 值与正确类别对应位置的元素置为零,排除不匹配的类别影响
        iou = iou * correct_class
        iou = iou.cpu().numpy()  # 将计算后的 IoU 转换为 NumPy 数组
        
        # 遍历每个 IoU 阈值
        for i, threshold in enumerate(self.iouv.cpu().tolist()):
            if use_scipy:
                # 如果使用 scipy 匹配
                import scipy  # 仅在需要时引入以节省资源
                
                # 构建成本矩阵,仅保留大于等于当前阈值的 IoU 值
                cost_matrix = iou * (iou >= threshold)
                
                # 使用线性求和匹配最大化方法求解最优匹配
                if cost_matrix.any():
                    labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
                    valid = cost_matrix[labels_idx, detections_idx] > 0
                    if valid.any():
                        correct[detections_idx[valid], i] = True
            else:
                # 如果不使用 scipy 匹配,直接寻找满足 IoU 大于阈值且类别匹配的预测与真实标签
                matches = np.nonzero(iou >= threshold)
                matches = np.array(matches).T
                if matches.shape[0]:
                    if matches.shape[0] > 1:
                        matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
                        matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
                        matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                    correct[matches[:, 1].astype(int), i] = True
        
        # 返回布尔类型的 Torch 张量,表示每个预测是否正确匹配的结果
        return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)

    def add_callback(self, event: str, callback):
        """Appends the given callback to the list associated with the event."""
        self.callbacks[event].append(callback)

    def run_callbacks(self, event: str):
        """Runs all callbacks associated with a specified event."""
        for callback in self.callbacks.get(event, []):
            callback(self)

    def get_dataloader(self, dataset_path, batch_size):
        """Get data loader from dataset path and batch size."""
        raise NotImplementedError("get_dataloader function not implemented for this validator")
    # 定义一个方法用于构建数据集,但是抛出一个未实现的错误,提示需要在验证器中实现这个方法
    def build_dataset(self, img_path):
        """Build dataset."""
        raise NotImplementedError("build_dataset function not implemented in validator")

    # 定义一个方法用于预处理输入的批次数据,直接返回原始批次数据
    def preprocess(self, batch):
        """Preprocesses an input batch."""
        return batch

    # 定义一个方法用于后处理预测结果,直接返回预测结果
    def postprocess(self, preds):
        """Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
        return preds

    # 定义一个方法用于初始化 YOLO 模型的性能指标,但是这里什么也没做
    def init_metrics(self, model):
        """Initialize performance metrics for the YOLO model."""
        pass

    # 定义一个方法用于根据预测和批次数据更新性能指标,但是这里什么也没做
    def update_metrics(self, preds, batch):
        """Updates metrics based on predictions and batch."""
        pass

    # 定义一个方法用于完成并返回所有性能指标,但是这里什么也没做
    def finalize_metrics(self, *args, **kwargs):
        """Finalizes and returns all metrics."""
        pass

    # 定义一个方法用于返回模型性能的统计信息,这里返回一个空字典
    def get_stats(self):
        """Returns statistics about the model's performance."""
        return {}

    # 定义一个方法用于检查统计信息,但是这里什么也没做
    def check_stats(self, stats):
        """Checks statistics."""
        pass

    # 定义一个方法用于打印模型预测的结果,但是这里什么也没做
    def print_results(self):
        """Prints the results of the model's predictions."""
        pass

    # 定义一个方法用于获取 YOLO 模型的描述信息,但是这里什么也没做
    def get_desc(self):
        """Get description of the YOLO model."""
        pass

    # 定义一个属性方法,用于返回 YOLO 训练/验证中使用的性能指标键值,这里返回一个空列表
    @property
    def metric_keys(self):
        """Returns the metric keys used in YOLO training/validation."""
        return []

    # 定义一个方法用于注册绘图数据(例如供回调函数使用)
    def on_plot(self, name, data=None):
        """Registers plots (e.g. to be consumed in callbacks)"""
        self.plots[Path(name)] = {"data": data, "timestamp": time.time()}

    # TODO: may need to put these following functions into callback
    # 定义一个方法用于在训练期间绘制验证样本,但是这里什么也没做
    def plot_val_samples(self, batch, ni):
        """Plots validation samples during training."""
        pass

    # 定义一个方法用于绘制 YOLO 模型在批次图像上的预测结果,但是这里什么也没做
    def plot_predictions(self, batch, preds, ni):
        """Plots YOLO model predictions on batch images."""
        pass

    # 定义一个方法用于将预测结果转换为 JSON 格式,但是这里什么也没做
    def pred_to_json(self, preds, batch):
        """Convert predictions to JSON format."""
        pass

    # 定义一个方法用于评估和返回预测统计数据的 JSON 格式,但是这里什么也没做
    def eval_json(self, stats):
        """Evaluate and return JSON format of prediction statistics."""
        pass

.\yolov8\ultralytics\engine\__init__.py

# 提示代码的许可信息和开发者团队
# Ultralytics YOLO 🚀, AGPL-3.0 license

.\yolov8\ultralytics\hub\auth.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 引入 requests 模块,用于发送 HTTP 请求
import requests

# 从 ultralytics.hub.utils 模块导入相关常量和函数
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
# 从 ultralytics.utils 模块导入特定变量和函数
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis

# 定义 API_KEY_URL 常量,指向 API 密钥设置页面的 URL
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"

# Auth 类,管理认证流程,包括 API 密钥处理、基于 cookie 的认证和生成头部信息
class Auth:
    """
    Manages authentication processes including API key handling, cookie-based authentication, and header generation.

    The class supports different methods of authentication:
    1. Directly using an API key.
    2. Authenticating using browser cookies (specifically in Google Colab).
    3. Prompting the user to enter an API key.

    Attributes:
        id_token (str or bool): Token used for identity verification, initialized as False.
        api_key (str or bool): API key for authentication, initialized as False.
        model_key (bool): Placeholder for model key, initialized as False.
    """

    # 类属性:身份令牌 id_token、API 密钥 api_key 和模型密钥 model_key 的初始化
    id_token = api_key = model_key = False

    def __init__(self, api_key="", verbose=False):
        """
        Initialize the Auth class with an optional API key.

        Args:
            api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
        """
        # 如果 api_key 包含下划线,则按下划线分割并保留第一部分作为 API 密钥
        api_key = api_key.split("_")[0]

        # 将 API 密钥设置为传入的值或者从 SETTINGS 中获取的 api_key
        self.api_key = api_key or SETTINGS.get("api_key", "")

        # 如果提供了 API 密钥
        if self.api_key:
            # 如果提供的 API 密钥与 SETTINGS 中的 api_key 匹配
            if self.api_key == SETTINGS.get("api_key"):
                # 如果 verbose 为 True,记录用户已经认证成功
                if verbose:
                    LOGGER.info(f"{PREFIX}Authenticated ✅")
                return
            else:
                # 尝试使用提供的 API 密钥进行认证
                success = self.authenticate()
        # 如果未提供 API 密钥且运行环境是 Google Colab 笔记本
        elif IS_COLAB:
            # 尝试使用浏览器 cookie 进行认证
            success = self.auth_with_cookies()
        else:
            # 请求用户输入 API 密钥
            success = self.request_api_key()

        # 在成功认证后,更新 SETTINGS 中的 API 密钥
        if success:
            SETTINGS.update({"api_key": self.api_key})
            # 如果 verbose 为 True,记录新的认证成功
            if verbose:
                LOGGER.info(f"{PREFIX}New authentication successful ✅")
        elif verbose:
            # 如果认证失败且 verbose 为 True,提示用户从 API_KEY_URL 获取 API 密钥
            LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo hub login API_KEY'")
    # 定义一个方法用于请求 API 密钥,最多尝试 max_attempts 次
    def request_api_key(self, max_attempts=3):
        """
        Prompt the user to input their API key.

        Returns the model ID.
        """
        import getpass  # 导入 getpass 模块,用于隐藏输入的 API 密钥

        # 循环尝试获取 API 密钥
        for attempts in range(max_attempts):
            LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
            input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")  # 提示用户输入 API 密钥
            self.api_key = input_key.split("_")[0]  # 如果有模型 ID,去除下划线后面的部分
            if self.authenticate():  # 尝试验证 API 密钥的有效性
                return True
        # 如果达到最大尝试次数仍未成功,抛出连接错误
        raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))

    # 方法用于验证 API 密钥的有效性
    def authenticate(self) -> bool:
        """
        Attempt to authenticate with the server using either id_token or API key.

        Returns:
            (bool): True if authentication is successful, False otherwise.
        """
        try:
            if header := self.get_auth_header():  # 获取认证所需的头部信息
                r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)  # 发送认证请求
                if not r.json().get("success", False):  # 检查认证是否成功
                    raise ConnectionError("Unable to authenticate.")
                return True
            raise ConnectionError("User has not authenticated locally.")  # 如果本地未认证则抛出连接错误
        except ConnectionError:
            self.id_token = self.api_key = False  # 重置无效的 id_token 和 api_key
            LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
            return False

    # 方法尝试通过 cookies 进行认证并设置 id_token
    def auth_with_cookies(self) -> bool:
        """
        Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
        supported browser.

        Returns:
            (bool): True if authentication is successful, False otherwise.
        """
        if not IS_COLAB:
            return False  # 当前只能在 Colab 中使用
        try:
            authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")  # 使用凭据请求自动认证
            if authn.get("success", False):  # 检查认证是否成功
                self.id_token = authn.get("data", {}).get("idToken", None)  # 设置 id_token
                self.authenticate()  # 尝试验证认证信息
                return True
            raise ConnectionError("Unable to fetch browser authentication details.")  # 无法获取浏览器认证详情则抛出连接错误
        except ConnectionError:
            self.id_token = False  # 重置无效的 id_token
            return False

    # 方法用于获取用于 API 请求的认证头部信息
    def get_auth_header(self):
        """
        Get the authentication header for making API requests.

        Returns:
            (dict): The authentication header if id_token or API key is set, None otherwise.
        """
        if self.id_token:
            return {"authorization": f"Bearer {self.id_token}"}  # 返回包含 id_token 的认证头部
        elif self.api_key:
            return {"x-api-key": self.api_key}  # 返回包含 API 密钥的认证头部
        # 如果两者都未设置,则返回 None

.\yolov8\ultralytics\hub\google\__init__.py

# 导入所需的库和模块
import concurrent.futures  # 用于并发执行任务
import statistics  # 提供统计函数,如计算均值、中位数等
import time  # 提供时间相关的功能,如睡眠、计时等
from typing import List, Optional, Tuple  # 导入类型提示相关的模块

import requests  # 提供进行 HTTP 请求的功能


class GCPRegions:
    """
    A class for managing and analyzing Google Cloud Platform (GCP) regions.

    This class provides functionality to initialize, categorize, and analyze GCP regions based on their
    geographical location, tier classification, and network latency.

    Attributes:
        regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.

    Methods:
        tier1: Returns a list of tier 1 GCP regions.
        tier2: Returns a list of tier 2 GCP regions.
        lowest_latency: Determines the GCP region(s) with the lowest network latency.

    Examples:
        >>> from ultralytics.hub.google import GCPRegions
        >>> regions = GCPRegions()
        >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)
        >>> print(f"Lowest latency region: {lowest_latency_region[0][0]}")
    """
    def __init__(self):
        """Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details."""
        # 定义包含各个谷歌云平台地区及其详细信息的字典
        self.regions = {
            "asia-east1": (1, "Taiwan", "China"),
            "asia-east2": (2, "Hong Kong", "China"),
            "asia-northeast1": (1, "Tokyo", "Japan"),
            "asia-northeast2": (1, "Osaka", "Japan"),
            "asia-northeast3": (2, "Seoul", "South Korea"),
            "asia-south1": (2, "Mumbai", "India"),
            "asia-south2": (2, "Delhi", "India"),
            "asia-southeast1": (2, "Jurong West", "Singapore"),
            "asia-southeast2": (2, "Jakarta", "Indonesia"),
            "australia-southeast1": (2, "Sydney", "Australia"),
            "australia-southeast2": (2, "Melbourne", "Australia"),
            "europe-central2": (2, "Warsaw", "Poland"),
            "europe-north1": (1, "Hamina", "Finland"),
            "europe-southwest1": (1, "Madrid", "Spain"),
            "europe-west1": (1, "St. Ghislain", "Belgium"),
            "europe-west10": (2, "Berlin", "Germany"),
            "europe-west12": (2, "Turin", "Italy"),
            "europe-west2": (2, "London", "United Kingdom"),
            "europe-west3": (2, "Frankfurt", "Germany"),
            "europe-west4": (1, "Eemshaven", "Netherlands"),
            "europe-west6": (2, "Zurich", "Switzerland"),
            "europe-west8": (1, "Milan", "Italy"),
            "europe-west9": (1, "Paris", "France"),
            "me-central1": (2, "Doha", "Qatar"),
            "me-west1": (1, "Tel Aviv", "Israel"),
            "northamerica-northeast1": (2, "Montreal", "Canada"),
            "northamerica-northeast2": (2, "Toronto", "Canada"),
            "southamerica-east1": (2, "São Paulo", "Brazil"),
            "southamerica-west1": (2, "Santiago", "Chile"),
            "us-central1": (1, "Iowa", "United States"),
            "us-east1": (1, "South Carolina", "United States"),
            "us-east4": (1, "Northern Virginia", "United States"),
            "us-east5": (1, "Columbus", "United States"),
            "us-south1": (1, "Dallas", "United States"),
            "us-west1": (1, "Oregon", "United States"),
            "us-west2": (2, "Los Angeles", "United States"),
            "us-west3": (2, "Salt Lake City", "United States"),
            "us-west4": (2, "Las Vegas", "United States"),
        }

    def tier1(self) -> List[str]:
        """Returns a list of GCP regions classified as tier 1 based on predefined criteria."""
        # 返回符合预定义标准的属于第一层级的谷歌云平台地区列表
        return [region for region, info in self.regions.items() if info[0] == 1]

    def tier2(self) -> List[str]:
        """Returns a list of GCP regions classified as tier 2 based on predefined criteria."""
        # 返回符合预定义标准的属于第二层级的谷歌云平台地区列表
        return [region for region, info in self.regions.items() if info[0] == 2]

    @staticmethod
    def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:
        """Pings a specified GCP region and returns latency statistics: mean, min, max, and standard deviation."""
        # 构建请求的 URL,使用指定的 GCP 地区
        url = f"https://{region}-docker.pkg.dev"
        # 存储每次请求的延迟时间
        latencies = []
        # 尝试多次请求
        for _ in range(attempts):
            try:
                # 记录请求开始时间
                start_time = time.time()
                # 发送 HEAD 请求到指定 URL,设置超时时间为 5 秒
                _ = requests.head(url, timeout=5)
                # 计算请求完成后的延迟时间(毫秒)
                latency = (time.time() - start_time) * 1000  # convert latency to milliseconds
                # 如果延迟时间不是无穷大,则添加到延迟时间列表中
                if latency != float("inf"):
                    latencies.append(latency)
            except requests.RequestException:
                pass
        # 如果未成功获取任何延迟数据,则返回无穷大的统计数据
        if not latencies:
            return region, float("inf"), float("inf"), float("inf"), float("inf")

        # 计算延迟时间的标准差,如果样本数大于1
        std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0
        # 返回地区名称及其延迟统计数据:平均值、标准差、最小值、最大值
        return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)

    def lowest_latency(
        self,
        top: int = 1,
        verbose: bool = False,
        tier: Optional[int] = None,
        attempts: int = 1,
    # 返回一个列表,包含元组,每个元组代表 GCP 地区的延迟统计信息
    # 每个元组包含 (地区名, 平均延迟, 标准差, 最小延迟, 最大延迟)
    def lowest_latency(self, top: int, verbose: bool, tier: Optional[int], attempts: int) -> List[Tuple[str, float, float, float, float]]:
        """
        Determines the GCP regions with the lowest latency based on ping tests.

        Args:
            top (int): Number of top regions to return.
            verbose (bool): If True, prints detailed latency information for all tested regions.
            tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
            attempts (int): Number of ping attempts per region.

        Returns:
            (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
            latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).

        Examples:
            >>> regions = GCPRegions()
            >>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)
            >>> print(results[0][0])  # Print the name of the lowest latency region
        """
        # 如果 verbose 为 True,打印正在进行的 ping 测试信息
        if verbose:
            print(f"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...")

        # 根据 tier 条件过滤要测试的地区列表
        regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys())
        
        # 使用 ThreadPoolExecutor 并发执行 ping 测试
        with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
            results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))

        # 根据平均延迟对结果进行排序
        sorted_results = sorted(results, key=lambda x: x[1])

        # 如果 verbose 为 True,打印详细的延迟信息表格
        if verbose:
            print(f"{'Region':<25} {'Location':<35} {'Tier':<5} {'Latency (ms)'}")
            for region, mean, std, min_, max_ in sorted_results:
                tier, city, country = self.regions[region]
                location = f"{city}, {country}"
                if mean == float("inf"):
                    print(f"{region:<25} {location:<35} {tier:<5} {'Timeout'}")
                else:
                    print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})")
            print(f"\nLowest latency region{'s' if top > 1 else ''}:")
            for region, mean, std, min_, max_ in sorted_results[:top]:
                tier, city, country = self.regions[region]
                location = f"{city}, {country}"
                print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))")

        # 返回延迟最低的前 top 个地区的信息列表
        return sorted_results[:top]
# 如果脚本被直接执行(而不是被导入为模块),则执行以下代码
if __name__ == "__main__":
    # 创建一个 GCPRegions 的实例对象
    regions = GCPRegions()
    # 调用 lowest_latency 方法来获取最低延迟的地区列表
    # 参数解释:
    #   top=3: 获取延迟最低的前三个地区
    #   verbose=True: 打印详细信息,例如每次尝试的信息
    #   tier=1: 限定在第一层次的数据中进行选择
    #   attempts=3: 尝试获取数据的最大次数
    top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3)

.\yolov8\ultralytics\hub\session.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import threading  # 导入多线程支持模块
import time  # 导入时间模块
from http import HTTPStatus  # 导入HTTP状态码模块
from pathlib import Path  # 导入路径操作模块

import requests  # 导入HTTP请求模块

from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM  # 导入Ultralytics HUB的工具模块
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis  # 导入Ultralytics的工具函数和常量
from ultralytics.utils.errors import HUBModelError  # 导入自定义的错误类

AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local"  # 根据是否在Colab环境中设置代理名称


class HUBTrainingSession:
    """
    HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.

    Attributes:
        model_id (str): Identifier for the YOLO model being trained.
        model_url (str): URL for the model in Ultralytics HUB.
        rate_limits (dict): Rate limits for different API calls (in seconds).
        timers (dict): Timers for rate limiting.
        metrics_queue (dict): Queue for the model's metrics.
        model (dict): Model data fetched from Ultralytics HUB.
    """

    def __init__(self, identifier):
        """
        Initialize the HUBTrainingSession with the provided model identifier.

        Args:
            identifier (str): Model identifier used to initialize the HUB training session.
                It can be a URL string or a model key with specific format.

        Raises:
            ValueError: If the provided model identifier is invalid.
            ConnectionError: If connecting with global API key is not supported.
            ModuleNotFoundError: If hub-sdk package is not installed.
        """
        from hub_sdk import HUBClient  # 导入HUBClient类来进行与Ultralytics HUB的API交互

        self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300}  # 设置API调用的速率限制(秒)
        self.metrics_queue = {}  # 存储每个epoch的指标,直到上传
        self.metrics_upload_failed_queue = {}  # 存储上传失败的每个epoch的指标
        self.timers = {}  # 在ultralytics/utils/callbacks/hub.py中保存计时器
        self.model = None  # 初始化模型数据为None
        self.model_url = None  # 初始化模型URL为None
        self.model_file = None  # 初始化模型文件为None

        # 解析输入的标识符
        api_key, model_id, self.filename = self._parse_identifier(identifier)

        # 获取凭证
        active_key = api_key or SETTINGS.get("api_key")
        credentials = {"api_key": active_key} if active_key else None  # 设置凭证信息

        # 初始化客户端
        self.client = HUBClient(credentials)

        # 如果认证成功则加载模型
        if self.client.authenticated:
            if model_id:
                self.load_model(model_id)  # 加载现有模型
            else:
                self.model = self.client.model()  # 加载空模型

    @classmethod
    def create_session(cls, identifier, args=None):
        """Class method to create an authenticated HUBTrainingSession or return None."""
        try:
            # 尝试创建一个指定标识符的会话对象
            session = cls(identifier)
            # 检查客户端是否已认证
            if not session.client.authenticated:
                # 如果未认证且标识符以指定路径开始,则警告并退出程序
                if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
                    LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
                    exit()
                return None
            # 如果提供了参数且标识符不是 HUB 模型的 URL,则创建模型
            if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"):  # not a HUB model URL
                session.create_model(args)
                # 断言模型已加载正确
                assert session.model.id, "HUB model not loaded correctly"
            # 返回创建的会话对象
            return session
        # 处理权限错误或模块未找到异常,表明 hub-sdk 未安装
        except (PermissionError, ModuleNotFoundError, AssertionError):
            return None

    def load_model(self, model_id):
        """Loads an existing model from Ultralytics HUB using the provided model identifier."""
        # 通过提供的模型标识符加载现有模型
        self.model = self.client.model(model_id)
        # 如果模型数据不存在,则抛出值错误异常
        if not self.model.data:  # then model does not exist
            raise ValueError(emojis("❌ The specified HUB model does not exist"))  # TODO: improve error handling

        # 设置模型的 URL
        self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
        # 如果模型已经训练完成
        if self.model.is_trained():
            # 输出加载已训练的 HUB 模型的信息
            print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
            # 获取模型权重的 URL
            self.model_file = self.model.get_weights_url("best")
            return

        # 设置训练参数并启动 HUB 监控代理的心跳
        self._set_train_args()
        self.model.start_heartbeat(self.rate_limits["heartbeat"])
        # 输出模型的 URL
        LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
    def create_model(self, model_args):
        """Initializes a HUB training session with the specified model identifier."""
        # 构造包含训练参数的 payload 对象
        payload = {
            "config": {
                "batchSize": model_args.get("batch", -1),  # 设置批量大小,默认为-1
                "epochs": model_args.get("epochs", 300),   # 设置训练周期数,默认为300
                "imageSize": model_args.get("imgsz", 640),  # 设置图像大小,默认为640
                "patience": model_args.get("patience", 100),  # 设置训练耐心值,默认为100
                "device": str(model_args.get("device", "")),  # 设置设备类型,将None转换为字符串
                "cache": str(model_args.get("cache", "ram")),  # 设置缓存类型,将True、False、None转换为字符串
            },
            "dataset": {"name": model_args.get("data")},  # 设置数据集名称
            "lineage": {
                "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},  # 设置模型架构名称
                "parent": {},  # 初始化父模型信息
            },
            "meta": {"name": self.filename},  # 设置模型元数据名称
        }

        if self.filename.endswith(".pt"):
            payload["lineage"]["parent"]["name"] = self.filename  # 如果文件名以.pt结尾,设置父模型名称为文件名

        self.model.create_model(payload)  # 调用模型对象的创建模型方法,使用payload作为参数

        # Model could not be created
        # TODO: improve error handling
        # 如果模型未成功创建,记录错误并返回None
        if not self.model.id:
            return None

        self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"  # 构造模型的URL链接

        # Start heartbeats for HUB to monitor agent
        # 启动心跳以便HUB监控代理
        self.model.start_heartbeat(self.rate_limits["heartbeat"])

        LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")  # 记录模型的访问链接
    def _parse_identifier(identifier):
        """
        Parses the given identifier to determine the type of identifier and extract relevant components.
        
        The method supports different identifier formats:
            - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
            - An identifier containing an API key and a model ID separated by an underscore
            - An identifier that is solely a model ID of a fixed length
            - A local filename that ends with '.pt' or '.yaml'
        
        Args:
            identifier (str): The identifier string to be parsed.
        
        Returns:
            (tuple): A tuple containing the API key, model ID, and filename as applicable.
        
        Raises:
            HUBModelError: If the identifier format is not recognized.
        """

        # Initialize variables to None
        api_key, model_id, filename = None, None, None

        # Check if identifier is a HUB URL
        if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
            # Extract the model_id after the HUB_WEB_ROOT URL
            model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
        else:
            # Split the identifier based on underscores only if it's not a HUB URL
            parts = identifier.split("_")

            # Check if identifier is in the format of API key and model ID
            if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
                api_key, model_id = parts
            # Check if identifier is a single model ID
            elif len(parts) == 1 and len(parts[0]) == 20:
                model_id = parts[0]
            # Check if identifier is a local filename
            elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
                filename = identifier
            else:
                # Raise an error if identifier format does not match any supported format
                raise HUBModelError(
                    f"model='{identifier}' could not be parsed. Check format is correct. "
                    f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
                )

        # Return the extracted components as a tuple
        return api_key, model_id, filename
    def _set_train_args(self):
        """
        Initializes training arguments and creates a model entry on the Ultralytics HUB.

        This method sets up training arguments based on the model's state and updates them with any additional
        arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
        or requires specific file setup.

        Raises:
            ValueError: If the model is already trained, if required dataset information is missing, or if there are
                issues with the provided training arguments.
        """

        if self.model.is_resumable():
            # Model has saved weights
            self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
            self.model_file = self.model.get_weights_url("last")
        else:
            # Model has no saved weights
            self.train_args = self.model.data.get("train_args")  # 从模型数据中获取训练参数
            # 设置模型文件,可以是 *.pt 或 *.yaml 文件
            self.model_file = (
                self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
            )

        if "data" not in self.train_args:
            # RF bug - datasets are sometimes not exported
            raise ValueError("Dataset may still be processing. Please wait a minute and try again.")

        self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False)  # 检查并纠正文件名
        self.model_id = self.model.id

    def request_queue(
        self,
        request_func,
        retry=3,
        timeout=30,
        thread=True,
        verbose=True,
        progress_total=None,
        stream_response=None,
        *args,
        **kwargs,
    ):
        """
        Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress.
        """

        def retry_request():
            """
            Attempts to call `request_func` with retries, timeout, and optional threading.
            """
            t0 = time.time()  # Record the start time for the timeout
            response = None
            for i in range(retry + 1):
                if (time.time() - t0) > timeout:
                    LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
                    break  # Timeout reached, exit loop

                response = request_func(*args, **kwargs)
                if response is None:
                    LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
                    time.sleep(2**i)  # Exponential backoff before retrying
                    continue  # Skip further processing and retry

                if progress_total:
                    self._show_upload_progress(progress_total, response)
                elif stream_response:
                    self._iterate_content(response)

                if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
                    # if request related to metrics upload
                    if kwargs.get("metrics"):
                        self.metrics_upload_failed_queue = {}
                    return response  # Success, no need to retry

                if i == 0:
                    # Initial attempt, check status code and provide messages
                    message = self._get_failure_message(response, retry, timeout)

                    if verbose:
                        LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")

                if not self._should_retry(response.status_code):
                    LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code})")
                    break  # Not an error that should be retried, exit loop

                time.sleep(2**i)  # Exponential backoff for retries

            # if request related to metrics upload and exceed retries
            if response is None and kwargs.get("metrics"):
                self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))

            return response

        if thread:
            # Start a new thread to run the retry_request function
            threading.Thread(target=retry_request, daemon=True).start()
        else:
            # If running in the main thread, call retry_request directly
            return retry_request()

    @staticmethod
    def _should_retry(status_code):
        """
        Determines if a request should be retried based on the HTTP status code.
        """
        retry_codes = {
            HTTPStatus.REQUEST_TIMEOUT,
            HTTPStatus.BAD_GATEWAY,
            HTTPStatus.GATEWAY_TIMEOUT,
        }
        return status_code in retry_codes
    def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
        """
        Generate a retry message based on the response status code.

        Args:
            response: The HTTP response object.
            retry: The number of retry attempts allowed.
            timeout: The maximum timeout duration.

        Returns:
            (str): The retry message.
        """
        # 如果应该重试,返回重试信息,包括重试次数和超时时间
        if self._should_retry(response.status_code):
            return f"Retrying {retry}x for {timeout}s." if retry else ""
        # 如果响应状态码为429(太多请求),则显示速率限制信息
        elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS:  # rate limit
            headers = response.headers
            return (
                f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
                f"Please retry after {headers['Retry-After']}s."
            )
        else:
            try:
                # 尝试从响应中读取JSON格式的消息,如果无法读取则返回默认消息
                return response.json().get("message", "No JSON message.")
            except AttributeError:
                # 如果无法读取JSON,则返回无法读取JSON的提示信息
                return "Unable to read JSON."

    def upload_metrics(self):
        """Upload model metrics to Ultralytics HUB."""
        # 将模型指标上传到Ultralytics HUB,并返回请求队列的结果
        return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)

    def upload_model(
        self,
        epoch: int,
        weights: str,
        is_best: bool = False,
        map: float = 0.0,
        final: bool = False,
    ) -> None:
        """
        Upload a model checkpoint to Ultralytics HUB.

        Args:
            epoch (int): The current training epoch.
            weights (str): Path to the model weights file.
            is_best (bool): Indicates if the current model is the best one so far.
            map (float): Mean average precision of the model.
            final (bool): Indicates if the model is the final model after training.
        """
        # 如果指定的模型权重文件存在
        if Path(weights).is_file():
            # 获取模型文件的总大小(仅在最终上传时显示进度)
            progress_total = Path(weights).stat().st_size if final else None  # Only show progress if final
            # 请求队列将模型上传到Ultralytics HUB,包括各种参数和选项
            self.request_queue(
                self.model.upload_model,
                epoch=epoch,
                weights=weights,
                is_best=is_best,
                map=map,
                final=final,
                retry=10,
                timeout=3600,
                thread=not final,
                progress_total=progress_total,
                stream_response=True,
            )
        else:
            # 如果指定的模型权重文件不存在,则记录警告信息
            LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")

    @staticmethod
    # 显示文件下载进度条,用于跟踪文件下载过程中的进度
    def _show_upload_progress(content_length: int, response: requests.Response) -> None:
        """
        Display a progress bar to track the upload progress of a file download.

        Args:
            content_length (int): The total size of the content to be downloaded in bytes.
            response (requests.Response): The response object from the file download request.

        Returns:
            None
        """
        # 使用 tqdm 创建进度条,总大小为 content_length,单位为 B,自动缩放单位
        with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
            # 遍历响应中的数据块,更新进度条
            for data in response.iter_content(chunk_size=1024):
                pbar.update(len(data))

    @staticmethod
    # 静态方法:处理流式 HTTP 响应数据
    def _iterate_content(response: requests.Response) -> None:
        """
        Process the streamed HTTP response data.

        Args:
            response (requests.Response): The response object from the file download request.

        Returns:
            None
        """
        # 遍历响应中的数据块,但不对数据块做任何操作
        for _ in response.iter_content(chunk_size=1024):
            pass  # Do nothing with data chunks

.\yolov8\ultralytics\hub\utils.py

# 导入所需的库
import os
import platform
import random
import threading
import time
from pathlib import Path

# 导入第三方库 requests
import requests

# 导入 ultralytics.utils 下的多个模块和函数
from ultralytics.utils import (
    ARGV,
    ENVIRONMENT,
    IS_COLAB,
    IS_GIT_DIR,
    IS_PIP_PACKAGE,
    LOGGER,
    ONLINE,
    RANK,
    SETTINGS,
    TESTS_RUNNING,
    TQDM,
    TryExcept,
    __version__,
    colorstr,
    get_git_origin_url,
)
# 导入 ultralytics.utils.downloads 模块中的 GITHUB_ASSETS_NAMES
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES

# 设置 HUB_API_ROOT 和 HUB_WEB_ROOT 变量,若环境变量 ULTRALYTICS_HUB_API 或 ULTRALYTICS_HUB_WEB 未定义,则使用默认值
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")

# 使用 colorstr 函数创建 PREFIX 变量,用于打印带颜色的文本前缀
PREFIX = colorstr("Ultralytics HUB: ")
# 定义帮助信息字符串
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."


def request_with_credentials(url: str) -> any:
    """
    在 Google Colab 环境中发送带有附加 cookies 的 AJAX 请求。

    Args:
        url (str): 要发送请求的 URL。

    Returns:
        (any): AJAX 请求的响应数据。

    Raises:
        OSError: 如果函数不在 Google Colab 环境中运行。
    """
    # 如果不在 Colab 环境中,则抛出 OSError 异常
    if not IS_COLAB:
        raise OSError("request_with_credentials() must run in a Colab environment")
    
    # 导入必要的 Colab 相关库
    from google.colab import output  # noqa
    from IPython import display  # noqa

    # 使用 display.Javascript 创建一个 AJAX 请求,并附加 cookies
    display.display(
        display.Javascript(
            """
            window._hub_tmp = new Promise((resolve, reject) => {
                const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
                fetch("%s", {
                    method: 'POST',
                    credentials: 'include'
                })
                    .then((response) => resolve(response.json()))
                    .then((json) => {
                    clearTimeout(timeout);
                    }).catch((err) => {
                    clearTimeout(timeout);
                    reject(err);
                });
            });
            """
            % url
        )
    )
    # 返回输出的结果
    return output.eval_js("_hub_tmp")


def requests_with_progress(method, url, **kwargs):
    """
    使用指定的方法和 URL 发送 HTTP 请求,支持可选的进度条显示。

    Args:
        method (str): 要使用的 HTTP 方法 (例如 'GET'、'POST')。
        url (str): 要发送请求的 URL。
        **kwargs (any): 传递给底层 `requests.request` 函数的其他关键字参数。

    Returns:
        (requests.Response): HTTP 请求的响应对象。

    Note:
        - 如果 'progress' 设置为 True,则进度条将显示已知内容长度的下载进度。
        - 如果 'progress' 是一个数字,则进度条将显示假设内容长度为 'progress' 的下载进度。
    """
    # 弹出 kwargs 中的 progress 参数,默认为 False
    progress = kwargs.pop("progress", False)
    # 如果 progress 为 False,则直接发送请求
    if not progress:
        return requests.request(method, url, **kwargs)
    # 发起 HTTP 请求并获取响应
    response = requests.request(method, url, stream=True, **kwargs)
    # 从响应头中获取内容长度信息,如果 progress 参数是布尔值则返回内容长度,否则返回 progress 参数的值作为总大小
    total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress)  # total size
    try:
        # 初始化进度条对象,显示总大小并按照适当的单位进行缩放
        pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
        # 逐块迭代响应数据流,每次更新进度条
        for data in response.iter_content(chunk_size=1024):
            pbar.update(len(data))
        # 关闭进度条
        pbar.close()
    except requests.exceptions.ChunkedEncodingError:  # 避免出现 'Connection broken: IncompleteRead' 的警告
        # 关闭响应以处理异常
        response.close()
    # 返回完整的 HTTP 响应对象
    return response
    """
    Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.

    Args:
        method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
        url (str): The URL to make the request to.
        retry (int, optional): Number of retries to attempt before giving up. Default is 3.
        timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
        thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
        code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
        verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
        progress (bool, optional): Whether to show a progress bar during the request. Default is False.
        **kwargs (any): Keyword arguments to be passed to the requests function specified in method.

    Returns:
        (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
    """
    retry_codes = (408, 500)  # retry only these codes

    # Decorator to handle exceptions and log messages
    @TryExcept(verbose=verbose)
    def func(func_method, func_url, **func_kwargs):
        """Make HTTP requests with retries and timeouts, with optional progress tracking."""
        r = None  # response object
        t0 = time.time()  # start time for timeout
        for i in range(retry + 1):
            if (time.time() - t0) > timeout:
                break
            # Perform HTTP request with progress tracking if enabled
            r = requests_with_progress(func_method, func_url, **func_kwargs)
            # Check if response status code indicates success
            if r.status_code < 300:
                break
            try:
                m = r.json().get("message", "No JSON message.")
            except AttributeError:
                m = "Unable to read JSON."
            # Handle retry logic based on response status code
            if i == 0:
                if r.status_code in retry_codes:
                    m += f" Retrying {retry}x for {timeout}s." if retry else ""
                elif r.status_code == 429:  # rate limit exceeded
                    h = r.headers  # response headers
                    m = (
                        f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
                        f"Please retry after {h['Retry-After']}s."
                    )
                if verbose:
                    LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
                # Return response if no need to retry
                if r.status_code not in retry_codes:
                    return r
            time.sleep(2**i)  # exponential backoff wait
        return r

    # Prepare arguments and pass progress flag to kwargs
    args = method, url
    kwargs["progress"] = progress
    # 如果 thread 参数为真,则创建一个新线程并启动,运行 func 函数
    if thread:
        threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
    # 如果 thread 参数为假,则直接调用 func 函数并返回其结果
    else:
        return func(*args, **kwargs)
class Events:
    """
    A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
    disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.

    Attributes:
        url (str): The URL to send anonymous events.
        rate_limit (float): The rate limit in seconds for sending events.
        metadata (dict): A dictionary containing metadata about the environment.
        enabled (bool): A flag to enable or disable Events based on certain conditions.
    """

    # 设置 Google Analytics 收集匿名事件的 URL
    url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"

    def __init__(self):
        """Initializes the Events object with default values for events, rate_limit, and metadata."""
        # 初始化事件列表
        self.events = []  # events list
        # 设置事件发送的速率限制(单位:秒)
        self.rate_limit = 30.0  # rate limit (seconds)
        # 初始化事件发送的计时器(单位:秒)
        self.t = 0.0  # rate limit timer (seconds)
        # 设置环境的元数据
        self.metadata = {
            "cli": Path(ARGV[0]).name == "yolo",  # 检查命令行是否为 'yolo'
            "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",  # 检查安装方式是 git 还是 pip 或其他
            "python": ".".join(platform.python_version_tuple()[:2]),  # Python 版本号,例如 3.10
            "version": __version__,  # 从模块中获取版本号
            "env": ENVIRONMENT,  # 获取环境变量
            "session_id": round(random.random() * 1e15),  # 创建随机会话 ID
            "engagement_time_msec": 1000,  # 设置参与时间(毫秒)
        }
        # 根据设置和其他条件,确定是否启用事件收集
        self.enabled = (
            SETTINGS["sync"]  # 检查是否设置为同步
            and RANK in {-1, 0}  # 检查当前排名是否为 -1 或 0
            and not TESTS_RUNNING  # 确保没有正在运行的测试
            and ONLINE  # 确保在线状态
            and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")  # 检查安装来源是否为指定的 GitHub 仓库
        )
    # 定义一个特殊方法 __call__(),使实例可以像函数一样被调用
    def __call__(self, cfg):
        """
        Attempts to add a new event to the events list and send events if the rate limit is reached.

        Args:
            cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
        """
        # 如果事件功能未启用,直接返回,不执行任何操作
        if not self.enabled:
            # Events disabled, do nothing
            return

        # 尝试添加事件到事件列表
        if len(self.events) < 25:  # 事件列表最多包含 25 个事件,超过部分将被丢弃
            # 构建事件参数字典,包括元数据和配置的任务和模型信息
            params = {
                **self.metadata,
                "task": cfg.task,
                "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
            }
            # 如果配置模式为 "export",则添加格式信息到参数字典中
            if cfg.mode == "export":
                params["format"] = cfg.format
            # 将新事件以字典形式添加到事件列表中
            self.events.append({"name": cfg.mode, "params": params})

        # 检查发送速率限制
        t = time.time()
        if (t - self.t) < self.rate_limit:
            # 如果发送时间间隔未超过限制,等待发送
            return

        # 如果时间间隔超过限制,立即发送事件数据
        data = {"client_id": SETTINGS["uuid"], "events": self.events}  # 使用 SHA-256 匿名化的 UUID 哈希和事件列表

        # 发送 POST 请求,相当于 requests.post(self.url, json=data),不进行重试和输出详细信息
        smart_request("post", self.url, json=data, retry=0, verbose=False)

        # 重置事件列表和发送时间计时器
        self.events = []
        self.t = t
# 在 hub/utils 初始化中运行以下代码
events = Events()

.\yolov8\ultralytics\hub\__init__.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import requests  # 导入requests库,用于发送HTTP请求

from ultralytics.data.utils import HUBDatasetStats  # 导入HUBDatasetStats工具类
from ultralytics.hub.auth import Auth  # 导入Auth类,用于认证
from ultralytics.hub.session import HUBTrainingSession  # 导入HUBTrainingSession类,用于处理训练会话
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events  # 导入常量和事件
from ultralytics.utils import LOGGER, SETTINGS, checks  # 导入日志记录器、设置和检查工具

__all__ = (
    "PREFIX",
    "HUB_WEB_ROOT",
    "HUBTrainingSession",
    "login",
    "logout",
    "reset_model",
    "export_fmts_hub",
    "export_model",
    "get_export",
    "check_dataset",
    "events",
)


def login(api_key: str = None, save=True) -> bool:
    """
    Log in to the Ultralytics HUB API using the provided API key.

    The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
    environment variable if successfully authenticated.

    Args:
        api_key (str, optional): API key to use for authentication.
            If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
        save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.

    Returns:
        (bool): True if authentication is successful, False otherwise.
    """
    checks.check_requirements("hub-sdk>=0.0.8")  # 检查是否满足SDK的最低版本要求
    from hub_sdk import HUBClient  # 导入HUBClient类来进行HUB API的客户端操作

    api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys"  # 设置API密钥设置页面的重定向URL
    saved_key = SETTINGS.get("api_key")  # 获取保存在SETTINGS中的API密钥
    active_key = api_key or saved_key  # 使用提供的API密钥或从环境变量中获取的API密钥

    credentials = {"api_key": active_key} if active_key and active_key != "" else None  # 设置认证凭据

    client = HUBClient(credentials)  # 初始化HUBClient客户端对象

    if client.authenticated:
        # 成功通过HUB认证

        if save and client.api_key != saved_key:
            SETTINGS.update({"api_key": client.api_key})  # 更新SETTINGS中的有效API密钥

        # 根据是否提供了API密钥或从设置中检索到来设置消息内容
        log_message = (
            "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
        )
        LOGGER.info(f"{PREFIX}{log_message}")  # 记录认证成功信息到日志

        return True
    else:
        # 未能通过HUB认证
        LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo hub login API_KEY'")
        return False


def logout():
    """
    Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.

    Example:
        ```py
        from ultralytics import hub

        hub.logout()
        ```
    """
    SETTINGS["api_key"] = ""  # 清空SETTINGS中的API密钥
    SETTINGS.save()  # 保存SETTINGS变更
    LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")  # 记录退出登录信息到日志


def reset_model(model_id=""):
    """Reset a trained model to an untrained state."""
    r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
    # 发送POST请求到HUB API以重置指定model_id的模型为未训练状态
    # 检查 HTTP 响应状态码是否为 200
    if r.status_code == 200:
        # 如果响应状态码为 200,记录信息日志,表示模型重置成功
        LOGGER.info(f"{PREFIX}Model reset successfully")
        # 返回空,结束函数执行
        return
    
    # 如果响应状态码不为 200,记录警告日志,表示模型重置失败,并包含响应的状态码和原因
    LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
def export_fmts_hub():
    """Returns a list of HUB-supported export formats."""
    # 导入 export_formats 函数,该函数位于 ultralytics.engine.exporter 模块中
    from ultralytics.engine.exporter import export_formats
    # 返回 export_formats 函数返回值的第二个元素至最后一个元素(不包括第一个元素),并添加两个特定的输出格式
    return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]


def export_model(model_id="", format="torchscript"):
    """Export a model to all formats."""
    # 断言指定的导出格式在支持的格式列表中,如果不支持则抛出 AssertionError
    assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
    # 发起 POST 请求,导出指定模型到指定格式,并使用 API 密钥进行身份验证
    r = requests.post(
        f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
    )
    # 断言请求的状态码为 200,否则抛出 AssertionError,显示错误信息
    assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
    # 记录导出操作开始的信息
    LOGGER.info(f"{PREFIX}{format} export started ✅")


def get_export(model_id="", format="torchscript"):
    """Get an exported model dictionary with download URL."""
    # 断言指定的导出格式在支持的格式列表中,如果不支持则抛出 AssertionError
    assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
    # 发起 POST 请求,获取导出的模型字典及其下载链接,并使用 API 密钥进行身份验证
    r = requests.post(
        f"{HUB_API_ROOT}/get-export",
        json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
        headers={"x-api-key": Auth().api_key},
    )
    # 断言请求的状态码为 200,否则抛出 AssertionError,显示错误信息
    assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
    # 返回从响应中解析得到的 JSON 格式的导出模型字典
    return r.json()


def check_dataset(path: str, task: str) -> None:
    """
    Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
    to the HUB. Usage examples are given below.

    Args:
        path (str): Path to data.zip (with data.yaml inside data.zip).
        task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.

    Example:
        Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
            i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
        ```py
        from ultralytics.hub import check_dataset

        check_dataset('path/to/coco8.zip', task='detect')  # detect dataset
        check_dataset('path/to/coco8-seg.zip', task='segment')  # segment dataset
        check_dataset('path/to/coco8-pose.zip', task='pose')  # pose dataset
        check_dataset('path/to/dota8.zip', task='obb')  # OBB dataset
        check_dataset('path/to/imagenet10.zip', task='classify')  # classification dataset
        ```
    """
    # 使用 HUBDatasetStats 类检查指定路径下的数据集文件(zip 格式),并为指定任务类型生成 JSON 格式的统计信息
    HUBDatasetStats(path=path, task=task).get_json()
    # 记录检查操作成功完成的信息
    LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")

.\yolov8\ultralytics\models\fastsam\model.py

# 导入必要的模块和类
from pathlib import Path
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator

# 定义 FastSAM 类,继承自 Model 类
class FastSAM(Model):
    """
    FastSAM 模型接口。

    Example:
        ```python
        from ultralytics import FastSAM

        model = FastSAM('last.pt')
        results = model.predict('ultralytics/assets/bus.jpg')
        ```py
    """

    def __init__(self, model="FastSAM-x.pt"):
        """初始化方法,调用父类(YOLO)的 __init__ 方法,使用更新后的默认模型名称。"""
        # 如果模型名称为 "FastSAM.pt",则修改为 "FastSAM-x.pt"
        if str(model) == "FastSAM.pt":
            model = "FastSAM-x.pt"
        # 断言模型文件的后缀不是 .yaml 或 .yml,因为 FastSAM 模型只支持预训练模型
        assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
        # 调用父类的初始化方法,传入模型名称和任务类型 "segment"
        super().__init__(model=model, task="segment")

    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
        """
        对给定的图像或视频源进行分割预测。

        Args:
            source (str): 图像或视频文件的路径,或者是 PIL.Image 对象,或者是 numpy.ndarray 对象。
            stream (bool, optional): 如果为 True,则启用实时流处理。默认为 False。
            bboxes (list, optional): 提供分割提示的边界框坐标列表。默认为 None。
            points (list, optional): 提供分割提示的点列表。默认为 None。
            labels (list, optional): 提供分割提示的标签列表。默认为 None。
            texts (list, optional): 提供分割提示的文本列表。默认为 None。

        Returns:
            (list): 模型的预测结果列表。
        """
        # 将提示信息组织成字典
        prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
        # 调用父类的 predict 方法进行预测,并传入参数和提示信息
        return super().predict(source, stream, prompts=prompts, **kwargs)

    @property
    def task_map(self):
        """返回一个字典,将分割任务映射到相应的预测器和验证器类。"""
        return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}

.\yolov8\ultralytics\models\fastsam\predict.py

# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from PIL import Image

from ultralytics.models.yolo.segment import SegmentationPredictor  # 导入分割预测器类
from ultralytics.utils import DEFAULT_CFG, checks  # 导入默认配置和检查工具
from ultralytics.utils.metrics import box_iou  # 导入 IoU 计算工具
from ultralytics.utils.ops import scale_masks  # 导入 mask 缩放操作

from .utils import adjust_bboxes_to_image_border  # 导入边界框调整函数


class FastSAMPredictor(SegmentationPredictor):
    """
    FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
    YOLO framework.

    This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
    adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
    class segmentation.
    """

    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        # 调用父类构造函数,初始化 FastSAMPredictor 对象
        super().__init__(cfg, overrides, _callbacks)
        # 初始化提示信息为空字典
        self.prompts = {}

    def postprocess(self, preds, img, orig_imgs):
        """Applies box postprocess for FastSAM predictions."""
        # 从提示信息中取出边界框、点、标签和文本信息
        bboxes = self.prompts.pop("bboxes", None)
        points = self.prompts.pop("points", None)
        labels = self.prompts.pop("labels", None)
        texts = self.prompts.pop("texts", None)
        # 调用父类的 postprocess 方法进行预测结果后处理
        results = super().postprocess(preds, img, orig_imgs)
        # 遍历每个结果
        for result in results:
            # 创建一个包含整个图像边界的框
            full_box = torch.tensor(
                [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
            )
            # 调整结果中的边界框,使其适应图像边界
            boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
            # 找到与整个图像边界框 IoU 大于 0.9 的边界框索引
            idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
            # 如果找到匹配的边界框索引,则将这些边界框设置为整个图像边界框
            if idx.numel() != 0:
                result.boxes.xyxy[idx] = full_box

        # 返回处理后的结果,并将原始提示信息传递给下一个函数
        return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
    def _clip_inference(self, images, texts):
        """
        CLIP Inference process.

        Args:
            images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
            texts (List[str]): A list of prompt texts and each of them should be string object.

        Returns:
            (torch.Tensor): The similarity between given images and texts.
        """
        try:
            import clip  # 尝试导入 CLIP 库
        except ImportError:
            checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")  # 如果导入失败,则检查并安装所需的依赖
            import clip  # 再次尝试导入 CLIP 库

        if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
            self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
            # 如果对象实例中没有 clip_model 或 clip_preprocess 属性,则加载 CLIP 模型和预处理器

        images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
        # 将输入的图像列表转换为 torch 张量,并使用 clip_preprocess 进行预处理,并移到设备上

        tokenized_text = clip.tokenize(texts).to(self.device)
        # 对输入的文本列表进行标记化,并移到设备上

        image_features = self.clip_model.encode_image(images)
        # 使用 CLIP 模型对图像进行编码,得到图像特征

        text_features = self.clip_model.encode_text(tokenized_text)
        # 使用 CLIP 模型对文本进行编码,得到文本特征

        image_features /= image_features.norm(dim=-1, keepdim=True)  # 对图像特征进行归一化处理
        text_features /= text_features.norm(dim=-1, keepdim=True)  # 对文本特征进行归一化处理

        return (image_features * text_features[:, None]).sum(-1)  # 计算图像和文本之间的相似性
        # 返回图像和文本之间的相似性得分,形状为 (M, N)

    def set_prompts(self, prompts):
        """Set prompts in advance."""
        self.prompts = prompts
        # 设置预设提示信息,存储在对象实例的 prompts 属性中

.\yolov8\ultralytics\models\fastsam\utils.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 调整边界框使其在一定阈值内粘合到图像边界

def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
    """
    Adjust bounding boxes to stick to image border if they are within a certain threshold.

    Args:
        boxes (torch.Tensor): (n, 4) 边界框坐标
        image_shape (tuple): (height, width) 图像高度和宽度
        threshold (int): pixel threshold 像素阈值

    Returns:
        adjusted_boxes (torch.Tensor): adjusted bounding boxes 调整后的边界框
    """

    # 图像尺寸
    h, w = image_shape

    # 调整边界框
    boxes[boxes[:, 0] < threshold, 0] = 0  # x1 左上角 x 坐标
    boxes[boxes[:, 1] < threshold, 1] = 0  # y1 左上角 y 坐标
    boxes[boxes[:, 2] > w - threshold, 2] = w  # x2 右下角 x 坐标
    boxes[boxes[:, 3] > h - threshold, 3] = h  # y2 右下角 y 坐标
    return boxes

.\yolov8\ultralytics\models\fastsam\val.py

# 导入Ultralytics YOLO框架中的相关模块和类
from ultralytics.models.yolo.segment import SegmentationValidator
from ultralytics.utils.metrics import SegmentMetrics

# 定义一个名为FastSAMValidator的类,继承自SegmentationValidator类
class FastSAMValidator(SegmentationValidator):
    """
    Ultralytics YOLO框架中用于快速SAM(Segment Anything Model)分割的自定义验证类。

    继承SegmentationValidator类,专门为快速SAM定制验证过程。该类将任务设置为'segment',
    并使用SegmentMetrics进行评估。此外,禁用绘图功能以避免在验证过程中出现错误。

    Attributes:
        dataloader: 用于验证的数据加载器对象。
        save_dir (str): 保存验证结果的目录。
        pbar: 进度条对象,用于显示进度。
        args: 用于定制的额外参数。
        _callbacks: 需要在验证期间调用的回调函数列表。
    """

    def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
        """
        初始化FastSAMValidator类,将任务设置为'segment',并使用SegmentMetrics作为度量标准。

        Args:
            dataloader (torch.utils.data.DataLoader): 用于验证的数据加载器。
            save_dir (Path, optional): 保存结果的目录。
            pbar (tqdm.tqdm): 显示进度的进度条。
            args (SimpleNamespace): 验证器的配置。
            _callbacks (dict): 存储各种回调函数的字典。

        Notes:
            禁用此类中的ConfusionMatrix和其他相关度量标准的绘图功能,以避免错误。
        """
        # 调用父类的构造函数初始化
        super().__init__(dataloader, save_dir, pbar, args, _callbacks)
        # 将任务参数设置为'segment'
        self.args.task = "segment"
        # 禁用绘制ConfusionMatrix和其他图表,以避免错误
        self.args.plots = False
        # 初始化SegmentMetrics对象,设置保存结果的目录和绘图回调函数为self.on_plot
        self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)

.\yolov8\ultralytics\models\fastsam\__init__.py

# 导入模块中的特定成员:FastSAM 模型、FastSAMPredictor 预测器、FastSAMValidator 验证器
from .model import FastSAM
from .predict import FastSAMPredictor
from .val import FastSAMValidator

# 定义一个列表 __all__,包含需要在模块外部可见的符号名称
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"

.\yolov8\ultralytics\models\nas\model.py

# 从 pathlib 模块导入 Path 类,用于处理文件路径
from pathlib import Path

# 导入 PyTorch 库
import torch

# 从 Ultralytics 引擎的 model 模块中导入 Model 类
from ultralytics.engine.model import Model

# 从 Ultralytics 的 utils 模块中导入下载相关的函数
from ultralytics.utils.downloads import attempt_download_asset

# 从 Ultralytics 的 utils 模块中导入与 PyTorch 相关的工具函数
from ultralytics.utils.torch_utils import model_info

# 导入当前目录下的 predict.py 文件中的 NASPredictor 类
from .predict import NASPredictor

# 导入当前目录下的 val.py 文件中的 NASValidator 类
from .val import NASValidator


class NAS(Model):
    """
    YOLO NAS model for object detection.

    This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
    It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.

    Example:
        ```python
        from ultralytics import NAS

        model = NAS('yolo_nas_s')
        results = model.predict('ultralytics/assets/bus.jpg')
        ```py

    Attributes:
        model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.

    Note:
        YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
    """

    def __init__(self, model="yolo_nas_s.pt") -> None:
        """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
        # 断言所提供的模型文件不是 YAML 配置文件,因为 YOLO-NAS 模型仅支持预训练模型
        assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
        # 调用父类 Model 的初始化方法,传入模型路径和任务类型为 "detect"
        super().__init__(model, task="detect")

    def _load(self, weights: str, task=None) -> None:
        """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
        # 动态导入 super_gradients 模块,用于加载模型权重
        import super_gradients

        # 获取权重文件的后缀名
        suffix = Path(weights).suffix
        # 如果后缀为 ".pt",则加载模型权重
        if suffix == ".pt":
            self.model = torch.load(attempt_download_asset(weights))
        # 如果后缀为空字符串,则根据权重名称获取预训练的 COCO 权重
        elif suffix == "":
            self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")

        # 重写模型的 forward 方法,忽略额外的参数
        def new_forward(x, *args, **kwargs):
            """Ignore additional __call__ arguments."""
            return self.model._original_forward(x)

        # 保存原始的 forward 方法,并将新的 forward 方法赋值给模型
        self.model._original_forward = self.model.forward
        self.model.forward = new_forward

        # 标准化模型的属性
        self.model.fuse = lambda verbose=True: self.model
        self.model.stride = torch.tensor([32])
        self.model.names = dict(enumerate(self.model._class_names))
        self.model.is_fused = lambda: False  # for info()
        self.model.yaml = {}  # for info()
        self.model.pt_path = weights  # for export()
        self.model.task = "detect"  # for export()
    # 定义一个方法用于记录模型信息
    def info(self, detailed=False, verbose=True):
        """
        Logs model info.

        Args:
            detailed (bool): Show detailed information about model.
            verbose (bool): Controls verbosity.
        """
        # 调用 model_info 函数,传入模型对象和其他参数,并返回结果
        return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)

    @property
    # 定义一个属性,返回一个字典,将任务映射到相应的预测器和验证器类
    def task_map(self):
        """Returns a dictionary mapping tasks to respective predictor and validator classes."""
        # 返回包含映射关系的字典
        return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
posted @ 2024-09-05 11:58  绝不原创的飞龙  阅读(1)  评论(0编辑  收藏  举报