Yolov8-源码解析-三十三-

Yolov8 源码解析(三十三)

.\yolov8\ultralytics\models\nas\predict.py

# 导入 PyTorch 库
import torch

# 从 Ultralytics 引擎中导入基础预测器、结果和操作工具
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops


class NASPredictor(BasePredictor):
    """
    Ultralytics YOLO NAS 预测器,用于目标检测。

    这个类扩展了 Ultralytics 引擎中的 `BasePredictor`,负责对 YOLO NAS 模型生成的原始预测进行后处理。
    它应用了非极大值抑制和缩放边界框以适应原始图像尺寸等操作。

    Attributes:
        args (Namespace): 包含各种后处理配置的命名空间。

    Example:
        ```python
        from ultralytics import NAS

        model = NAS('yolo_nas_s')
        predictor = model.predictor
        # 假设 raw_preds, img, orig_imgs 可用
        results = predictor.postprocess(raw_preds, img, orig_imgs)
        ```py

    Note:
        通常情况下,不会直接实例化这个类,而是在 `NAS` 类的内部使用。

    """

    def postprocess(self, preds_in, img, orig_imgs):
        """后处理预测结果并返回 Results 对象的列表。"""

        # 将预测结果转换为 xywh 格式的边界框
        boxes = ops.xyxy2xywh(preds_in[0][0])
        # 将边界框和类别分数连接起来,并进行维度变换
        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)

        # 应用非极大值抑制处理预测结果
        preds = ops.non_max_suppression(
            preds,
            self.args.conf,
            self.args.iou,
            agnostic=self.args.agnostic_nms,
            max_det=self.args.max_det,
            classes=self.args.classes,
        )

        # 如果输入图像不是列表而是 torch.Tensor,则转换为 numpy 数组的批量
        if not isinstance(orig_imgs, list):
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        # 初始化结果列表
        results = []
        # 遍历每个预测结果、原始图像和图像路径,生成 Results 对象并添加到 results 列表中
        for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
            # 缩放边界框以适应原始图像尺寸
            pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        # 返回最终的 results 列表
        return results

.\yolov8\ultralytics\models\nas\val.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import torch  # 导入PyTorch库

from ultralytics.models.yolo.detect import DetectionValidator  # 导入检测验证器类
from ultralytics.utils import ops  # 导入工具函数

__all__ = ["NASValidator"]  # 定义模块的公开接口,只包含NASValidator类

class NASValidator(DetectionValidator):
    """
    Ultralytics YOLO NAS Validator for object detection.

    Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
    generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
    ultimately producing the final detections.

    Attributes:
        args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
        lb (torch.Tensor): Optional tensor for multilabel NMS.

    Example:
        ```py
        from ultralytics import NAS

        model = NAS('yolo_nas_s')
        validator = model.validator
        # Assumes that raw_preds are available
        final_preds = validator.postprocess(raw_preds)
        ```

    Note:
        This class is generally not instantiated directly but is used internally within the `NAS` class.
    """

    def postprocess(self, preds_in):
        """Apply Non-maximum suppression to prediction outputs."""
        # 将预测框格式转换为中心宽高格式
        boxes = ops.xyxy2xywh(preds_in[0][0])
        # 将预测结果与框的坐标信息合并,并对维度进行调整
        preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
        # 执行非极大值抑制,筛选出最终的检测结果
        return ops.non_max_suppression(
            preds,
            self.args.conf,  # 置信度阈值
            self.args.iou,   # IoU 阈值
            labels=self.lb,  # 多标签NMS的标签张量,可选
            multi_label=False,  # 是否进行多标签NMS
            agnostic=self.args.single_cls,  # 是否进行类别无关的NMS
            max_det=self.args.max_det,  # 最大检测数目
            max_time_img=0.5,  # 最大处理时间
        )

.\yolov8\ultralytics\models\nas\__init__.py

# 导入自定义模块中的 NAS 模型类
from .model import NAS

# 导入自定义模块中的 NASPredictor 类
from .predict import NASPredictor

# 导入自定义模块中的 NASValidator 类
from .val import NASValidator

# 设置 __all__ 变量,指定在使用 from module import * 时导入的符号列表
__all__ = "NASPredictor", "NASValidator", "NAS"

.\yolov8\ultralytics\models\rtdetr\model.py

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
hybrid encoder and IoU-aware query selection for enhanced detection accuracy.

For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
"""

from ultralytics.engine.model import Model  # 导入 Model 类
from ultralytics.nn.tasks import RTDETRDetectionModel  # 导入 RTDETRDetectionModel 类

from .predict import RTDETRPredictor  # 导入 RTDETRPredictor 类
from .train import RTDETRTrainer  # 导入 RTDETRTrainer 类
from .val import RTDETRValidator  # 导入 RTDETRValidator 类


class RTDETR(Model):
    """
    Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
    with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.

    Attributes:
        model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
    """

    def __init__(self, model="rtdetr-l.pt") -> None:
        """
        Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.

        Args:
            model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.

        Raises:
            NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
        """
        # 调用父类 Model 的构造函数,初始化模型和任务为 'detect'
        super().__init__(model=model, task="detect")

    @property
    def task_map(self) -> dict:
        """
        Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.

        Returns:
            dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
        """
        # 返回一个任务映射字典,将任务名称映射到相应的 Ultralytics 类
        return {
            "detect": {
                "predictor": RTDETRPredictor,  # 预测器类
                "validator": RTDETRValidator,  # 验证器类
                "trainer": RTDETRTrainer,      # 训练器类
                "model": RTDETRDetectionModel,  # 检测模型类
            }
        }

.\yolov8\ultralytics\models\rtdetr\predict.py

# 导入PyTorch库
import torch

# 导入图像处理库中的LetterBox类
from ultralytics.data.augment import LetterBox
# 导入预测器基类BasePredictor
from ultralytics.engine.predictor import BasePredictor
# 导入结果处理类Results
from ultralytics.engine.results import Results
# 导入工具操作库中的ops模块
from ultralytics.utils import ops

# 定义RT-DETR预测器类,继承自BasePredictor
class RTDETRPredictor(BasePredictor):
    """
    RT-DETR(Real-Time Detection Transformer)预测器,扩展自BasePredictor类,用于使用Baidu的RT-DETR模型进行预测。

    该类利用Vision Transformers实现实时目标检测,并保持高准确性。支持高效的混合编码和IoU感知的查询选择。

    示例:
        ```python
        from ultralytics.utils import ASSETS
        from ultralytics.models.rtdetr import RTDETRPredictor

        args = dict(model='rtdetr-l.pt', source=ASSETS)
        predictor = RTDETRPredictor(overrides=args)
        predictor.predict_cli()
        ```py

    属性:
        imgsz (int): 推断时的图像尺寸(必须是方形且填充为比例尺寸)。
        args (dict): 预测器的参数覆盖。

    """

    # 定义后处理方法,用于从模型的原始预测生成边界框和置信度分数
    def postprocess(self, preds, img, orig_imgs):
        """
        后处理方法,从模型的原始预测生成边界框和置信度分数。

        该方法基于置信度和类别(如果在self.args中指定)筛选检测结果。

        Args:
            preds (list): 模型的预测结果列表。
            img (torch.Tensor): 处理过的输入图像。
            orig_imgs (list or torch.Tensor): 原始未处理的图像。

        Returns:
            (list[Results]): 包含后处理边界框、置信度分数和类别标签的Results对象列表。
        """
        if not isinstance(preds, (list, tuple)):  # 对于PyTorch推断,预测结果是列表,但对于导出推断,预测结果是列表中的第一个张量
            preds = [preds, None]

        nd = preds[0].shape[-1]
        bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

        if not isinstance(orig_imgs, list):  # 输入图像是torch.Tensor而不是列表
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)

        results = []
        for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]):  # (300, 4)
            bbox = ops.xywh2xyxy(bbox)
            max_score, cls = score.max(-1, keepdim=True)  # (300, 1)
            idx = max_score.squeeze(-1) > self.args.conf  # (300, )
            if self.args.classes is not None:
                idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
            pred = torch.cat([bbox, max_score, cls], dim=-1)[idx]  # 进行过滤
            oh, ow = orig_img.shape[:2]
            pred[..., [0, 2]] *= ow
            pred[..., [1, 3]] *= oh
            results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
        return results
    # 定义一个方法用于在模型推理之前对输入图像进行预处理。
    # 输入图像将进行letterboxing以确保正方形纵横比并进行scale-fill处理。
    # 图像的大小必须是640x640,并且要进行scale-fill处理。

    def pre_transform(self, im):
        """
        Pre-transforms the input images before feeding them into the model for inference. The input images are
        letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.

        Args:
            im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.

        Returns:
            (list): List of pre-transformed images ready for model inference.
        """
        # 创建一个LetterBox对象,用于进行图像的letterboxing操作,保持图像尺寸为指定的self.imgsz大小,自动缩放填充。
        letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
        # 对输入的每张图像进行预处理,应用上述定义的LetterBox对象进行处理。
        return [letterbox(image=x) for x in im]

.\yolov8\ultralytics\models\rtdetr\train.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入所需模块和库
from copy import copy
import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator

# 定义 RT-DETRTrainer 类,继承自 DetectionTrainer 类
class RTDETRTrainer(DetectionTrainer):
    """
    Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
    class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
    Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.

    Notes:
        - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
        - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.

    Example:
        ```py
        from ultralytics.models.rtdetr.train import RTDETRTrainer

        args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
        trainer = RTDETRTrainer(overrides=args)
        trainer.train()
        ```
    """

    # 获取模型方法,初始化并返回用于对象检测任务的 RT-DETR 模型
    def get_model(self, cfg=None, weights=None, verbose=True):
        """
        Initialize and return an RT-DETR model for object detection tasks.

        Args:
            cfg (dict, optional): Model configuration. Defaults to None.
            weights (str, optional): Path to pre-trained model weights. Defaults to None.
            verbose (bool): Verbose logging if True. Defaults to True.

        Returns:
            (RTDETRDetectionModel): Initialized model.
        """
        model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
        if weights:
            model.load(weights)
        return model

    # 构建数据集方法,返回用于训练或验证的 RT-DETR 数据集对象
    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build and return an RT-DETR dataset for training or validation.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): Dataset mode, either 'train' or 'val'.
            batch (int, optional): Batch size for rectangle training. Defaults to None.

        Returns:
            (RTDETRDataset): Dataset object for the specific mode.
        """
        return RTDETRDataset(
            img_path=img_path,
            imgsz=self.args.imgsz,
            batch_size=batch,
            augment=mode == "train",
            hyp=self.args,
            rect=False,
            cache=self.args.cache or None,
            prefix=colorstr(f"{mode}: "),
            data=self.data,
        )

    # 获取验证器方法,返回适用于 RT-DETR 模型验证的 DetectionValidator 对象
    def get_validator(self):
        """
        Returns a DetectionValidator suitable for RT-DETR model validation.

        Returns:
            (RTDETRValidator): Validator object for model validation.
        """
        self.loss_names = "giou_loss", "cls_loss", "l1_loss"
        return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
    # 继承父类方法,预处理图像批次。将图像缩放并转换为浮点格式。
    def preprocess_batch(self, batch):
        """
        Preprocess a batch of images. Scales and converts the images to float format.

        Args:
            batch (dict): Dictionary containing a batch of images, bboxes, and labels.

        Returns:
            (dict): Preprocessed batch.
        """
        # 调用父类的预处理方法,获取预处理后的批次数据
        batch = super().preprocess_batch(batch)
        
        # 获取批次中图像的数量
        bs = len(batch["img"])
        
        # 获取当前批次的索引
        batch_idx = batch["batch_idx"]
        
        # 初始化用于存储真实边界框和类别的列表
        gt_bbox, gt_class = [], []
        
        # 遍历批次中的每张图像
        for i in range(bs):
            # 将当前批次索引等于 i 的边界框添加到 gt_bbox 中,并将其移到相应设备上
            gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
            
            # 将当前批次索引等于 i 的类别添加到 gt_class 中,并将其移到相应设备上
            gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
        
        # 返回预处理后的批次数据
        return batch

.\yolov8\ultralytics\models\rtdetr\val.py

import torch  # 导入PyTorch库

from ultralytics.data import YOLODataset  # 导入YOLODataset类
from ultralytics.data.augment import Compose, Format, v8_transforms  # 导入数据增强相关类和函数
from ultralytics.models.yolo.detect import DetectionValidator  # 导入目标检测验证器类
from ultralytics.utils import colorstr, ops  # 导入颜色字符串处理和操作相关工具函数

__all__ = ("RTDETRValidator",)  # 定义可导出的模块成员名称,此处为单元素元组

class RTDETRDataset(YOLODataset):
    """
    Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.

    This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
    real-time detection and tracking tasks.
    """

    def __init__(self, *args, data=None, **kwargs):
        """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
        super().__init__(*args, data=data, **kwargs)  # 调用父类YOLODataset的初始化方法

    # NOTE: add stretch version load_image for RTDETR mosaic
    def load_image(self, i, rect_mode=False):
        """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
        return super().load_image(i=i, rect_mode=rect_mode)  # 调用父类YOLODataset的load_image方法

    def build_transforms(self, hyp=None):
        """Temporary, only for evaluation."""
        if self.augment:
            hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
            hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
            transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)  # 使用v8_transforms函数创建变换列表
        else:
            # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
            transforms = Compose([])  # 如果不进行数据增强,则使用空的变换列表
        transforms.append(
            Format(
                bbox_format="xywh",  # 边界框格式设置为(x, y, width, height)
                normalize=True,  # 归一化图像像素值
                return_mask=self.use_segments,  # 根据use_segments参数返回掩膜
                return_keypoint=self.use_keypoints,  # 根据use_keypoints参数返回关键点
                batch_idx=True,  # 返回带有批次索引的数据
                mask_ratio=hyp.mask_ratio,  # 掩膜比率
                mask_overlap=hyp.overlap_mask,  # 掩膜重叠
            )
        )
        return transforms  # 返回最终的数据变换列表

class RTDETRValidator(DetectionValidator):
    """
    RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
    the RT-DETR (Real-Time DETR) object detection model.

    The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
    post-processing, and updates evaluation metrics accordingly.

    Example:
        ```python
        from ultralytics.models.rtdetr import RTDETRValidator

        args = dict(model='rtdetr-l.pt', data='coco8.yaml')
        validator = RTDETRValidator(args=args)
        validator()
        ```py

    Note:
        For further details on the attributes and methods, refer to the parent DetectionValidator class.
    """
    def build_dataset(self, img_path, mode="val", batch=None):
        """
        Build an RTDETR Dataset.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        # 构建一个 RTDETRDataset 对象,用于处理数据集
        return RTDETRDataset(
            img_path=img_path,  # 图片文件夹路径
            imgsz=self.args.imgsz,  # 图像尺寸
            batch_size=batch,  # 批大小,用于 `rect` 参数
            augment=False,  # 不进行数据增强
            hyp=self.args,  # 模型超参数
            rect=False,  # 不进行 rect 操作
            cache=self.args.cache or None,  # 缓存,如果未设置则为空
            prefix=colorstr(f"{mode}: "),  # 日志前缀,基于 mode 参数
            data=self.data,  # 数据对象
        )

    def postprocess(self, preds):
        """Apply Non-maximum suppression to prediction outputs."""
        if not isinstance(preds, (list, tuple)):  # 如果 preds 不是 list 或 tuple 类型
            preds = [preds, None]  # 将 preds 转换为列表形式

        bs, _, nd = preds[0].shape  # 获取预测结果的形状信息
        bboxes, scores = preds[0].split((4, nd - 4), dim=-1)  # 将预测结果拆分为边界框和分数
        bboxes *= self.args.imgsz  # 根据图像尺寸调整边界框坐标
        outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs  # 初始化输出列表

        for i, bbox in enumerate(bboxes):  # 遍历每个边界框
            bbox = ops.xywh2xyxy(bbox)  # 将边界框从 (x, y, w, h) 格式转换为 (x1, y1, x2, y2)
            score, cls = scores[i].max(-1)  # 获取最大分数和对应的类别
            # 不需要阈值进行评估,因为这里只有 300 个边界框
            pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1)  # 组合边界框、分数和类别
            # 根据置信度对预测结果排序,确保内部指标的正确性
            pred = pred[score.argsort(descending=True)]  # 按照分数降序排序
            outputs[i] = pred  # 将排序后的结果存入输出列表

        return outputs  # 返回处理后的预测结果列表

    def _prepare_batch(self, si, batch):
        """Prepares a batch for training or inference by applying transformations."""
        idx = batch["batch_idx"] == si  # 获取与 si 对应的批次索引
        cls = batch["cls"][idx].squeeze(-1)  # 获取对应批次的类别信息并去除多余的维度
        bbox = batch["bboxes"][idx]  # 获取对应批次的边界框信息
        ori_shape = batch["ori_shape"][si]  # 获取原始图像形状
        imgsz = batch["img"].shape[2:]  # 获取图像尺寸
        ratio_pad = batch["ratio_pad"][si]  # 获取比例填充信息

        if len(cls):  # 如果类别信息不为空
            bbox = ops.xywh2xyxy(bbox)  # 将边界框从 (x, y, w, h) 格式转换为 (x1, y1, x2, y2)
            bbox[..., [0, 2]] *= ori_shape[1]  # 将 x 轴坐标根据原始形状缩放
            bbox[..., [1, 3]] *= ori_shape[0]  # 将 y 轴坐标根据原始形状缩放

        return {
            "cls": cls,  # 类别信息
            "bbox": bbox,  # 边界框信息
            "ori_shape": ori_shape,  # 原始图像形状
            "imgsz": imgsz,  # 图像尺寸
            "ratio_pad": ratio_pad  # 比例填充信息
        }

    def _prepare_pred(self, pred, pbatch):
        """Prepares and returns a batch with transformed bounding boxes and class labels."""
        predn = pred.clone()  # 复制预测结果
        predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz  # 根据原始形状和图像尺寸调整边界框 x 轴坐标
        predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz  # 根据原始形状和图像尺寸调整边界框 y 轴坐标
        return predn.float()  # 返回调整后的预测结果

.\yolov8\ultralytics\models\rtdetr\__init__.py

# 导入模块,使用相对路径从当前包中导入相关模块和类
from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator

# 定义 __all__ 变量,指定当前包的公共接口,以便通过 `from package import *` 导入指定符号
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"

.\yolov8\ultralytics\models\sam\amg.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# 导入标准数学库
import math
# 导入 itertools 中的 product 函数
from itertools import product
# 导入类型提示相关库
from typing import Any, Generator, List, Tuple

# 导入第三方库 numpy 和 torch
import numpy as np
import torch


def is_box_near_crop_edge(
    boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
) -> torch.Tensor:
    """Return a boolean tensor indicating if boxes are near the crop edge."""
    # 将 crop_box 和 orig_box 转换为 torch.Tensor,并使用与 boxes 相同的设备
    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
    # 调用 uncrop_boxes_xyxy 函数并将其结果转换为 float 类型
    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
    # 检查 boxes 是否在 crop 边缘附近,使用绝对容差 atol 进行比较
    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
    # 检查 boxes 是否在原始图像边缘附近,使用绝对容差 atol 进行比较
    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
    # 将 near_crop_edge 与 ~near_image_edge 逻辑与操作,以排除原始图像边缘的情况
    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
    # 检查是否有任何 boxes 在 crop 边缘附近,返回结果作为 boolean tensor
    return torch.any(near_crop_edge, dim=1)


def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
    """Yield batches of data from the input arguments."""
    # 断言确保 args 不为空且每个参数的长度与第一个参数相同,用于批处理迭代
    assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
    # 计算需要生成的批次数量
    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
    # 生成器函数,按批次生成输入参数的数据
    for b in range(n_batches):
        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]


def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
    """
    Computes the stability score for a batch of masks.

    The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
    and low values.

    Notes:
        - One mask is always contained inside the other.
        - Save memory by preventing unnecessary cast to torch.int64
    """
    # 计算高阈值和低阈值下的二进制掩模的交集和并集
    intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    # 计算稳定性分数,即交集除以并集
    return intersections / unions


def build_point_grid(n_per_side: int) -> np.ndarray:
    """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
    # 计算每个边上均匀分布的点的偏移量
    offset = 1 / (2 * n_per_side)
    # 在 [offset, 1-offset] 区间内生成 n_per_side 个均匀分布的点
    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
    # 使用 np.tile 创建完整的网格
    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
    # 将点的 x 和 y 坐标堆叠起来,生成最终的点网格并返回
    return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)


def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
    """Generate point grids for all crop layers."""
    # 生成所有裁剪层的点网格
    return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]


def generate_crop_boxes(
    im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
) -> Tuple[List[List[int]], List[int]]:
    """
    Generates a list of crop boxes of different sizes.
    
    # 代码未完成,需要继续补充完整
    Each layer has (2**i)**2 boxes for the ith layer.
    """

    # 初始化空列表,用于存储裁剪框和图层索引
    crop_boxes, layer_idxs = [], []
    # 获取输入图像的高度和宽度
    im_h, im_w = im_size
    # 计算图像的较短边
    short_side = min(im_h, im_w)

    # 原始图像的裁剪框,表示整个图像
    crop_boxes.append([0, 0, im_w, im_h])
    layer_idxs.append(0)

    def crop_len(orig_len, n_crops, overlap):
        """Crops bounding boxes to the size of the input image."""
        # 根据输入图像的大小裁剪边界框
        return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))

    # 循环生成各层的裁剪框
    for i_layer in range(n_layers):
        # 每层的裁剪数量是2的(i_layer + 1)次方
        n_crops_per_side = 2 ** (i_layer + 1)
        # 计算重叠区域的大小
        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))

        # 计算裁剪框的宽度和高度
        crop_w = crop_len(im_w, n_crops_per_side, overlap)
        crop_h = crop_len(im_h, n_crops_per_side, overlap)

        # 计算裁剪框左上角的坐标
        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]

        # 以XYWH格式进行裁剪
        for x0, y0 in product(crop_box_x0, crop_box_y0):
            # 根据左上角坐标和裁剪框的宽高计算裁剪框
            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
            # 将裁剪框添加到列表中
            crop_boxes.append(box)
            # 记录当前裁剪框属于的图层索引
            layer_idxs.append(i_layer + 1)

    # 返回裁剪框列表和图层索引列表作为结果
    return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
    """Uncrop bounding boxes by adding the crop box offset."""
    # Extract the top-left corner coordinates of the crop box
    x0, y0, _, _ = crop_box
    # Create an offset tensor based on the crop box coordinates
    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
    # Check if the boxes tensor has a channel dimension
    if len(boxes.shape) == 3:
        offset = offset.unsqueeze(1)
    # Add the offset to the boxes tensor to uncrop them
    return boxes + offset


def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
    """Uncrop points by adding the crop box offset."""
    # Extract the top-left corner coordinates of the crop box
    x0, y0, _, _ = crop_box
    # Create an offset tensor based on the crop box coordinates
    offset = torch.tensor([[x0, y0]], device=points.device)
    # Check if the points tensor has a channel dimension
    if len(points.shape) == 3:
        offset = offset.unsqueeze(1)
    # Add the offset to the points tensor to uncrop them
    return points + offset


def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
    """Uncrop masks by padding them to the original image size."""
    # Extract the crop box coordinates
    x0, y0, x1, y1 = crop_box
    # Check if the crop box covers the entire original image
    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
        return masks
    # Calculate the padding required to restore the masks to original size
    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
    pad = (x0, pad_x - x0, y0, pad_y - y0)
    # Pad the masks tensor to the original size with zeros
    return torch.nn.functional.pad(masks, pad, value=0)


def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
    """Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
    import cv2  # type: ignore

    # Ensure the mode is valid
    assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
    # Determine whether to correct holes or islands based on mode
    correct_holes = mode == "holes"
    # Convert mask to binary and invert if correcting holes
    working_mask = (correct_holes ^ mask).astype(np.uint8)
    # Perform connected component analysis to find regions
    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
    # Extract region sizes
    sizes = stats[:, -1][1:]  # Row 0 is background label
    # Identify small regions based on area threshold
    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
    # If no small regions found, return original mask
    if not small_regions:
        return mask, False
    # Create list of labels to fill (small regions)
    fill_labels = [0] + small_regions
    # If not correcting holes, keep only the largest region if all are below threshold
    if not correct_holes:
        fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
    # Generate mask with only specified fill labels
    mask = np.isin(regions, fill_labels)
    return mask, True


def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
    """
    Calculates boxes in XYXY format around masks.

    Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
    """
    # Return zeros if masks tensor is empty
    if torch.numel(masks) == 0:
        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)

    # Normalize masks to shape CxHxW
    shape = masks.shape
    h, w = shape[-2:]
    masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
    # Compute top edges and their coordinates
    in_height, _ = torch.max(masks, dim=-1)
    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
    # Calculate bottom edges based on top edges
    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
    # 计算输入高度坐标
    in_height_coords = in_height_coords + h * (~in_height)
    # 获取顶部边缘坐标
    top_edges, _ = torch.min(in_height_coords, dim=-1)

    # 获取左右边缘
    # 计算输入宽度
    in_width, _ = torch.max(masks, dim=-2)
    # 计算宽度坐标
    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
    # 获取右边缘坐标
    right_edges, _ = torch.max(in_width_coords, dim=-1)
    # 更新宽度坐标,处理超出边界情况
    in_width_coords = in_width_coords + w * (~in_width)
    # 获取左边缘坐标
    left_edges, _ = torch.min(in_width_coords, dim=-1)

    # 如果掩码为空,则右边缘会在左边缘左侧,或者底部边缘在顶部边缘上方。
    # 将这些框替换为 [0, 0, 0, 0]
    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
    # 组合左上右下边缘坐标
    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
    # 根据空过滤器将不合规的框设置为零
    out = out * (~empty_filter).unsqueeze(-1)

    # 返回到原始形状
    return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]

.\yolov8\ultralytics\models\sam\build.py

# 导入 functools 模块中的 partial 函数
from functools import partial

# 导入 torch 库
import torch

# 导入下载函数 attempt_download_asset
from ultralytics.utils.downloads import attempt_download_asset

# 导入模块中的类和函数
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer


def build_sam_vit_h(checkpoint=None):
    """构建并返回一个 SAM h-size 模型。"""
    return _build_sam(
        encoder_embed_dim=1280,  # 编码器嵌入维度
        encoder_depth=32,  # 编码器深度
        encoder_num_heads=16,  # 编码器头数
        encoder_global_attn_indexes=[7, 15, 23, 31],  # 全局注意力索引
        checkpoint=checkpoint,  # 检查点
    )


def build_sam_vit_l(checkpoint=None):
    """构建并返回一个 SAM l-size 模型。"""
    return _build_sam(
        encoder_embed_dim=1024,  # 编码器嵌入维度
        encoder_depth=24,  # 编码器深度
        encoder_num_heads=16,  # 编码器头数
        encoder_global_attn_indexes=[5, 11, 17, 23],  # 全局注意力索引
        checkpoint=checkpoint,  # 检查点
    )


def build_sam_vit_b(checkpoint=None):
    """构建并返回一个 SAM b-size 模型。"""
    return _build_sam(
        encoder_embed_dim=768,  # 编码器嵌入维度
        encoder_depth=12,  # 编码器深度
        encoder_num_heads=12,  # 编码器头数
        encoder_global_attn_indexes=[2, 5, 8, 11],  # 全局注意力索引
        checkpoint=checkpoint,  # 检查点
    )


def build_mobile_sam(checkpoint=None):
    """构建并返回 Mobile-SAM 模型。"""
    return _build_sam(
        encoder_embed_dim=[64, 128, 160, 320],  # 编码器嵌入维度列表
        encoder_depth=[2, 2, 6, 2],  # 编码器深度列表
        encoder_num_heads=[2, 4, 5, 10],  # 编码器头数列表
        encoder_global_attn_indexes=None,  # 全局注意力索引
        mobile_sam=True,  # 是否是移动 SAM
        checkpoint=checkpoint,  # 检查点
    )


def _build_sam(
    encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
):
    """构建选定的 SAM 模型架构。"""
    prompt_embed_dim = 256  # 提示嵌入维度
    image_size = 1024  # 图像尺寸
    vit_patch_size = 16  # ViT 补丁大小
    image_embedding_size = image_size // vit_patch_size  # 图像嵌入大小
    # 创建图像编码器对象,根据条件选择不同的实现方式:TinyViT 或 ImageEncoderViT
    image_encoder = (
        TinyViT(
            img_size=1024,
            in_chans=3,
            num_classes=1000,
            embed_dims=encoder_embed_dim,
            depths=encoder_depth,
            num_heads=encoder_num_heads,
            window_sizes=[7, 7, 14, 7],
            mlp_ratio=4.0,
            drop_rate=0.0,
            drop_path_rate=0.0,
            use_checkpoint=False,
            mbconv_expand_ratio=4.0,
            local_conv_size=3,
            layer_lr_decay=0.8,
        )
        if mobile_sam  # 如果 mobile_sam 变量为真,则使用 TinyViT
        else ImageEncoderViT(  # 否则使用 ImageEncoderViT
            depth=encoder_depth,
            embed_dim=encoder_embed_dim,
            img_size=image_size,
            mlp_ratio=4,
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            num_heads=encoder_num_heads,
            patch_size=vit_patch_size,
            qkv_bias=True,
            use_rel_pos=True,
            global_attn_indexes=encoder_global_attn_indexes,
            window_size=14,
            out_chans=prompt_embed_dim,
        )
    )
    
    # 创建 SAM 模型对象,包括图像编码器、提示编码器和蒙版解码器
    sam = Sam(
        image_encoder=image_encoder,
        prompt_encoder=PromptEncoder(
            embed_dim=prompt_embed_dim,
            image_embedding_size=(image_embedding_size, image_embedding_size),
            input_image_size=(image_size, image_size),
            mask_in_chans=16,
        ),
        mask_decoder=MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
        ),
        pixel_mean=[123.675, 116.28, 103.53],  # 像素均值用于数据标准化
        pixel_std=[58.395, 57.12, 57.375],  # 像素标准差用于数据标准化
    )
    
    # 如果提供了检查点文件路径,则加载模型状态字典
    if checkpoint is not None:
        checkpoint = attempt_download_asset(checkpoint)  # 尝试下载检查点文件
        with open(checkpoint, "rb") as f:
            state_dict = torch.load(f)  # 加载检查点的状态字典
        sam.load_state_dict(state_dict)  # 将加载的状态字典应用到 SAM 模型
    
    sam.eval()  # 设置 SAM 模型为评估模式
    
    # 返回配置好的 SAM 模型对象
    return sam
# SAM 模型映射,将模型文件名映射到对应的构建函数
sam_model_map = {
    "sam_h.pt": build_sam_vit_h,    # 如果文件名以 "sam_h.pt" 结尾,则使用 build_sam_vit_h 构建函数
    "sam_l.pt": build_sam_vit_l,    # 如果文件名以 "sam_l.pt" 结尾,则使用 build_sam_vit_l 构建函数
    "sam_b.pt": build_sam_vit_b,    # 如果文件名以 "sam_b.pt" 结尾,则使用 build_sam_vit_b 构建函数
    "mobile_sam.pt": build_mobile_sam,  # 如果文件名为 "mobile_sam.pt",则使用 build_mobile_sam 构建函数
}

# 构建 SAM 模型的函数,根据给定的检查点(ckpt)选择合适的构建函数
def build_sam(ckpt="sam_b.pt"):
    """Build a SAM model specified by ckpt."""
    model_builder = None
    ckpt = str(ckpt)  # 将检查点转换为字符串类型,以支持路径检查点类型

    # 遍历 SAM 模型映射中的每个键(模型文件名)
    for k in sam_model_map.keys():
        # 如果给定的检查点(ckpt)以当前模型文件名(k)结尾,则选择对应的构建函数
        if ckpt.endswith(k):
            model_builder = sam_model_map.get(k)

    # 如果未找到匹配的模型构建函数,则抛出文件未找到异常
    if not model_builder:
        raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")

    # 使用选定的模型构建函数构建模型,并返回结果
    return model_builder(ckpt)

.\yolov8\ultralytics\models\sam\model.py

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM model interface.

This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
image distributions and tasks without prior knowledge.

Key Features:
    - Promptable segmentation
    - Real-time performance
    - Zero-shot transfer capabilities
    - Trained on SA-1B dataset
"""

from pathlib import Path

from ultralytics.engine.model import Model  # 导入Ultralytics的Model类
from ultralytics.utils.torch_utils import model_info  # 导入模型信息工具函数

from .build import build_sam  # 导入SAM模型构建函数
from .predict import Predictor  # 导入预测器类


class SAM(Model):
    """
    SAM (Segment Anything Model) interface class.

    SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
    bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
    dataset.
    """

    def __init__(self, model="sam_b.pt") -> None:
        """
        Initializes the SAM model with a pre-trained model file.

        Args:
            model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.

        Raises:
            NotImplementedError: If the model file extension is not .pt or .pth.
        """
        # 检查模型文件是否是以.pt或.pth结尾,如果不是则抛出异常
        if model and Path(model).suffix not in {".pt", ".pth"}:
            raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
        # 调用父类构造函数初始化模型
        super().__init__(model=model, task="segment")

    def _load(self, weights: str, task=None):
        """
        Loads the specified weights into the SAM model.

        Args:
            weights (str): Path to the weights file.
            task (str, optional): Task name. Defaults to None.
        """
        # 使用指定的权重文件构建SAM模型
        self.model = build_sam(weights)
    def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        执行给定图像或视频源的分割预测。

        Args:
            source (str): 图像或视频文件的路径,或者是一个 PIL.Image 对象,或者是一个 numpy.ndarray 对象。
            stream (bool, optional): 如果为 True,则启用实时流处理。默认为 False。
            bboxes (list, optional): 提示的分割边界框坐标列表。默认为 None。
            points (list, optional): 提示的分割点列表。默认为 None。
            labels (list, optional): 提示的分割标签列表。默认为 None。

        Returns:
            (list): 模型的预测结果。
        """
        # 设置默认的参数覆盖值
        overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
        kwargs.update(overrides)
        # 组装提示信息的字典
        prompts = dict(bboxes=bboxes, points=points, labels=labels)
        # 调用父类的预测方法,并传递参数
        return super().predict(source, stream, prompts=prompts, **kwargs)

    def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
        """
        'predict' 方法的别名。

        Args:
            source (str): 图像或视频文件的路径,或者是一个 PIL.Image 对象,或者是一个 numpy.ndarray 对象。
            stream (bool, optional): 如果为 True,则启用实时流处理。默认为 False。
            bboxes (list, optional): 提示的分割边界框坐标列表。默认为 None。
            points (list, optional): 提示的分割点列表。默认为 None。
            labels (list, optional): 提示的分割标签列表。默认为 None。

        Returns:
            (list): 模型的预测结果。
        """
        # 调用 'predict' 方法进行预测
        return self.predict(source, stream, bboxes, points, labels, **kwargs)

    def info(self, detailed=False, verbose=True):
        """
        记录有关 SAM 模型的信息。

        Args:
            detailed (bool, optional): 如果为 True,则显示关于模型的详细信息。默认为 False。
            verbose (bool, optional): 如果为 True,则在控制台上显示信息。默认为 True。

        Returns:
            (tuple): 包含模型信息的元组。
        """
        # 调用 model_info 函数获取模型信息
        return model_info(self.model, detailed=detailed, verbose=verbose)

    @property
    def task_map(self):
        """
        提供从 'segment' 任务到其相应的 'Predictor' 的映射。

        Returns:
            (dict): 将 'segment' 任务映射到其相应 'Predictor' 的字典。
        """
        # 返回 'segment' 任务到 'Predictor' 类的映射字典
        return {"segment": {"predictor": Predictor}}

.\yolov8\ultralytics\models\sam\modules\decoders.py

# 导入所需的模块和类
# 使用类型提示,指定导入的类型为 List, Tuple, Type
from typing import List, Tuple, Type

# 导入 PyTorch 库
import torch
# 导入 PyTorch 中的 nn 模块
from torch import nn
# 导入 PyTorch 中的 functional 模块,并简称为 F
from torch.nn import functional as F

# 从 ultralytics.nn.modules 中导入 LayerNorm2d 类
from ultralytics.nn.modules import LayerNorm2d

# 定义一个名为 MaskDecoder 的 nn.Module 类
class MaskDecoder(nn.Module):
    """
    Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
    masks given image and prompt embeddings.

    Attributes:
        transformer_dim (int): Channel dimension for the transformer module.
        transformer (nn.Module): The transformer module used for mask prediction.
        num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
        iou_token (nn.Embedding): Embedding for the IoU token.
        num_mask_tokens (int): Number of mask tokens.
        mask_tokens (nn.Embedding): Embedding for the mask tokens.
        output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
        output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
        iou_prediction_head (nn.Module): MLP for predicting mask quality.
    """

    # 定义初始化方法,接受多个参数作为输入
    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
        # 以下参数未完全列出,继续在后续的代码中定义和使用
    ) -> None:
        """
        Predicts masks given an image and prompt embeddings, using a transformer architecture.

        Args:
            transformer_dim (int): the channel dimension of the transformer module
            transformer (nn.Module): the transformer used to predict masks
            num_multimask_outputs (int): the number of masks to predict when disambiguating masks
            activation (nn.Module): the type of activation to use when upscaling masks
            iou_head_depth (int): the depth of the MLP used to predict mask quality
            iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer  # 保存传入的 transformer 模型

        self.num_multimask_outputs = num_multimask_outputs  # 保存多重掩模(mask)输出的数量

        self.iou_token = nn.Embedding(1, transformer_dim)  # 创建一个大小为 1xtransformer_dim 的嵌入层
        self.num_mask_tokens = num_multimask_outputs + 1  # 计算总的 mask 标记数量
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)  # 创建一个大小为 num_mask_tokens x transformer_dim 的嵌入层

        self.output_upscaling = nn.Sequential(  # 定义输出上采样的网络结构
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),  # 反卷积层,将 transformer_dim 维度的特征图上采样到 transformer_dim // 4 维度
            LayerNorm2d(transformer_dim // 4),  # Layer normalization 层,对上一层输出进行归一化处理
            activation(),  # 激活函数,根据传入的 activation 类创建
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),  # 进一步上采样到 transformer_dim // 8 维度
            activation(),  # 再次应用激活函数
        )
        self.output_hypernetworks_mlps = nn.ModuleList(
            [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
        )  # 创建一个 ModuleList,其中包含 num_mask_tokens 个 MLP(多层感知机)模型

        self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)  # 创建一个 MLP 用于预测 IOU(Intersection over Union)

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predict masks given image and prompt embeddings.

        Args:
            image_embeddings (torch.Tensor): the embeddings from the image encoder
            image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
            sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
            dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
            multimask_output (bool): Whether to return multiple masks or a single mask.

        Returns:
            torch.Tensor: batched predicted masks
            torch.Tensor: batched predictions of mask quality
        """
        # Predict masks using the provided embeddings and positional encoding
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # Select the correct mask or masks for output based on multimask_output flag
        mask_slice = slice(1, None) if multimask_output else slice(0, 1)
        # Slice the masks tensor to include only the desired masks
        masks = masks[:, mask_slice, :, :]
        # Slice the iou_pred tensor to include only the corresponding predictions
        iou_pred = iou_pred[:, mask_slice]

        # Prepare output for batched masks and their quality predictions
        return masks, iou_pred

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predicts masks.

        See 'forward' for more details.
        """
        # Concatenate output tokens
        # 将输出的 token 拼接起来
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        # 在批处理的方向上扩展每个图像数据以对应每个 mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        # 运行 Transformer 模型
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        # 放大 mask 的嵌入并使用 mask token 预测 masks
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = [
            self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
        ]
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        # 生成 mask 质量预测
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred
# 定义一个 MLP(多层感知机)模型类
class MLP(nn.Module):
    """
    MLP (Multi-Layer Perceptron) model lightly adapted from
    https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        sigmoid_output: bool = False,
    ) -> None:
        """
        初始化 MLP(多层感知机)模型。

        Args:
            input_dim (int): 输入特征的维度。
            hidden_dim (int): 隐藏层的维度。
            output_dim (int): 输出层的维度。
            num_layers (int): 隐藏层的数量。
            sigmoid_output (bool, optional): 是否对输出层应用 sigmoid 激活函数,默认为 False。
        """
        super().__init__()
        self.num_layers = num_layers
        # 构建隐藏层的结构,使用 ModuleList 存储多个线性层
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
        self.sigmoid_output = sigmoid_output

    def forward(self, x):
        """执行神经网络模块的前向传播,并应用激活函数。"""
        # 遍历并应用所有隐藏层的线性变换及 ReLU 激活函数
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        # 如果设置了 sigmoid_output,则对最终输出应用 sigmoid 激活函数
        if self.sigmoid_output:
            x = torch.sigmoid(x)
        return x

.\yolov8\ultralytics\models\sam\modules\encoders.py

# 导入所需模块和类,包括通用数据类型和模型定义
from typing import Any, Optional, Tuple, Type

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# 导入自定义模块
from ultralytics.nn.modules import LayerNorm2d, MLPBlock

# 定义一个基于Vision Transformer(ViT)架构的图像编码器类
class ImageEncoderViT(nn.Module):
    """
    An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
    encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
    The encoded patches are then processed through a neck to generate the final encoded representation.

    This class and its supporting functions below lightly adapted from the ViTDet backbone available at
    https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.

    Attributes:
        img_size (int): Dimension of input images, assumed to be square.
        patch_embed (PatchEmbed): Module for patch embedding.
        pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
        blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
        neck (nn.Sequential): Neck module to further process the output.
    """

    def __init__(
        self,
        img_size: int = 1024,             # 输入图像的尺寸,默认为1024x1024像素
        patch_size: int = 16,            # 每个patch的尺寸,默认为16x16像素
        in_chans: int = 3,               # 输入图像的通道数,默认为RGB三通道
        embed_dim: int = 768,            # 嵌入维度,每个patch的嵌入维度,默认为768
        depth: int = 12,                 # Transformer块的深度(层数),默认为12层
        num_heads: int = 12,             # 注意力头的数量,默认为12个
        mlp_ratio: float = 4.0,          # MLP(多层感知机)部分的维度扩展比例,默认为4.0
        out_chans: int = 256,            # 输出通道数,默认为256
        qkv_bias: bool = True,           # 是否允许注意力机制中的查询、键、值偏置,默认为True
        norm_layer: Type[nn.Module] = nn.LayerNorm,  # 规范化层类型,默认为LayerNorm
        act_layer: Type[nn.Module] = nn.GELU,        # 激活函数类型,默认为GELU
        use_abs_pos: bool = True,        # 是否使用绝对位置编码,默认为True
        use_rel_pos: bool = False,       # 是否使用相对位置编码,默认为False
        rel_pos_zero_init: bool = True,  # 相对位置编码是否零初始化,默认为True
        window_size: int = 0,            # 窗口大小,用于局部注意力机制,默认为0表示全局注意力
        global_attn_indexes: Tuple[int, ...] = (),  # 全局注意力的索引列表,默认为空元组
    ) -> None:
        """
        Args:
            img_size (int): Input image size.
            patch_size (int): Patch size.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
            depth (int): Depth of ViT.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_abs_pos (bool): If True, use absolute positional embeddings.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks.
            global_attn_indexes (list): Indexes for blocks using global attention.
        """
        super().__init__()
        self.img_size = img_size  # 设置输入图像尺寸

        self.patch_embed = PatchEmbed(
            kernel_size=(patch_size, patch_size),  # 设置 patch 的大小
            stride=(patch_size, patch_size),  # 设置 patch 的步长
            in_chans=in_chans,  # 设置输入图像的通道数
            embed_dim=embed_dim,  # 设置 patch 嵌入的维度
        )

        self.pos_embed: Optional[nn.Parameter] = None
        if use_abs_pos:
            # 如果使用绝对位置编码,则初始化绝对位置嵌入,大小为预训练图像大小除以 patch 大小
            self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))

        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                dim=embed_dim,  # 设置块的维度
                num_heads=num_heads,  # 设置块中注意力头的数量
                mlp_ratio=mlp_ratio,  # 设置 MLP 隐藏层维度与嵌入维度的比率
                qkv_bias=qkv_bias,  # 设置是否在查询、键、值上添加可学习偏置
                norm_layer=norm_layer,  # 设置归一化层
                act_layer=act_layer,  # 设置激活函数层
                use_rel_pos=use_rel_pos,  # 设置是否使用相对位置编码
                rel_pos_zero_init=rel_pos_zero_init,  # 设置是否将相对位置参数初始化为零
                window_size=window_size if i not in global_attn_indexes else 0,  # 设置窗口注意力块的窗口大小
                input_size=(img_size // patch_size, img_size // patch_size),  # 设置输入块的大小
            )
            self.blocks.append(block)  # 将块添加到模块列表中

        self.neck = nn.Sequential(
            nn.Conv2d(
                embed_dim,  # 输入通道数为嵌入维度
                out_chans,  # 输出通道数为指定的输出通道数
                kernel_size=1,  # 设置卷积核大小为1x1
                bias=False,  # 不使用偏置
            ),
            LayerNorm2d(out_chans),  # 应用输出通道数的层归一化
            nn.Conv2d(
                out_chans,  # 输入通道数为上一层的输出通道数
                out_chans,  # 输出通道数为上一层的输出通道数
                kernel_size=3,  # 设置卷积核大小为3x3
                padding=1,  # 使用填充大小为1
                bias=False,  # 不使用偏置
            ),
            LayerNorm2d(out_chans),  # 应用输出通道数的层归一化
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Processes input through patch embedding, applies positional embedding if present, and passes through blocks
        and neck.
        """
        # 使用 patch embedding 处理输入 x,将其转换为特定形状
        x = self.patch_embed(x)
        
        # 如果存在位置编码,则将其加到 x 上
        if self.pos_embed is not None:
            x = x + self.pos_embed
        
        # 逐个应用每个块(block)
        for blk in self.blocks:
            x = blk(x)
        
        # 将张量维度进行置换,通常是为了适应后续操作的需要
        return self.neck(x.permute(0, 3, 1, 2))
class PromptEncoder(nn.Module):
    """
    Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
    produces both sparse and dense embeddings for the input prompts.

    Attributes:
        embed_dim (int): Dimension of the embeddings.
        input_image_size (Tuple[int, int]): Size of the input image as (H, W).
        image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
        pe_layer (PositionEmbeddingRandom): Module for random position embedding.
        num_point_embeddings (int): Number of point embeddings for different types of points.
        point_embeddings (nn.ModuleList): List of point embeddings.
        not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
        mask_input_size (Tuple[int, int]): Size of the input mask.
        mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
        no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
    """

    def __init__(
        self,
        embed_dim: int,
        image_embedding_size: Tuple[int, int],
        input_image_size: Tuple[int, int],
        mask_in_chans: int,
        activation: Type[nn.Module] = nn.GELU,
    ) -> None:
        """
        Encodes prompts for input to SAM's mask decoder.

        Args:
          embed_dim (int): The prompts' embedding dimension
          image_embedding_size (tuple(int, int)): The spatial size of the
            image embedding, as (H, W).
          input_image_size (int): The padded size of the image as input
            to the image encoder, as (H, W).
          mask_in_chans (int): The number of hidden channels used for
            encoding input masks.
          activation (nn.Module): The activation to use when encoding
            input masks.
        """
        super().__init__()
        self.embed_dim = embed_dim  # 存储嵌入维度信息
        self.input_image_size = input_image_size  # 存储输入图像的大小信息
        self.image_embedding_size = image_embedding_size  # 存储图像嵌入的空间尺寸信息
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)  # 创建一个随机位置编码模块

        self.num_point_embeddings: int = 4  # 点嵌入的数量,包括正负点和两个框角点
        point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
        self.point_embeddings = nn.ModuleList(point_embeddings)  # 创建点嵌入的列表
        self.not_a_point_embed = nn.Embedding(1, embed_dim)  # 创建一个用于不属于任何标签的点的嵌入

        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])  # 设置输入掩码的大小
        self.mask_downscaling = nn.Sequential(  # 创建用于缩小掩码的神经网络序列
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),  # 第一层卷积层,将输入通道数缩小到 mask_in_chans // 4
            LayerNorm2d(mask_in_chans // 4),  # 对通道进行归一化
            activation(),  # 应用激活函数
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),  # 第二层卷积层,将通道数扩展回 mask_in_chans
            LayerNorm2d(mask_in_chans),  # 对通道进行归一化
            activation(),  # 应用激活函数
            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),  # 最后一层卷积层,将通道数减少到嵌入维度
        )
        self.no_mask_embed = nn.Embedding(1, embed_dim)  # 创建一个用于没有提供掩码的情况的嵌入
    def get_dense_pe(self) -> torch.Tensor:
        """
        Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
        image encoding.

        Returns:
          torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
        """
        # 调用位置编码层,返回对图像编码大小的位置编码张量,并在第一维度上增加一个维度
        return self.pe_layer(self.image_embedding_size).unsqueeze(0)

    def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
        """Embeds point prompts."""
        # 将点的坐标向中心偏移0.5,以便准确表示像素中心
        points = points + 0.5  # Shift to center of pixel
        if pad:
            # 如果需要填充,则创建零填充点和负标签
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
            # 在点和标签的第一维度上连接填充点和填充标签
            points = torch.cat([points, padding_point], dim=1)
            labels = torch.cat([labels, padding_label], dim=1)
        # 使用位置编码层将点坐标嵌入到输入图像大小中
        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
        # 将标签为-1的点嵌入设为0
        point_embedding[labels == -1] = 0.0
        # 将标签为-1的点嵌入增加not_a_point_embed的权重
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        # 将标签为0的点嵌入增加point_embeddings[0]的权重
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        # 将标签为1的点嵌入增加point_embeddings[1]的权重
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        return point_embedding

    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        # 将框的坐标向中心偏移0.5,以便准确表示像素中心
        boxes = boxes + 0.5  # Shift to center of pixel
        # 将框的坐标重塑为(-1, 2, 2)的形状
        coords = boxes.reshape(-1, 2, 2)
        # 使用位置编码层将框的角坐标嵌入到输入图像大小中
        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
        # 将每个框的第一个角点的嵌入增加point_embeddings[2]的权重
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        # 将每个框的第二个角点的嵌入增加point_embeddings[3]的权重
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        return corner_embedding

    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
        """Embeds mask inputs."""
        # 使用mask_downscaling函数对输入的masks进行嵌入
        return self.mask_downscaling(masks)

    def _get_batch_size(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> int:
        """Gets the batch size of the output given the batch size of the input prompts."""
        if points is not None:
            # 如果有点的输入,则返回点的批量大小
            return points[0].shape[0]
        elif boxes is not None:
            # 如果有框的输入,则返回框的批量大小
            return boxes.shape[0]
        elif masks is not None:
            # 如果有mask的输入,则返回mask的批量大小
            return masks.shape[0]
        else:
            # 如果没有输入,则默认返回批量大小为1
            return 1

    def _get_device(self) -> torch.device:
        """Returns the device of the first point embedding's weight tensor."""
        # 返回第一个点嵌入权重张量所在的设备
        return self.point_embeddings[0].weight.device
    # 定义函数签名及文档字符串,说明函数用途和返回值
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Embeds different types of prompts, returning both sparse and dense embeddings.

        Args:
          points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
              如果不为 None,则包含点的坐标和标签的元组
          boxes (torch.Tensor, None): boxes to embed
              如果不为 None,则包含要嵌入的框的张量
          masks (torch.Tensor, None): masks to embed
              如果不为 None,则包含要嵌入的掩码的张量

        Returns:
          torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
            by the number of input points and boxes.
              稀疏嵌入(sparse embeddings)用于点和框,形状为 BxNx(embed_dim),其中 N 取决于输入点和框的数量。
          torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
              密集嵌入(dense embeddings)用于掩码,形状为 Bx(embed_dim)x(embed_H)x(embed_W),其中 B 是批大小,embed_dim 是嵌入维度,embed_H 和 embed_W 是图像嵌入的高度和宽度。
        """
        # 获取批大小
        bs = self._get_batch_size(points, boxes, masks)
        # 初始化稀疏嵌入张量,形状为 (批大小, 0, embed_dim),使用与设备相关的空设备
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
        
        # 如果 points 不为 None,则嵌入点的坐标和标签
        if points is not None:
            coords, labels = points
            # 调用 _embed_points 方法嵌入点,根据 boxes 是否为 None 进行填充
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
            # 在稀疏嵌入张量中拼接点的嵌入结果,按维度 1 连接
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        
        # 如果 boxes 不为 None,则嵌入框
        if boxes is not None:
            # 调用 _embed_boxes 方法嵌入框
            box_embeddings = self._embed_boxes(boxes)
            # 在稀疏嵌入张量中拼接框的嵌入结果,按维度 1 连接
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

        # 如果 masks 不为 None,则嵌入掩码
        if masks is not None:
            # 调用 _embed_masks 方法嵌入掩码
            dense_embeddings = self._embed_masks(masks)
        else:
            # 否则,使用预设的无掩码嵌入权重,重塑形状以匹配指定的图像嵌入大小
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

        # 返回稀疏嵌入张量和密集嵌入张量
        return sparse_embeddings, dense_embeddings
class PositionEmbeddingRandom(nn.Module):
    """Positional encoding using random spatial frequencies."""

    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        """Initializes a position embedding using random spatial frequencies."""
        super().__init__()
        # 如果未指定或指定的尺度小于等于0,则将尺度设置为1.0
        if scale is None or scale <= 0.0:
            scale = 1.0
        # 创建一个随机高斯矩阵,并将其缩放到指定的尺度
        self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))

        # 设置非确定性以避免前向传播时出现'cumsum_cuda_kernel does not have a deterministic implementation'错误
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = False

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # 假设坐标在[0, 1]^2的范围内,并且具有形状为 d_1 x ... x d_n x 2
        coords = 2 * coords - 1  # 将坐标映射到[-1, 1]^2范围内
        coords = coords @ self.positional_encoding_gaussian_matrix  # 用高斯矩阵对坐标进行编码
        coords = 2 * np.pi * coords  # 对编码后的坐标进行标准化
        # 输出形状为 d_1 x ... x d_n x C
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)  # 返回正弦和余弦的拼接结果

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device
        grid = torch.ones((h, w), device=device, dtype=torch.float32)  # 创建一个大小为 h x w 的全1张量
        y_embed = grid.cumsum(dim=0) - 0.5  # 在垂直方向上累积和并偏移0.5
        x_embed = grid.cumsum(dim=1) - 0.5  # 在水平方向上累积和并偏移0.5
        y_embed = y_embed / h  # 归一化垂直方向的编码
        x_embed = x_embed / w  # 归一化水平方向的编码

        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))  # 对编码后的坐标进行位置编码
        return pe.permute(2, 0, 1)  # 返回形状为 C x H x W 的编码结果

    def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
        """Positionally encode points that are not normalized to [0,1]."""
        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1]  # 将 x 坐标归一化到 [0, 1]
        coords[:, :, 1] = coords[:, :, 1] / image_size[0]  # 将 y 坐标归一化到 [0, 1]
        return self._pe_encoding(coords.to(torch.float))  # 对归一化后的坐标进行位置编码,返回 B x N x C 的结果


class Block(nn.Module):
    """Transformer blocks with support of window attention and residual propagation blocks."""

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        act_layer: Type[nn.Module] = nn.GELU,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        window_size: int = 0,
        input_size: Optional[Tuple[int, int]] = None,
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        norm_layer: nn.Module = nn.LayerNorm,
        act_layer: nn.Module = nn.ReLU,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = False,
        window_size: int = 0,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then
                use global attention.
            input_size (tuple(int, int), None): Input resolution for calculating the relative
                positional parameter size.
        """
        # Initialize the transformer block with parameters and layers
        super().__init__()
        # Layer normalization for the input data
        self.norm1 = norm_layer(dim)
        # Attention mechanism initialization
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
        )
        # Layer normalization after the attention mechanism
        self.norm2 = norm_layer(dim)
        # Multi-layer perceptron (MLP) block initialization
        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
        # Store the window size parameter
        self.window_size = window_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
        # Store the input tensor for residual connection
        shortcut = x
        # Layer normalization on the input tensor
        x = self.norm1(x)
        # Perform window partitioning if the window size is greater than 0
        if self.window_size > 0:
            # Retrieve dimensions H (height) and W (width) from the input tensor
            H, W = x.shape[1], x.shape[2]
            # Partition the input tensor into windows and calculate padding
            x, pad_hw = window_partition(x, self.window_size)

        # Apply attention mechanism on the input tensor
        x = self.attn(x)

        # Reverse window partitioning if the window size is greater than 0
        if self.window_size > 0:
            # Unpartition the tensor, using stored parameters
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        # Add the shortcut connection to the transformed tensor
        x = shortcut + x
        # Apply layer normalization, MLP block, and return the transformed tensor
        return x + self.mlp(self.norm2(x))
class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = True,
        use_rel_pos: bool = False,
        rel_pos_zero_init: bool = True,
        input_size: Optional[Tuple[int, int]] = None,
    ) -> None:
        """
        Initialize Attention module.

        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (tuple(int, int), None): Input resolution for calculating the relative
                positional parameter size.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        # Linear transformation for queries, keys, and values
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # Linear transformation for projecting output
        self.proj = nn.Linear(dim, dim)

        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert input_size is not None, "Input size must be provided if using relative positional encoding."
            # Initialize relative positional embeddings for attention mechanism
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
        B, H, W, _ = x.shape
        # Linear transformation for queries, keys, and values with reshaping and permutation
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # Separate queries, keys, and values
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)

        # Compute attention scores
        attn = (q * self.scale) @ k.transpose(-2, -1)

        if self.use_rel_pos:
            # Incorporate relative positional embeddings into attention scores
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))

        # Apply softmax to compute attention weights
        attn = attn.softmax(dim=-1)
        # Compute weighted sum of values based on attention weights
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        # Project output back to original dimension
        return self.proj(x)


def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
    """
    Partition into non-overlapping windows with padding if needed.

    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    # Calculate padding required to make dimensions divisible by window_size
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size

    if pad_h > 0 or pad_w > 0:
        # Pad the input tensor along height and width dimensions
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    # 计算添加填充后的高度和宽度
    Hp, Wp = H + pad_h, W + pad_w
    
    # 将输入张量 x 重新视图为多个窗口,每个窗口大小为 window_size x window_size,按照指定顺序重新排列
    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    
    # 对重新排列后的张量进行维度置换,使得窗口的维度顺序为 (batch_size, rows, cols, window_size, window_size, channels),并确保张量连续性
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    
    # 返回重组后的窗口数据以及添加填充后的高度和宽度元组
    return windows, (Hp, Wp)
def add_decomposed_rel_pos(
    attn: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
    https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.

    Args:
        attn (Tensor): attention map.
            输入参数,表示注意力图的张量。

        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
            输入参数,表示注意力层中的查询 q,形状为 (B, q_h * q_w, C)。

        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
            输入参数,表示高度轴上的相对位置嵌入,形状为 (Lh, C)。

        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
            输入参数,表示宽度轴上的相对位置嵌入,形状为 (Lw, C)。

        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
            输入参数,表示查询 q 的空间序列大小,形状为 (q_h, q_w)。

        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
            输入参数,表示键 k 的空间序列大小,形状为 (k_h, k_w)。

    Returns:
        torch.Tensor: Updated attention map with decomposed relative positional embeddings.
            返回值,更新后的注意力图,包含分解后的相对位置嵌入。
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]
    # 获取查询和键的尺寸
    q_h, q_w = q_size
    k_h, k_w = k_size

    # 获取相对位置编码的矩阵
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    # 获取输入张量的批大小、通道数和维度
    B, _, dim = q.shape

    # 将查询张量重塑为四维张量,以便进行相对位置编码计算
    r_q = q.reshape(B, q_h, q_w, dim)

    # 使用 Einstein Summation Notation 计算高度方向和宽度方向的相对位置编码
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    # 将注意力张量重塑为期望的形状,并添加高度和宽度方向的相对位置编码
    attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
        B, q_h * q_w, k_h * k_w
    )

    # 返回添加了相对位置编码的注意力张量
    return attn
class PatchEmbed(nn.Module):
    """Image to Patch Embedding."""

    def __init__(
        self,
        kernel_size: Tuple[int, int] = (16, 16),  # 定义卷积核大小为 (16, 16)
        stride: Tuple[int, int] = (16, 16),  # 定义卷积步长为 (16, 16)
        padding: Tuple[int, int] = (0, 0),  # 定义卷积填充为 (0, 0)
        in_chans: int = 3,  # 输入图片的通道数为 3
        embed_dim: int = 768,  # 嵌入维度为 768
    ) -> None:
        """
        Initialize PatchEmbed module.

        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int): Patch embedding dimension.
        """
        super().__init__()

        # 使用 nn.Conv2d 定义一个卷积层,用于将输入图像转换成嵌入表示
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes patch embedding by applying convolution and transposing resulting tensor."""
        # 将输入张量 x 经过卷积层 self.proj 处理,并对结果进行维度置换
        return self.proj(x).permute(0, 2, 3, 1)  # B C H W -> B H W C
posted @ 2024-09-05 12:00  绝不原创的飞龙  阅读(3)  评论(0编辑  收藏  举报