"""
This module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection,
instance segmentation, image classification, pose estimation, and multi-object tracking.
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```py
"""import random
import shutil
import subprocess
import time
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
from ultralytics.utils.plotting import plot_tune_results
classTuner:
"""
Class responsible for hyperparameter tuning of YOLO models.
The class evolves YOLO model hyperparameters over a given number of iterations
by mutating them according to the search space and retraining the model to evaluate their performance.
Attributes:
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
tune_dir (Path): Directory where evolution logs and results will be saved.
tune_csv (Path): Path to the CSV file where evolution logs are saved.
Methods:
_mutate(hyp: dict) -> dict:
Mutates the given hyperparameters within the bounds specified in `self.space`.
__call__():
Executes the hyperparameter evolution across multiple iterations.
Example:
Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```py
Tune with custom search space.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
```
"""def__init__(self, args=DEFAULT_CFG, _callbacks=None):
"""
Initialize the Tuner with configurations.
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""# 将参数中的'space'键弹出,如果不存在则使用默认空间字典
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))# 初始学习率范围 (例如 SGD=1E-2, Adam=1E-3)"lr0": (1e-5, 1e-1),
# 最终的 OneCycleLR 学习率范围 (lr0 * lrf)"lrf": (0.0001, 0.1),
# SGD 动量/Adam beta1 范围"momentum": (0.7, 0.98, 0.3),
# 优化器权重衰减范围"weight_decay": (0.0, 0.001),
# 温升 epochs 范围 (可以是小数)"warmup_epochs": (0.0, 5.0),
# 温升初始动量范围"warmup_momentum": (0.0, 0.95),
# box 损失增益范围"box": (1.0, 20.0),
# cls 损失增益范围 (与像素缩放相关)"cls": (0.2, 4.0),
# dfl 损失增益范围"dfl": (0.4, 6.0),
# 图像 HSV-Hue 增强范围 (分数)"hsv_h": (0.0, 0.1),
# 图像 HSV-Saturation 增强范围 (分数)"hsv_s": (0.0, 0.9),
# 图像 HSV-Value 增强范围 (分数)"hsv_v": (0.0, 0.9),
# 图像旋转范围 (+/- 度数)"degrees": (0.0, 45.0),
# 图像平移范围 (+/- 分数)"translate": (0.0, 0.9),
# 图像缩放范围 (+/- 增益)"scale": (0.0, 0.95),
# 图像剪切范围 (+/- 度数)"shear": (0.0, 10.0),
# 图像透视范围 (+/- 分数),范围 0-0.001"perspective": (0.0, 0.001),
# 图像上下翻转概率"flipud": (0.0, 1.0),
# 图像左右翻转概率"fliplr": (0.0, 1.0),
# 图像通道 bgr 变换概率"bgr": (0.0, 1.0),
# 图像混合概率"mosaic": (0.0, 1.0),
# 图像 mixup 概率"mixup": (0.0, 1.0),
# 分割复制粘贴概率"copy_paste": (0.0, 1.0),
}
# 使用参数获取配置并初始化
self.args = get_cfg(overrides=args)
# 获取保存目录路径
self.tune_dir = get_save_dir(self.args, name="tune")
# 定义保存结果的 CSV 文件路径
self.tune_csv = self.tune_dir / "tune_results.csv"# 获取回调函数或者使用默认回调函数列表
self.callbacks = _callbacks or callbacks.get_default_callbacks()
# 设置前缀字符串
self.prefix = colorstr("Tuner: ")
# 添加整合回调函数
callbacks.add_integration_callbacks(self)
# 记录初始化信息
LOGGER.info(
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
)
# 根据指定的参数变异超参数,基于self.space中指定的边界和缩放因子。def_mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
"""
Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
Args:
parent (str): Parent selection method: 'single' or 'weighted'.
n (int): Number of parents to consider.
mutation (float): Probability of a parameter mutation in any given iteration.
sigma (float): Standard deviation for Gaussian random number generator.
Returns:
(dict): A dictionary containing mutated hyperparameters.
"""if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate# Select parent(s)
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
w = x[:, 0] - x[:, 0].min() + 1e-6# weights (sum > 0)if parent == "single"orlen(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selectionelif parent == "weighted":
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination# Mutate
r = np.random # method
r.seed(int(time.time()))
g = np.array([v[2] iflen(v) == 3else1.0for k, v in self.space.items()]) # gains 0-1
ng = len(self.space)
v = np.ones(ng)
whileall(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
hyp = {k: float(x[i + 1] * v[i]) for i, k inenumerate(self.space.keys())}
else:
# 如果没有调优CSV文件,则使用self.args中的值初始化超参数
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
# Constrain to limits# 将超参数限制在定义的边界内for k, v in self.space.items():
hyp[k] = max(hyp[k], v[0]) # lower limit
hyp[k] = min(hyp[k], v[1]) # upper limit
hyp[k] = round(hyp[k], 5) # significant digitsreturn hyp
.\yolov8\ultralytics\engine\validator.py
# 导入必要的库import json # 导入处理 JSON 格式数据的模块import time # 导入时间相关的模块from pathlib import Path # 导入处理文件路径的模块import numpy as np # 导入处理数值数据的模块import torch # 导入 PyTorch 深度学习框架# 导入 Ultralytics 自定义模块和函数from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
# 定义一个基础验证器类classBaseValidator:
"""
BaseValidator.
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
names (dict): Class names.
seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
"""def__init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initializes a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""# 使用给定的参数初始化 BaseValidator 实例
self.args = get_cfg(overrides=args) # 获取配置参数,并用其覆盖默认配置
self.dataloader = dataloader # 存储数据加载器
self.pbar = pbar # 存储进度条对象
self.stride = None# 初始化步长为 None
self.data = None# 初始化数据为 None
self.device = None# 初始化设备为 None
self.batch_i = None# 初始化批次索引为 None
self.training = True# 标记当前为训练模式
self.names = None# 初始化名称列表为 None
self.seen = None# 初始化 seen 为 None
self.stats = None# 初始化统计信息为 None
self.confusion_matrix = None# 初始化混淆矩阵为 None
self.nc = None# 初始化类别数为 None
self.iouv = None# 初始化 iouv 为 None
self.jdict = None# 初始化 jdict 为 None
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} # 初始化速度字典为各项均为 0.0
self.save_dir = save_dir or get_save_dir(self.args) # 设置保存结果的目录,如果未提供 save_dir,则使用默认目录
(self.save_dir / "labels"if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# 如果保存为文本标签,则在保存目录下创建 'labels' 子目录;否则直接创建保存目录,并确保父目录存在if self.args.conf isNone:
self.args.conf = 0.001# 如果未提供 conf 参数,则设置默认的 conf=0.001
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) # 检查并修正图像尺寸参数
self.plots = {} # 初始化绘图字典为空
self.callbacks = _callbacks or callbacks.get_default_callbacks() # 设置回调函数字典,如果未提供,则获取默认回调函数defmatch_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
"""
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
use_scipy (bool): Whether to use scipy for matching (more precise).
Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
"""# 创建一个全零的形状为 (预测类别数, IoU 阈值数) 的布尔类型数组,用于存储正确匹配结果
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
# 创建一个形状为 (真实类别数, 预测类别数) 的布尔类型数组,标记哪些预测类别与真实类别相匹配
correct_class = true_classes[:, None] == pred_classes
# 将 IoU 值与正确类别对应位置的元素置为零,排除不匹配的类别影响
iou = iou * correct_class
iou = iou.cpu().numpy() # 将计算后的 IoU 转换为 NumPy 数组# 遍历每个 IoU 阈值for i, threshold inenumerate(self.iouv.cpu().tolist()):
if use_scipy:
# 如果使用 scipy 匹配import scipy # 仅在需要时引入以节省资源# 构建成本矩阵,仅保留大于等于当前阈值的 IoU 值
cost_matrix = iou * (iou >= threshold)
# 使用线性求和匹配最大化方法求解最优匹配if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
valid = cost_matrix[labels_idx, detections_idx] > 0if valid.any():
correct[detections_idx[valid], i] = Trueelse:
# 如果不使用 scipy 匹配,直接寻找满足 IoU 大于阈值且类别匹配的预测与真实标签
matches = np.nonzero(iou >= threshold)
matches = np.array(matches).T
if matches.shape[0]:
if matches.shape[0] > 1:
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True# 返回布尔类型的 Torch 张量,表示每个预测是否正确匹配的结果return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
defadd_callback(self, event: str, callback):
"""Appends the given callback to the list associated with the event."""
self.callbacks[event].append(callback)
defrun_callbacks(self, event: str):
"""Runs all callbacks associated with a specified event."""for callback in self.callbacks.get(event, []):
callback(self)
defget_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""raise NotImplementedError("get_dataloader function not implemented for this validator")
# 定义一个方法用于构建数据集,但是抛出一个未实现的错误,提示需要在验证器中实现这个方法defbuild_dataset(self, img_path):
"""Build dataset."""raise NotImplementedError("build_dataset function not implemented in validator")
# 定义一个方法用于预处理输入的批次数据,直接返回原始批次数据defpreprocess(self, batch):
"""Preprocesses an input batch."""return batch
# 定义一个方法用于后处理预测结果,直接返回预测结果defpostprocess(self, preds):
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""return preds
# 定义一个方法用于初始化 YOLO 模型的性能指标,但是这里什么也没做definit_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""pass# 定义一个方法用于根据预测和批次数据更新性能指标,但是这里什么也没做defupdate_metrics(self, preds, batch):
"""Updates metrics based on predictions and batch."""pass# 定义一个方法用于完成并返回所有性能指标,但是这里什么也没做deffinalize_metrics(self, *args, **kwargs):
"""Finalizes and returns all metrics."""pass# 定义一个方法用于返回模型性能的统计信息,这里返回一个空字典defget_stats(self):
"""Returns statistics about the model's performance."""return {}
# 定义一个方法用于检查统计信息,但是这里什么也没做defcheck_stats(self, stats):
"""Checks statistics."""pass# 定义一个方法用于打印模型预测的结果,但是这里什么也没做defprint_results(self):
"""Prints the results of the model's predictions."""pass# 定义一个方法用于获取 YOLO 模型的描述信息,但是这里什么也没做defget_desc(self):
"""Get description of the YOLO model."""pass# 定义一个属性方法,用于返回 YOLO 训练/验证中使用的性能指标键值,这里返回一个空列表 @propertydefmetric_keys(self):
"""Returns the metric keys used in YOLO training/validation."""return []
# 定义一个方法用于注册绘图数据(例如供回调函数使用)defon_plot(self, name, data=None):
"""Registers plots (e.g. to be consumed in callbacks)"""
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
# TODO: may need to put these following functions into callback# 定义一个方法用于在训练期间绘制验证样本,但是这里什么也没做defplot_val_samples(self, batch, ni):
"""Plots validation samples during training."""pass# 定义一个方法用于绘制 YOLO 模型在批次图像上的预测结果,但是这里什么也没做defplot_predictions(self, batch, preds, ni):
"""Plots YOLO model predictions on batch images."""pass# 定义一个方法用于将预测结果转换为 JSON 格式,但是这里什么也没做defpred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""pass# 定义一个方法用于评估和返回预测统计数据的 JSON 格式,但是这里什么也没做defeval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""pass
# Ultralytics YOLO 🚀, AGPL-3.0 license# 引入 requests 模块,用于发送 HTTP 请求
import requests
# 从 ultralytics.hub.utils 模块导入相关常量和函数
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
# 从 ultralytics.utils 模块导入特定变量和函数
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis
# 定义 API_KEY_URL 常量,指向 API 密钥设置页面的 URL
API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"# Auth 类,管理认证流程,包括 API 密钥处理、基于 cookie 的认证和生成头部信息classAuth:"""
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
The class supports different methods of authentication:
1. Directly using an API key.
2. Authenticating using browser cookies (specifically in Google Colab).
3. Prompting the user to enter an API key.
Attributes:
id_token (str or bool): Token used for identity verification, initialized as False.
api_key (str or bool): API key for authentication, initialized as False.
model_key (bool): Placeholder for model key, initialized as False.
"""# 类属性:身份令牌 id_token、API 密钥 api_key 和模型密钥 model_key 的初始化
id_token = api_key = model_key = False
def__init__(self, api_key="", verbose=False):
"""
Initialize the Auth class with an optional API key.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
"""# 如果 api_key 包含下划线,则按下划线分割并保留第一部分作为 API 密钥
api_key = api_key.split("_")[0]
# 将 API 密钥设置为传入的值或者从 SETTINGS 中获取的 api_keyself.api_key = api_key or SETTINGS.get("api_key", "")
# 如果提供了 API 密钥ifself.api_key:# 如果提供的 API 密钥与 SETTINGS 中的 api_key 匹配ifself.api_key == SETTINGS.get("api_key"):
# 如果 verbose 为 True,记录用户已经认证成功ifverbose:
LOGGER.info(f"{PREFIX}Authenticated ✅")
returnelse:# 尝试使用提供的 API 密钥进行认证
success = self.authenticate()
# 如果未提供 API 密钥且运行环境是 Google Colab 笔记本
elif IS_COLAB:# 尝试使用浏览器 cookie 进行认证
success = self.auth_with_cookies()
else:# 请求用户输入 API 密钥
success = self.request_api_key()
# 在成功认证后,更新 SETTINGS 中的 API 密钥ifsuccess:
SETTINGS.update({"api_key": self.api_key})
# 如果 verbose 为 True,记录新的认证成功ifverbose:
LOGGER.info(f"{PREFIX}New authentication successful ✅")
elif verbose:# 如果认证失败且 verbose 为 True,提示用户从 API_KEY_URL 获取 API 密钥
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo hub login API_KEY'")
# 定义一个方法用于请求 API 密钥,最多尝试 max_attempts 次defrequest_api_key(self, max_attempts=3):
"""
Prompt the user to input their API key.
Returns the model ID.
"""
import getpass # 导入 getpass 模块,用于隐藏输入的 API 密钥# 循环尝试获取 API 密钥for attempts in range(max_attempts):
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") # 提示用户输入 API 密钥self.api_key = input_key.split("_")[0] # 如果有模型 ID,去除下划线后面的部分ifself.authenticate(): # 尝试验证 API 密钥的有效性return True
# 如果达到最大尝试次数仍未成功,抛出连接错误
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
# 方法用于验证 API 密钥的有效性defauthenticate(self) -> bool:"""
Attempt to authenticate with the server using either id_token or API key.
Returns:
(bool): True if authentication is successful, False otherwise.
"""try:if header := self.get_auth_header(): # 获取认证所需的头部信息
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) # 发送认证请求ifnot r.json().get("success", False): # 检查认证是否成功
raise ConnectionError("Unable to authenticate.")
return True
raise ConnectionError("User has not authenticated locally.") # 如果本地未认证则抛出连接错误
except ConnectionError:self.id_token = self.api_key = False # 重置无效的 id_token 和 api_key
LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
return False
# 方法尝试通过 cookies 进行认证并设置 id_tokendefauth_with_cookies(self) -> bool:"""
Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a
supported browser.
Returns:
(bool): True if authentication is successful, False otherwise.
"""ifnotIS_COLAB:return False # 当前只能在 Colab 中使用try:
authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") # 使用凭据请求自动认证if authn.get("success", False): # 检查认证是否成功self.id_token = authn.get("data", {}).get("idToken", None) # 设置 id_tokenself.authenticate() # 尝试验证认证信息return True
raise ConnectionError("Unable to fetch browser authentication details.") # 无法获取浏览器认证详情则抛出连接错误
except ConnectionError:self.id_token = False # 重置无效的 id_tokenreturn False
# 方法用于获取用于 API 请求的认证头部信息defget_auth_header(self):
"""
Get the authentication header for making API requests.
Returns:
(dict): The authentication header if id_token or API key is set, None otherwise.
"""ifself.id_token:return {"authorization": f"Bearer {self.id_token}"} # 返回包含 id_token 的认证头部
elif self.api_key:return {"x-api-key": self.api_key} # 返回包含 API 密钥的认证头部# 如果两者都未设置,则返回 None
.\yolov8\ultralytics\hub\google\__init__.py
# 导入所需的库和模块import concurrent.futures # 用于并发执行任务import statistics # 提供统计函数,如计算均值、中位数等import time # 提供时间相关的功能,如睡眠、计时等from typing importList, Optional, Tuple# 导入类型提示相关的模块import requests # 提供进行 HTTP 请求的功能classGCPRegions:
"""
A class for managing and analyzing Google Cloud Platform (GCP) regions.
This class provides functionality to initialize, categorize, and analyze GCP regions based on their
geographical location, tier classification, and network latency.
Attributes:
regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
Methods:
tier1: Returns a list of tier 1 GCP regions.
tier2: Returns a list of tier 2 GCP regions.
lowest_latency: Determines the GCP region(s) with the lowest network latency.
Examples:
>>> from ultralytics.hub.google import GCPRegions
>>> regions = GCPRegions()
>>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3)
>>> print(f"Lowest latency region: {lowest_latency_region[0][0]}")
"""def__init__(self):
"""Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details."""# 定义包含各个谷歌云平台地区及其详细信息的字典
self.regions = {
"asia-east1": (1, "Taiwan", "China"),
"asia-east2": (2, "Hong Kong", "China"),
"asia-northeast1": (1, "Tokyo", "Japan"),
"asia-northeast2": (1, "Osaka", "Japan"),
"asia-northeast3": (2, "Seoul", "South Korea"),
"asia-south1": (2, "Mumbai", "India"),
"asia-south2": (2, "Delhi", "India"),
"asia-southeast1": (2, "Jurong West", "Singapore"),
"asia-southeast2": (2, "Jakarta", "Indonesia"),
"australia-southeast1": (2, "Sydney", "Australia"),
"australia-southeast2": (2, "Melbourne", "Australia"),
"europe-central2": (2, "Warsaw", "Poland"),
"europe-north1": (1, "Hamina", "Finland"),
"europe-southwest1": (1, "Madrid", "Spain"),
"europe-west1": (1, "St. Ghislain", "Belgium"),
"europe-west10": (2, "Berlin", "Germany"),
"europe-west12": (2, "Turin", "Italy"),
"europe-west2": (2, "London", "United Kingdom"),
"europe-west3": (2, "Frankfurt", "Germany"),
"europe-west4": (1, "Eemshaven", "Netherlands"),
"europe-west6": (2, "Zurich", "Switzerland"),
"europe-west8": (1, "Milan", "Italy"),
"europe-west9": (1, "Paris", "France"),
"me-central1": (2, "Doha", "Qatar"),
"me-west1": (1, "Tel Aviv", "Israel"),
"northamerica-northeast1": (2, "Montreal", "Canada"),
"northamerica-northeast2": (2, "Toronto", "Canada"),
"southamerica-east1": (2, "São Paulo", "Brazil"),
"southamerica-west1": (2, "Santiago", "Chile"),
"us-central1": (1, "Iowa", "United States"),
"us-east1": (1, "South Carolina", "United States"),
"us-east4": (1, "Northern Virginia", "United States"),
"us-east5": (1, "Columbus", "United States"),
"us-south1": (1, "Dallas", "United States"),
"us-west1": (1, "Oregon", "United States"),
"us-west2": (2, "Los Angeles", "United States"),
"us-west3": (2, "Salt Lake City", "United States"),
"us-west4": (2, "Las Vegas", "United States"),
}
deftier1(self) -> List[str]:
"""Returns a list of GCP regions classified as tier 1 based on predefined criteria."""# 返回符合预定义标准的属于第一层级的谷歌云平台地区列表return [region for region, info in self.regions.items() if info[0] == 1]
deftier2(self) -> List[str]:
"""Returns a list of GCP regions classified as tier 2 based on predefined criteria."""# 返回符合预定义标准的属于第二层级的谷歌云平台地区列表return [region for region, info in self.regions.items() if info[0] == 2]
@staticmethoddef_ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:
"""Pings a specified GCP region and returns latency statistics: mean, min, max, and standard deviation."""# 构建请求的 URL,使用指定的 GCP 地区
url = f"https://{region}-docker.pkg.dev"# 存储每次请求的延迟时间
latencies = []
# 尝试多次请求for _ inrange(attempts):
try:
# 记录请求开始时间
start_time = time.time()
# 发送 HEAD 请求到指定 URL,设置超时时间为 5 秒
_ = requests.head(url, timeout=5)
# 计算请求完成后的延迟时间(毫秒)
latency = (time.time() - start_time) * 1000# convert latency to milliseconds# 如果延迟时间不是无穷大,则添加到延迟时间列表中if latency != float("inf"):
latencies.append(latency)
except requests.RequestException:
pass# 如果未成功获取任何延迟数据,则返回无穷大的统计数据ifnot latencies:
return region, float("inf"), float("inf"), float("inf"), float("inf")
# 计算延迟时间的标准差,如果样本数大于1
std_dev = statistics.stdev(latencies) iflen(latencies) > 1else0# 返回地区名称及其延迟统计数据:平均值、标准差、最小值、最大值return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies)
deflowest_latency(
self,
top: int = 1,
verbose: bool = False,
tier: Optional[int] = None,
attempts: int = 1,
# 返回一个列表,包含元组,每个元组代表 GCP 地区的延迟统计信息# 每个元组包含 (地区名, 平均延迟, 标准差, 最小延迟, 最大延迟)def lowest_latency(self, top: int, verbose: bool, tier: Optional[int], attempts: int) -> List[Tuple[str, float, float, float, float]]:
"""
Determines the GCP regions with the lowest latency based on ping tests.
Args:
top (int): Number of top regions to return.
verbose (bool): If True, prints detailed latency information for all tested regions.
tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
attempts (int): Number of ping attempts per region.
Returns:
(List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
Examples:
>>> regions = GCPRegions()
>>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2)
>>> print(results[0][0]) # Print the name of the lowest latency region
"""# 如果 verbose 为 True,打印正在进行的 ping 测试信息if verbose:
print(f"Testing GCP regions for latency (with {attempts}{'retry'if attempts == 1else'attempts'})...")
# 根据 tier 条件过滤要测试的地区列表
regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier elselist(self.regions.keys())
# 使用 ThreadPoolExecutor 并发执行 ping 测试with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test))
# 根据平均延迟对结果进行排序
sorted_results = sorted(results, key=lambda x: x[1])
# 如果 verbose 为 True,打印详细的延迟信息表格if verbose:
print(f"{'Region':<25}{'Location':<35}{'Tier':<5}{'Latency (ms)'}")
for region, mean, std, min_, max_ in sorted_results:
tier, city, country = self.regions[region]
location = f"{city}, {country}"if mean == float("inf"):
print(f"{region:<25}{location:<35}{tier:<5}{'Timeout'}")
else:
print(f"{region:<25}{location:<35}{tier:<5}{mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})")
print(f"\nLowest latency region{'s'if top > 1else''}:")
for region, mean, std, min_, max_ in sorted_results[:top]:
tier, city, country = self.regions[region]
location = f"{city}, {country}"print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))")
# 返回延迟最低的前 top 个地区的信息列表return sorted_results[:top]
# 如果脚本被直接执行(而不是被导入为模块),则执行以下代码if __name__ == "__main__":
# 创建一个 GCPRegions 的实例对象
regions = GCPRegions()
# 调用 lowest_latency 方法来获取最低延迟的地区列表# 参数解释:# top=3: 获取延迟最低的前三个地区# verbose=True: 打印详细信息,例如每次尝试的信息# tier=1: 限定在第一层次的数据中进行选择# attempts=3: 尝试获取数据的最大次数
top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3)
.\yolov8\ultralytics\hub\session.py
# Ultralytics YOLO 🚀, AGPL-3.0 license
import threading # 导入多线程支持模块
import time # 导入时间模块
from http import HTTPStatus # 导入HTTP状态码模块
from pathlib import Path # 导入路径操作模块
import requests # 导入HTTP请求模块
from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM # 导入Ultralytics HUB的工具模块
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis # 导入Ultralytics的工具函数和常量
from ultralytics.utils.errors import HUBModelError # 导入自定义的错误类
AGENT_NAME = f"python-{__version__}-colab"if IS_COLAB else f"python-{__version__}-local"# 根据是否在Colab环境中设置代理名称classHUBTrainingSession:"""
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
Attributes:
model_id (str): Identifier for the YOLO model being trained.
model_url (str): URL for the model in Ultralytics HUB.
rate_limits (dict): Rate limits for different API calls (in seconds).
timers (dict): Timers for rate limiting.
metrics_queue (dict): Queue for the model's metrics.
model (dict): Model data fetched from Ultralytics HUB.
"""def__init__(self, identifier):
"""
Initialize the HUBTrainingSession with the provided model identifier.
Args:
identifier (str): Model identifier used to initialize the HUB training session.
It can be a URL string or a model key with specific format.
Raises:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
ModuleNotFoundError: If hub-sdk package is not installed.
"""
from hub_sdk import HUBClient # 导入HUBClient类来进行与Ultralytics HUB的API交互self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # 设置API调用的速率限制(秒)self.metrics_queue = {} # 存储每个epoch的指标,直到上传self.metrics_upload_failed_queue = {} # 存储上传失败的每个epoch的指标self.timers = {} # 在ultralytics/utils/callbacks/hub.py中保存计时器self.model = None # 初始化模型数据为Noneself.model_url = None # 初始化模型URL为Noneself.model_file = None # 初始化模型文件为None# 解析输入的标识符
api_key, model_id, self.filename = self._parse_identifier(identifier)
# 获取凭证
active_key = api_key or SETTINGS.get("api_key")
credentials = {"api_key": active_key} if active_key else None # 设置凭证信息# 初始化客户端self.client = HUBClient(credentials)
# 如果认证成功则加载模型ifself.client.authenticated:ifmodel_id:self.load_model(model_id) # 加载现有模型else:self.model = self.client.model() # 加载空模型@classmethoddefcreate_session(cls, identifier, args=None):
"""Class method to create an authenticated HUBTrainingSession or return None."""try:# 尝试创建一个指定标识符的会话对象
session = cls(identifier)
# 检查客户端是否已认证ifnot session.client.authenticated:# 如果未认证且标识符以指定路径开始,则警告并退出程序if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
exit()
return None
# 如果提供了参数且标识符不是 HUB 模型的 URL,则创建模型if args andnot identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
session.create_model(args)
# 断言模型已加载正确
assert session.model.id, "HUB model not loaded correctly"# 返回创建的会话对象return session
# 处理权限错误或模块未找到异常,表明 hub-sdk 未安装
except (PermissionError, ModuleNotFoundError, AssertionError):
return None
defload_model(self, model_id):
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""# 通过提供的模型标识符加载现有模型self.model = self.client.model(model_id)
# 如果模型数据不存在,则抛出值错误异常ifnotself.model.data:# then model does not exist
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling# 设置模型的 URLself.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"# 如果模型已经训练完成ifself.model.is_trained():
# 输出加载已训练的 HUB 模型的信息
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
# 获取模型权重的 URLself.model_file = self.model.get_weights_url("best")
return# 设置训练参数并启动 HUB 监控代理的心跳self._set_train_args()
self.model.start_heartbeat(self.rate_limits["heartbeat"])
# 输出模型的 URL
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
defcreate_model(self, model_args):
"""Initializes a HUB training session with the specified model identifier."""# 构造包含训练参数的 payload 对象
payload = {
"config": {
"batchSize": model_args.get("batch", -1), # 设置批量大小,默认为-1"epochs": model_args.get("epochs", 300), # 设置训练周期数,默认为300"imageSize": model_args.get("imgsz", 640), # 设置图像大小,默认为640"patience": model_args.get("patience", 100), # 设置训练耐心值,默认为100"device": str(model_args.get("device", "")), # 设置设备类型,将None转换为字符串"cache": str(model_args.get("cache", "ram")), # 设置缓存类型,将True、False、None转换为字符串
},
"dataset": {"name": model_args.get("data")}, # 设置数据集名称"lineage": {
"architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")}, # 设置模型架构名称"parent": {}, # 初始化父模型信息
},
"meta": {"name": self.filename}, # 设置模型元数据名称
}
ifself.filename.endswith(".pt"):
payload["lineage"]["parent"]["name"] = self.filename # 如果文件名以.pt结尾,设置父模型名称为文件名self.model.create_model(payload) # 调用模型对象的创建模型方法,使用payload作为参数# Model could not be created# TODO: improve error handling# 如果模型未成功创建,记录错误并返回Noneifnotself.model.id:return None
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"# 构造模型的URL链接# Start heartbeats for HUB to monitor agent# 启动心跳以便HUB监控代理self.model.start_heartbeat(self.rate_limits["heartbeat"])
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") # 记录模型的访问链接def_parse_identifier(identifier):
"""
Parses the given identifier to determine the type of identifier and extract relevant components.
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
Args:
identifier (str): The identifier string to be parsed.
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
Raises:
HUBModelError: If the identifier format is not recognized.
"""# Initialize variables to None
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URLif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
else:# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split("_")
# Check if identifier is in the format of API key and model IDif len(parts) == 2and len(parts[0]) == 42and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:# Raise an error if identifier format does not match any supported format
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
# Return the extracted components as a tuplereturn api_key, model_id, filename
def_set_train_args(self):
"""
Initializes training arguments and creates a model entry on the Ultralytics HUB.
This method sets up training arguments based on the model's state and updates them with any additional
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
or requires specific file setup.
Raises:
ValueError: If the model is already trained, if required dataset information is missing, or if there are
issues with the provided training arguments.
"""ifself.model.is_resumable():
# Model has saved weightsself.train_args = {"data": self.model.get_dataset_url(), "resume": True}
self.model_file = self.model.get_weights_url("last")
else:# Model has no saved weightsself.train_args = self.model.data.get("train_args") # 从模型数据中获取训练参数# 设置模型文件,可以是 *.pt 或 *.yaml 文件self.model_file = (
self.model.get_weights_url("parent") ifself.model.is_pretrained() elseself.model.get_architecture()
)
if"data"notinself.train_args:# RF bug - datasets are sometimes not exported
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # 检查并纠正文件名self.model_id = self.model.id
defrequest_queue(
self,
request_func,
retry=3,
timeout=30,
thread=True,
verbose=True,
progress_total=None,
stream_response=None,
*args,
**kwargs,
):
"""
Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress.
"""defretry_request():
"""
Attempts to call `request_func` with retries, timeout, and optional threading.
"""
t0 = time.time() # Record the start time for the timeout
response = None
for i in range(retry + 1):
if (time.time() - t0) > timeout:
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
break# Timeout reached, exit loop
response = request_func(*args, **kwargs)
if response is None:
LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
time.sleep(2**i) # Exponential backoff before retrying
continue # Skip further processing and retryifprogress_total:self._show_upload_progress(progress_total, response)
elif stream_response:self._iterate_content(response)
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:# if request related to metrics uploadif kwargs.get("metrics"):
self.metrics_upload_failed_queue = {}
return response # Success, no need to retryif i == 0:
# Initial attempt, check status code and provide messages
message = self._get_failure_message(response, retry, timeout)
ifverbose:
LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
ifnotself._should_retry(response.status_code):
LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code})")
break# Not an error that should be retried, exit loop
time.sleep(2**i) # Exponential backoff for retries# if request related to metrics upload and exceed retriesif response is None and kwargs.get("metrics"):
self.metrics_upload_failed_queue.update(kwargs.get("metrics", None))
return response
ifthread:# Start a new thread to run the retry_request function
threading.Thread(target=retry_request, daemon=True).start()
else:# If running in the main thread, call retry_request directlyreturn retry_request()
@staticmethoddef_should_retry(status_code):
"""
Determines if a request should be retried based on the HTTP status code.
"""
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT,
}
return status_code in retry_codes
def_get_failure_message(self, response: requests.Response, retry: int, timeout: int):
"""
Generate a retry message based on the response status code.
Args:
response: The HTTP response object.
retry: The number of retry attempts allowed.
timeout: The maximum timeout duration.
Returns:
(str): The retry message.
"""# 如果应该重试,返回重试信息,包括重试次数和超时时间ifself._should_retry(response.status_code):
return f"Retrying {retry}x for {timeout}s."ifretryelse""# 如果响应状态码为429(太多请求),则显示速率限制信息
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS:# rate limit
headers = response.headers
return (
f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s."
)
else:try:# 尝试从响应中读取JSON格式的消息,如果无法读取则返回默认消息return response.json().get("message", "No JSON message.")
except AttributeError:# 如果无法读取JSON,则返回无法读取JSON的提示信息return"Unable to read JSON."defupload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""# 将模型指标上传到Ultralytics HUB,并返回请求队列的结果returnself.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
defupload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> None:"""
Upload a model checkpoint to Ultralytics HUB.
Args:
epoch (int): The current training epoch.
weights (str): Path to the model weights file.
is_best (bool): Indicates if the current model is the best one so far.
map (float): Mean average precision of the model.
final (bool): Indicates if the model is the final model after training.
"""# 如果指定的模型权重文件存在if Path(weights).is_file():
# 获取模型文件的总大小(仅在最终上传时显示进度)
progress_total = Path(weights).stat().st_size if final else None # Only show progress if final# 请求队列将模型上传到Ultralytics HUB,包括各种参数和选项self.request_queue(
self.model.upload_model,
epoch=epoch,
weights=weights,
is_best=is_best,
map=map,
final=final,
retry=10,
timeout=3600,
thread=not final,
progress_total=progress_total,
stream_response=True,
)
else:# 如果指定的模型权重文件不存在,则记录警告信息
LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
@staticmethod# 显示文件下载进度条,用于跟踪文件下载过程中的进度def_show_upload_progress(content_length: int, response: requests.Response) -> None:"""
Display a progress bar to track the upload progress of a file download.
Args:
content_length (int): The total size of the content to be downloaded in bytes.
response (requests.Response): The response object from the file download request.
Returns:
None
"""# 使用 tqdm 创建进度条,总大小为 content_length,单位为 B,自动缩放单位
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:# 遍历响应中的数据块,更新进度条for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
@staticmethod# 静态方法:处理流式 HTTP 响应数据def_iterate_content(response: requests.Response) -> None:"""
Process the streamed HTTP response data.
Args:
response (requests.Response): The response object from the file download request.
Returns:
None
"""# 遍历响应中的数据块,但不对数据块做任何操作for _ in response.iter_content(chunk_size=1024):
pass # Do nothing with data chunks
.\yolov8\ultralytics\hub\utils.py
# 导入所需的库import os
import platform
import random
import threading
import time
from pathlib import Path
# 导入第三方库 requestsimport requests
# 导入 ultralytics.utils 下的多个模块和函数from ultralytics.utils import (
ARGV,
ENVIRONMENT,
IS_COLAB,
IS_GIT_DIR,
IS_PIP_PACKAGE,
LOGGER,
ONLINE,
RANK,
SETTINGS,
TESTS_RUNNING,
TQDM,
TryExcept,
__version__,
colorstr,
get_git_origin_url,
)
# 导入 ultralytics.utils.downloads 模块中的 GITHUB_ASSETS_NAMESfrom ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
# 设置 HUB_API_ROOT 和 HUB_WEB_ROOT 变量,若环境变量 ULTRALYTICS_HUB_API 或 ULTRALYTICS_HUB_WEB 未定义,则使用默认值
HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com")
HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com")
# 使用 colorstr 函数创建 PREFIX 变量,用于打印带颜色的文本前缀
PREFIX = colorstr("Ultralytics HUB: ")
# 定义帮助信息字符串
HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."defrequest_with_credentials(url: str) -> any:
"""
在 Google Colab 环境中发送带有附加 cookies 的 AJAX 请求。
Args:
url (str): 要发送请求的 URL。
Returns:
(any): AJAX 请求的响应数据。
Raises:
OSError: 如果函数不在 Google Colab 环境中运行。
"""# 如果不在 Colab 环境中,则抛出 OSError 异常ifnot IS_COLAB:
raise OSError("request_with_credentials() must run in a Colab environment")
# 导入必要的 Colab 相关库from google.colab import output # noqafrom IPython import display # noqa# 使用 display.Javascript 创建一个 AJAX 请求,并附加 cookies
display.display(
display.Javascript(
"""
window._hub_tmp = new Promise((resolve, reject) => {
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
fetch("%s", {
method: 'POST',
credentials: 'include'
})
.then((response) => resolve(response.json()))
.then((json) => {
clearTimeout(timeout);
}).catch((err) => {
clearTimeout(timeout);
reject(err);
});
});
"""
% url
)
)
# 返回输出的结果return output.eval_js("_hub_tmp")
defrequests_with_progress(method, url, **kwargs):
"""
使用指定的方法和 URL 发送 HTTP 请求,支持可选的进度条显示。
Args:
method (str): 要使用的 HTTP 方法 (例如 'GET'、'POST')。
url (str): 要发送请求的 URL。
**kwargs (any): 传递给底层 `requests.request` 函数的其他关键字参数。
Returns:
(requests.Response): HTTP 请求的响应对象。
Note:
- 如果 'progress' 设置为 True,则进度条将显示已知内容长度的下载进度。
- 如果 'progress' 是一个数字,则进度条将显示假设内容长度为 'progress' 的下载进度。
"""# 弹出 kwargs 中的 progress 参数,默认为 False
progress = kwargs.pop("progress", False)
# 如果 progress 为 False,则直接发送请求ifnot progress:
return requests.request(method, url, **kwargs)
# 发起 HTTP 请求并获取响应
response = requests.request(method, url, stream=True, **kwargs)
# 从响应头中获取内容长度信息,如果 progress 参数是布尔值则返回内容长度,否则返回 progress 参数的值作为总大小
total = int(response.headers.get("content-length", 0) ifisinstance(progress, bool) else progress) # total sizetry:
# 初始化进度条对象,显示总大小并按照适当的单位进行缩放
pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
# 逐块迭代响应数据流,每次更新进度条for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
# 关闭进度条
pbar.close()
except requests.exceptions.ChunkedEncodingError: # 避免出现 'Connection broken: IncompleteRead' 的警告# 关闭响应以处理异常
response.close()
# 返回完整的 HTTP 响应对象return response
"""
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
Args:
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
url (str): The URL to make the request to.
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
**kwargs (any): Keyword arguments to be passed to the requests function specified in method.
Returns:
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
"""
retry_codes = (408, 500) # retry only these codes# Decorator to handle exceptions and log messages @TryExcept(verbose=verbose)deffunc(func_method, func_url, **func_kwargs):
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
r = None# response object
t0 = time.time() # start time for timeoutfor i inrange(retry + 1):
if (time.time() - t0) > timeout:
break# Perform HTTP request with progress tracking if enabled
r = requests_with_progress(func_method, func_url, **func_kwargs)
# Check if response status code indicates successif r.status_code < 300:
breaktry:
m = r.json().get("message", "No JSON message.")
except AttributeError:
m = "Unable to read JSON."# Handle retry logic based on response status codeif i == 0:
if r.status_code in retry_codes:
m += f" Retrying {retry}x for {timeout}s."if retry else""elif r.status_code == 429: # rate limit exceeded
h = r.headers # response headers
m = (
f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "f"Please retry after {h['Retry-After']}s."
)
if verbose:
LOGGER.warning(f"{PREFIX}{m}{HELP_MSG} ({r.status_code} #{code})")
# Return response if no need to retryif r.status_code notin retry_codes:
return r
time.sleep(2**i) # exponential backoff waitreturn r
# Prepare arguments and pass progress flag to kwargs
args = method, url
kwargs["progress"] = progress
# 如果 thread 参数为真,则创建一个新线程并启动,运行 func 函数if thread:
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
# 如果 thread 参数为假,则直接调用 func 函数并返回其结果else:
return func(*args, **kwargs)
classEvents:
"""
A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
Attributes:
url (str): The URL to send anonymous events.
rate_limit (float): The rate limit in seconds for sending events.
metadata (dict): A dictionary containing metadata about the environment.
enabled (bool): A flag to enable or disable Events based on certain conditions.
"""# 设置 Google Analytics 收集匿名事件的 URL
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"def__init__(self):
"""Initializes the Events object with default values for events, rate_limit, and metadata."""# 初始化事件列表
self.events = [] # events list# 设置事件发送的速率限制(单位:秒)
self.rate_limit = 30.0# rate limit (seconds)# 初始化事件发送的计时器(单位:秒)
self.t = 0.0# rate limit timer (seconds)# 设置环境的元数据
self.metadata = {
"cli": Path(ARGV[0]).name == "yolo", # 检查命令行是否为 'yolo'"install": "git"if IS_GIT_DIR else"pip"if IS_PIP_PACKAGE else"other", # 检查安装方式是 git 还是 pip 或其他"python": ".".join(platform.python_version_tuple()[:2]), # Python 版本号,例如 3.10"version": __version__, # 从模块中获取版本号"env": ENVIRONMENT, # 获取环境变量"session_id": round(random.random() * 1e15), # 创建随机会话 ID"engagement_time_msec": 1000, # 设置参与时间(毫秒)
}
# 根据设置和其他条件,确定是否启用事件收集
self.enabled = (
SETTINGS["sync"] # 检查是否设置为同步and RANK in {-1, 0} # 检查当前排名是否为 -1 或 0andnot TESTS_RUNNING # 确保没有正在运行的测试and ONLINE # 确保在线状态and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") # 检查安装来源是否为指定的 GitHub 仓库
)
# 定义一个特殊方法 __call__(),使实例可以像函数一样被调用def__call__(self, cfg):
"""
Attempts to add a new event to the events list and send events if the rate limit is reached.
Args:
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
"""# 如果事件功能未启用,直接返回,不执行任何操作ifnot self.enabled:
# Events disabled, do nothingreturn# 尝试添加事件到事件列表iflen(self.events) < 25: # 事件列表最多包含 25 个事件,超过部分将被丢弃# 构建事件参数字典,包括元数据和配置的任务和模型信息
params = {
**self.metadata,
"task": cfg.task,
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else"custom",
}
# 如果配置模式为 "export",则添加格式信息到参数字典中if cfg.mode == "export":
params["format"] = cfg.format# 将新事件以字典形式添加到事件列表中
self.events.append({"name": cfg.mode, "params": params})
# 检查发送速率限制
t = time.time()
if (t - self.t) < self.rate_limit:
# 如果发送时间间隔未超过限制,等待发送return# 如果时间间隔超过限制,立即发送事件数据
data = {"client_id": SETTINGS["uuid"], "events": self.events} # 使用 SHA-256 匿名化的 UUID 哈希和事件列表# 发送 POST 请求,相当于 requests.post(self.url, json=data),不进行重试和输出详细信息
smart_request("post", self.url, json=data, retry=0, verbose=False)
# 重置事件列表和发送时间计时器
self.events = []
self.t = t
# 在 hub/utils 初始化中运行以下代码
events = Events()
.\yolov8\ultralytics\hub\__init__.py
# Ultralytics YOLO 🚀, AGPL-3.0 license
import requests # 导入requests库,用于发送HTTP请求
from ultralytics.data.utils import HUBDatasetStats # 导入HUBDatasetStats工具类
from ultralytics.hub.auth import Auth # 导入Auth类,用于认证
from ultralytics.hub.session import HUBTrainingSession # 导入HUBTrainingSession类,用于处理训练会话
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events # 导入常量和事件
from ultralytics.utils import LOGGER, SETTINGS, checks # 导入日志记录器、设置和检查工具
__all__ = (
"PREFIX",
"HUB_WEB_ROOT",
"HUBTrainingSession",
"login",
"logout",
"reset_model",
"export_fmts_hub",
"export_model",
"get_export",
"check_dataset",
"events",
)
deflogin(api_key: str = None, save=True) -> bool:"""
Log in to the Ultralytics HUB API using the provided API key.
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
environment variable if successfully authenticated.
Args:
api_key (str, optional): API key to use for authentication.
If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
Returns:
(bool): True if authentication is successful, False otherwise.
"""
checks.check_requirements("hub-sdk>=0.0.8") # 检查是否满足SDK的最低版本要求
from hub_sdk import HUBClient # 导入HUBClient类来进行HUB API的客户端操作
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys"# 设置API密钥设置页面的重定向URL
saved_key = SETTINGS.get("api_key") # 获取保存在SETTINGS中的API密钥
active_key = api_key or saved_key # 使用提供的API密钥或从环境变量中获取的API密钥
credentials = {"api_key": active_key} if active_key and active_key != ""else None # 设置认证凭据
client = HUBClient(credentials) # 初始化HUBClient客户端对象if client.authenticated:# 成功通过HUB认证if save and client.api_key != saved_key:
SETTINGS.update({"api_key": client.api_key}) # 更新SETTINGS中的有效API密钥# 根据是否提供了API密钥或从设置中检索到来设置消息内容
log_message = (
"New authentication successful ✅"if client.api_key == api_key ornot credentials else"Authenticated ✅"
)
LOGGER.info(f"{PREFIX}{log_message}") # 记录认证成功信息到日志return True
else:# 未能通过HUB认证
LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo hub login API_KEY'")
return False
deflogout():
"""
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
Example:
```py
from ultralytics import hub
hub.logout()
```
"""
SETTINGS["api_key"] = ""# 清空SETTINGS中的API密钥
SETTINGS.save() # 保存SETTINGS变更
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.") # 记录退出登录信息到日志defreset_model(model_id=""):
"""Reset a trained model to an untrained state."""
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
# 发送POST请求到HUB API以重置指定model_id的模型为未训练状态# 检查 HTTP 响应状态码是否为 200if r.status_code == 200:
# 如果响应状态码为 200,记录信息日志,表示模型重置成功
LOGGER.info(f"{PREFIX}Model reset successfully")
# 返回空,结束函数执行return# 如果响应状态码不为 200,记录警告日志,表示模型重置失败,并包含响应的状态码和原因
LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
defexport_fmts_hub():
"""Returns a list of HUB-supported export formats."""# 导入 export_formats 函数,该函数位于 ultralytics.engine.exporter 模块中
from ultralytics.engine.exporter import export_formats
# 返回 export_formats 函数返回值的第二个元素至最后一个元素(不包括第一个元素),并添加两个特定的输出格式return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
defexport_model(model_id="", format="torchscript"):
"""Export a model to all formats."""# 断言指定的导出格式在支持的格式列表中,如果不支持则抛出 AssertionError
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"# 发起 POST 请求,导出指定模型到指定格式,并使用 API 密钥进行身份验证
r = requests.post(
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
)
# 断言请求的状态码为 200,否则抛出 AssertionError,显示错误信息
assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"# 记录导出操作开始的信息
LOGGER.info(f"{PREFIX}{format} export started ✅")
defget_export(model_id="", format="torchscript"):
"""Get an exported model dictionary with download URL."""# 断言指定的导出格式在支持的格式列表中,如果不支持则抛出 AssertionError
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"# 发起 POST 请求,获取导出的模型字典及其下载链接,并使用 API 密钥进行身份验证
r = requests.post(
f"{HUB_API_ROOT}/get-export",
json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
headers={"x-api-key": Auth().api_key},
)
# 断言请求的状态码为 200,否则抛出 AssertionError,显示错误信息
assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"# 返回从响应中解析得到的 JSON 格式的导出模型字典return r.json()
defcheck_dataset(path: str, task: str) -> None:"""
Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
to the HUB. Usage examples are given below.
Args:
path (str): Path to data.zip (with data.yaml inside data.zip).
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
Example:
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
```py
from ultralytics.hub import check_dataset
check_dataset('path/to/coco8.zip', task='detect') # detect dataset
check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
check_dataset('path/to/dota8.zip', task='obb') # OBB dataset
check_dataset('path/to/imagenet10.zip', task='classify') # classification dataset
```
"""# 使用 HUBDatasetStats 类检查指定路径下的数据集文件(zip 格式),并为指定任务类型生成 JSON 格式的统计信息
HUBDatasetStats(path=path, task=task).get_json()
# 记录检查操作成功完成的信息
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
# 从 pathlib 模块导入 Path 类,用于处理文件路径from pathlib import Path
# 导入 PyTorch 库import torch
# 从 Ultralytics 引擎的 model 模块中导入 Model 类from ultralytics.engine.model import Model
# 从 Ultralytics 的 utils 模块中导入下载相关的函数from ultralytics.utils.downloads import attempt_download_asset
# 从 Ultralytics 的 utils 模块中导入与 PyTorch 相关的工具函数from ultralytics.utils.torch_utils import model_info
# 导入当前目录下的 predict.py 文件中的 NASPredictor 类from .predict import NASPredictor
# 导入当前目录下的 val.py 文件中的 NASValidator 类from .val import NASValidator
classNAS(Model):
"""
YOLO NAS model for object detection.
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
Example:
```python
from ultralytics import NAS
model = NAS('yolo_nas_s')
results = model.predict('ultralytics/assets/bus.jpg')
```py
Attributes:
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
Note:
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
"""def__init__(self, model="yolo_nas_s.pt") -> None:
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""# 断言所提供的模型文件不是 YAML 配置文件,因为 YOLO-NAS 模型仅支持预训练模型assert Path(model).suffix notin {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."# 调用父类 Model 的初始化方法,传入模型路径和任务类型为 "detect"super().__init__(model, task="detect")
def_load(self, weights: str, task=None) -> None:
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""# 动态导入 super_gradients 模块,用于加载模型权重import super_gradients
# 获取权重文件的后缀名
suffix = Path(weights).suffix
# 如果后缀为 ".pt",则加载模型权重if suffix == ".pt":
self.model = torch.load(attempt_download_asset(weights))
# 如果后缀为空字符串,则根据权重名称获取预训练的 COCO 权重elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# 重写模型的 forward 方法,忽略额外的参数defnew_forward(x, *args, **kwargs):
"""Ignore additional __call__ arguments."""return self.model._original_forward(x)
# 保存原始的 forward 方法,并将新的 forward 方法赋值给模型
self.model._original_forward = self.model.forward
self.model.forward = new_forward
# 标准化模型的属性
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False# for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = "detect"# for export()# 定义一个方法用于记录模型信息definfo(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""# 调用 model_info 函数,传入模型对象和其他参数,并返回结果return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
@property# 定义一个属性,返回一个字典,将任务映射到相应的预测器和验证器类deftask_map(self):
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""# 返回包含映射关系的字典return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· 如何使用 Uni-app 实现视频聊天(源码,支持安卓、iOS)
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
2023-09-05 【Python 自动化】自媒体剪辑第一版·思路简述与技术方案
2021-09-05 数据科学 IPython 笔记本 9.1 NumPy