Yolov8-源码解析-三十四-

Yolov8 源码解析(三十四)

.\yolov8\ultralytics\models\sam\modules\sam.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# 导入需要的模块和类
from typing import List

import torch
from torch import nn

# 导入本模块中的子模块
from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder

# 定义 Sam 类,用于对象分割任务
class Sam(nn.Module):
    """
    Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
    embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
    decoder to predict object masks.

    Attributes:
        mask_threshold (float): Threshold value for mask prediction.
        image_format (str): Format of the input image, default is 'RGB'.
        image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
        prompt_encoder (PromptEncoder): Encodes various types of input prompts.
        mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
        pixel_mean (List[float]): Mean pixel values for image normalization.
        pixel_std (List[float]): Standard deviation values for image normalization.
    """

    # 默认的阈值用于掩模预测
    mask_threshold: float = 0.0
    # 默认的输入图像格式为 RGB
    image_format: str = "RGB"

    def __init__(
        self,
        image_encoder: ImageEncoderViT,
        prompt_encoder: PromptEncoder,
        mask_decoder: MaskDecoder,
        pixel_mean: List[float] = (123.675, 116.28, 103.53),
        pixel_std: List[float] = (58.395, 57.12, 57.375),
    ) -> None:
        """
        Initialize the Sam class to predict object masks from an image and input prompts.

        Note:
            All forward() operations moved to SAMPredictor.

        Args:
            image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
            prompt_encoder (PromptEncoder): Encodes various types of input prompts.
            mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
            pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
                (123.675, 116.28, 103.53).
            pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
                (58.395, 57.12, 57.375).
        """
        super().__init__()
        # 设置图像编码器、提示编码器和掩模解码器
        self.image_encoder = image_encoder
        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder
        # 注册用于图像归一化的均值和标准差作为缓冲区
        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)

.\yolov8\ultralytics\models\sam\modules\tiny_encoder.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
#   LeViT: (https://github.com/facebookresearch/levit)
#   Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------

import itertools  # 导入 itertools 库,用于迭代操作
from typing import Tuple  # 导入 Tuple 类型提示,用于指定元组类型

import torch  # 导入 PyTorch 深度学习库
import torch.nn as nn  # 导入 PyTorch 神经网络模块
import torch.nn.functional as F  # 导入 PyTorch 神经网络函数模块
import torch.utils.checkpoint as checkpoint  # 导入 PyTorch 检查点模块,用于内存优化

from ultralytics.utils.instance import to_2tuple  # 从 ultralytics.utils.instance 模块中导入 to_2tuple 函数


class Conv2d_BN(torch.nn.Sequential):
    """A sequential container that performs 2D convolution followed by batch normalization."""

    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
        """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
        drop path.
        """
        super().__init__()
        # 添加 2D 卷积层,不使用偏置参数
        self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
        # 添加批归一化层,并初始化权重为 bn_weight_init,偏置为 0
        bn = torch.nn.BatchNorm2d(b)
        torch.nn.init.constant_(bn.weight, bn_weight_init)
        torch.nn.init.constant_(bn.bias, 0)
        self.add_module("bn", bn)


class PatchEmbed(nn.Module):
    """Embeds images into patches and projects them into a specified embedding dimension."""

    def __init__(self, in_chans, embed_dim, resolution, activation):
        """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
        function.
        """
        super().__init__()
        img_size: Tuple[int, int] = to_2tuple(resolution)
        self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        n = embed_dim
        # 构建序列模型,包含两个 Conv2d_BN 层和激活函数
        self.seq = nn.Sequential(
            Conv2d_BN(in_chans, n // 2, 3, 2, 1),  # 第一个卷积 + 批归一化层
            activation(),  # 激活函数
            Conv2d_BN(n // 2, n, 3, 2, 1),  # 第二个卷积 + 批归一化层
        )

    def forward(self, x):
        """Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
        return self.seq(x)


class MBConv(nn.Module):
    """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
    def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
        """
        Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
        function.
        """
        super().__init__()
        
        # 设置输入通道数
        self.in_chans = in_chans
        # 计算隐藏层通道数,根据扩展比例
        self.hidden_chans = int(in_chans * expand_ratio)
        # 设置输出通道数
        self.out_chans = out_chans

        # 第一个卷积层,包括卷积和批归一化
        self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
        # 第一个激活函数,根据给定的激活函数类实例化
        self.act1 = activation()

        # 第二个卷积层,包括卷积、批归一化和分组卷积(根据隐藏通道数)
        self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
        # 第二个激活函数,同样根据给定的激活函数类实例化
        self.act2 = activation()

        # 第三个卷积层,包括卷积、批归一化
        self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
        # 第三个激活函数,使用给定的激活函数类实例化
        self.act3 = activation()

        # 在训练时,根据是否需要进行 DropPath 操作来决定是否使用 DropPath 层
        # NOTE: `DropPath` is needed only for training.
        self.drop_path = nn.Identity()  # 如果 drop_path <= 0,使用恒等映射作为 drop_path
        # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        """
        Implements the forward pass for the model architecture.
        """
        # 将输入作为快捷连接(shortcut)
        shortcut = x
        # 第一层卷积操作
        x = self.conv1(x)
        # 第一层激活函数
        x = self.act1(x)
        # 第二层卷积操作
        x = self.conv2(x)
        # 第二层激活函数
        x = self.act2(x)
        # 第三层卷积操作
        x = self.conv3(x)
        # DropPath 操作(在训练时可能会对 x 进行操作)
        x = self.drop_path(x)
        # 加上快捷连接
        x += shortcut
        # 最后一层激活函数
        return self.act3(x)
class PatchMerging(nn.Module):
    """Merges neighboring patches in the feature map and projects to a new dimension."""

    def __init__(self, input_resolution, dim, out_dim, activation):
        """Initializes the PatchMerging module with specified parameters.

        Args:
            input_resolution (tuple): Resolution of the input feature map (H, W).
            dim (int): Input dimensionality of the feature map.
            out_dim (int): Output dimensionality after merging and projection.
            activation (torch.nn.Module): Activation function instance.
        """
        super().__init__()

        self.input_resolution = input_resolution  # Store input resolution (H, W)
        self.dim = dim  # Store input dimensionality
        self.out_dim = out_dim  # Store output dimensionality
        self.act = activation()  # Initialize activation function instance
        self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)  # 1x1 convolution layer
        stride_c = 1 if out_dim in {320, 448, 576} else 2
        self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)  # Depthwise separable convolution
        self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)  # 1x1 convolution layer

    def forward(self, x):
        """Performs forward pass through the PatchMerging module.

        Args:
            x (torch.Tensor): Input tensor, expected to have dimensions (B, C, H, W) or (B, H, W, C).

        Returns:
            torch.Tensor: Flattened and transposed tensor after convolution operations.
        """
        if x.ndim == 3:
            H, W = self.input_resolution
            B = len(x)
            # Reshape input tensor to (B, C, H, W) format if initially in (B, H, W, C)
            x = x.view(B, H, W, -1).permute(0, 3, 1, 2)

        x = self.conv1(x)  # Apply first convolution layer
        x = self.act(x)  # Apply activation function

        x = self.conv2(x)  # Apply second convolution layer
        x = self.act(x)  # Apply activation function
        x = self.conv3(x)  # Apply third convolution layer

        return x.flatten(2).transpose(1, 2)  # Flatten and transpose output tensor


class ConvLayer(nn.Module):
    """
    Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).

    Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
    """

    def __init__(
        self,
        dim,
        input_resolution,
        depth,
        activation,
        drop_path=0.0,
        downsample=None,
        use_checkpoint=False,
        out_dim=None,
        conv_expand_ratio=4.0,
    ):
        """Initializes the ConvLayer module with specified parameters.

        Args:
            dim (int): Input dimensionality for the convolutional layer.
            input_resolution (tuple): Resolution of the input feature map (H, W).
            depth (int): Depth of the convolutional layer.
            activation (torch.nn.Module): Activation function instance.
            drop_path (float, optional): Dropout probability. Defaults to 0.0.
            downsample (str or None, optional): Downsample operation type. Defaults to None.
            use_checkpoint (bool, optional): Flag to use gradient checkpointing. Defaults to False.
            out_dim (int or None, optional): Output dimensionality. Defaults to None.
            conv_expand_ratio (float, optional): Expansion ratio for convolution layers. Defaults to 4.0.
        """
        super().__init__()

        # Initialize module attributes
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.activation = activation
        self.drop_path = drop_path
        self.downsample = downsample
        self.use_checkpoint = use_checkpoint
        self.out_dim = out_dim
        self.conv_expand_ratio = conv_expand_ratio
    ):
        """
        Initializes the ConvLayer with the given dimensions and settings.

        Args:
            dim (int): The dimensionality of the input and output.
            input_resolution (Tuple[int, int]): The resolution of the input image.
            depth (int): The number of MBConv layers in the block.
            activation (Callable): Activation function applied after each convolution.
            drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
            downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
            use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
            out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
            conv_expand_ratio (float): Expansion ratio for the MBConv layers.
        """
        super().__init__()  # 调用父类的初始化方法

        self.dim = dim  # 设置输入和输出的维度
        self.input_resolution = input_resolution  # 设置输入图像的分辨率
        self.depth = depth  # 设置 MBConv 块中的层数
        self.use_checkpoint = use_checkpoint  # 设置是否使用梯度检查点来节省内存

        # 构建块
        self.blocks = nn.ModuleList(
            [
                MBConv(
                    dim,
                    dim,
                    conv_expand_ratio,
                    activation,
                    drop_path[i] if isinstance(drop_path, list) else drop_path,
                )
                for i in range(depth)
            ]
        )

        # Patch merging layer
        self.downsample = (
            None  # 如果没有指定 downsample 函数,则设置为 None
            if downsample is None
            else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)  # 否则调用 downsample 函数进行下采样
        )

    def forward(self, x):
        """Processes the input through a series of convolutional layers and returns the activated output."""
        for blk in self.blocks:
            x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)  # 依次对输入 x 应用每个 MBConv 块
        return x if self.downsample is None else self.downsample(x)  # 如果定义了 downsample 函数,则对最终输出 x 进行下采样处理
        """
        Initializes the Multi-head Attention module with given parameters.

        Args:
            dim (int): Dimensionality of input embeddings.
            key_dim (int): Dimensionality of key and query vectors.
            num_heads (int, optional): Number of attention heads. Default is 8.
            attn_ratio (int, optional): Ratio of total spatial positions to use as attention biases. Default is 4.
            resolution (tuple, optional): Spatial resolution of the input. Default is (14, 14).
        """
        super().__init__()
        # Calculate the size of each head in the attention mechanism
        head_dim = key_dim // num_heads
        # Initialize the linear transformation of input into query, key, and value
        self.to_qkv = nn.Linear(dim, 3 * key_dim)
        # Cache the number of attention heads
        self.num_heads = num_heads
        # Set the spatial bias ratio for attention mechanism
        self.attn_ratio = attn_ratio
        # Determine the size of spatial grid for the attention biases
        self.resolution = resolution
        # Initialize cached attention biases for inference, to be deleted during training
        self.ab = None
    ):
        """
        Initializes the Attention module.

        Args:
            dim (int): The dimensionality of the input and output.
            key_dim (int): The dimensionality of the keys and queries.
            num_heads (int, optional): Number of attention heads. Default is 8.
            attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
            resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).

        Raises:
            AssertionError: If `resolution` is not a tuple of length 2.
        """
        super().__init__()

        # 检查并确保 `resolution` 是长度为 2 的元组
        assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
        
        # 设置模块的属性
        self.num_heads = num_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio
        
        # 计算 `h`,作为后续线性层的输入维度
        h = self.dh + nh_kd * 2

        # Layer normalization 层
        self.norm = nn.LayerNorm(dim)
        
        # 线性变换层,将输入转换为 `h` 维度
        self.qkv = nn.Linear(dim, h)
        
        # 输出投影层,将注意力头的结果投影回 `dim` 维度
        self.proj = nn.Linear(self.dh, dim)

        # 生成所有空间位置的偏移量对应的索引
        points = list(itertools.product(range(resolution[0]), range(resolution[1])))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        
        # 初始化注意力偏置参数,并注册为模型的可学习参数
        self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)

    @torch.no_grad()
    def train(self, mode=True):
        """Sets the module in training mode and handles attribute 'ab' based on the mode."""
        # 调用父类的 `train` 方法,设置模块的训练模式
        super().train(mode)
        
        # 根据训练模式处理 `ab` 属性
        if mode and hasattr(self, "ab"):
            del self.ab  # 如果是训练模式且存在 `ab` 属性,则删除它
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]
            # 如果是测试模式或 `ab` 属性不存在,则将 `attention_biases` 与 `attention_bias_idxs` 结合起来存储在 `ab` 中
    def forward(self, x):  # x
        """对输入张量 'x' 执行前向传播,包括归一化和查询键/值操作。"""
        B, N, _ = x.shape  # B, N, C

        # 归一化处理
        x = self.norm(x)

        # 查询键值对
        qkv = self.qkv(x)
        # 将结果重塑为 (B, N, num_heads, d),并分割为 q, k, v
        q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
        # 将维度重新排列为 (B, num_heads, N, d)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        # 将 attention_biases 转移到合适的设备上
        self.ab = self.ab.to(self.attention_biases.device)

        # 计算注意力权重,包括缩放和偏置项
        attn = (q @ k.transpose(-2, -1)) * self.scale + (
            self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
        )
        attn = attn.softmax(dim=-1)
        
        # 计算加权后的值
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
        
        # 应用投影层并返回结果
        return self.proj(x)
class TinyViTBlock(nn.Module):
    """TinyViT Block that applies self-attention and a local convolution to the input."""

    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        window_size=7,
        mlp_ratio=4.0,
        drop=0.0,
        drop_path=0.0,
        local_conv_size=3,
        activation=nn.GELU,
    ):
        """
        Initializes the TinyViTBlock.

        Args:
            dim (int): The dimensionality of the input and output.
            input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
            num_heads (int): Number of attention heads.
            window_size (int, optional): Window size for attention. Default is 7.
            mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
            drop (float, optional): Dropout rate. Default is 0.
            drop_path (float, optional): Stochastic depth rate. Default is 0.
            local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
            activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.

        Raises:
            AssertionError: If `window_size` is not greater than 0.
            AssertionError: If `dim` is not divisible by `num_heads`.
        """
        super().__init__()
        self.dim = dim  # 设置输入输出的维度
        self.input_resolution = input_resolution  # 设置输入特征图的空间分辨率
        self.num_heads = num_heads  # 设置注意力头的数量
        assert window_size > 0, "window_size must be greater than 0"  # 断言窗口大小必须大于0
        self.window_size = window_size  # 设置注意力机制的窗口大小
        self.mlp_ratio = mlp_ratio  # 设置MLP隐藏层维度与嵌入维度的比例

        # NOTE: `DropPath` is needed only for training.
        # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.drop_path = nn.Identity()  # 设置DropPath层,用于训练时的随机深度(如果drop_path大于0)

        assert dim % num_heads == 0, "dim must be divisible by num_heads"  # 断言维度必须能够被注意力头数整除
        head_dim = dim // num_heads  # 计算每个注意力头的维度

        window_resolution = (window_size, window_size)
        self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
        # 初始化注意力层,传入维度、头部维度、头部数量、注意力比例和窗口分辨率

        mlp_hidden_dim = int(dim * mlp_ratio)
        mlp_activation = activation
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
        # 初始化MLP层,传入输入特征维度、隐藏层特征维度、激活函数和dropout率

        pad = local_conv_size // 2
        self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
        # 初始化本地卷积层,传入输入和输出特征维度、卷积核大小、步长、填充、分组数
    def forward(self, x):
        """对输入的 'x' 进行基于注意力的转换或填充,然后通过本地卷积传递。

        Args:
            x (tensor): 输入张量,形状为 [batch, height*width, channels]。

        Returns:
            tensor: 经过处理后的张量,形状为 [batch, height*width, channels]。
        """
        h, w = self.input_resolution
        b, hw, c = x.shape  # batch, height*width, channels
        assert hw == h * w, "input feature has wrong size"  # 断言输入特征的尺寸是否正确
        res_x = x  # 保留原始输入张量

        # 如果输入分辨率等于窗口尺寸,则直接应用注意力模块
        if h == self.window_size and w == self.window_size:
            x = self.attn(x)
        else:
            # 否则,对输入进行重塑以便进行填充
            x = x.view(b, h, w, c)
            pad_b = (self.window_size - h % self.window_size) % self.window_size
            pad_r = (self.window_size - w % self.window_size) % self.window_size
            padding = pad_b > 0 or pad_r > 0  # 检查是否需要填充

            # 如果需要填充,则对输入进行填充操作
            if padding:
                x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

            pH, pW = h + pad_b, w + pad_r
            nH = pH // self.window_size
            nW = pW // self.window_size

            # 窗口分割
            x = (
                x.view(b, nH, self.window_size, nW, self.window_size, c)
                .transpose(2, 3)
                .reshape(b * nH * nW, self.window_size * self.window_size, c)
            )
            x = self.attn(x)  # 应用注意力模块

            # 窗口重组
            x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
            if padding:
                x = x[:, :h, :w].contiguous()  # 移除填充部分

            x = x.view(b, hw, c)  # 恢复原始形状

        x = res_x + self.drop_path(x)  # 加入残差连接和DropPath操作
        x = x.transpose(1, 2).reshape(b, c, h, w)  # 转置和重塑张量形状
        x = self.local_conv(x)  # 应用本地卷积
        x = x.view(b, c, hw).transpose(1, 2)  # 重塑张量形状

        return x + self.drop_path(self.mlp(x))  # 加入残差连接和MLP操作

    def extra_repr(self) -> str:
        """返回一个格式化的字符串,表示TinyViTBlock的参数:维度、输入分辨率、注意力头数、窗口尺寸和MLP比例。

        Returns:
            str: 格式化后的参数信息字符串。
        """
        return (
            f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
            f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
        )
# 定义一个名为 BasicLayer 的类,用于 TinyViT 架构中的一个阶段的基本层次
class BasicLayer(nn.Module):
    """A basic TinyViT layer for one stage in a TinyViT architecture."""

    def __init__(
        self,
        dim,
        input_resolution,
        depth,
        num_heads,
        window_size,
        mlp_ratio=4.0,
        drop=0.0,
        drop_path=0.0,
        downsample=None,
        use_checkpoint=False,
        local_conv_size=3,
        activation=nn.GELU,
        out_dim=None,
    ):
        """
        Initializes the BasicLayer.

        Args:
            dim (int): The dimensionality of the input and output.
            input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
            depth (int): Number of TinyViT blocks.
            num_heads (int): Number of attention heads.
            window_size (int): Local window size.
            mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
            drop (float, optional): Dropout rate. Default is 0.
            drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
            downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
            use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
            local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
            activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
            out_dim (int | None, optional): The output dimension of the layer. Default is None.

        Raises:
            ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
        """
        # 调用父类的初始化方法
        super().__init__()
        # 设置类的属性
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # 构建 TinyViTBlock 组成的模块列表
        self.blocks = nn.ModuleList(
            [
                TinyViTBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    mlp_ratio=mlp_ratio,
                    drop=drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    local_conv_size=local_conv_size,
                    activation=activation,
                )
                for i in range(depth)  # 根据 depth 参数循环创建 TinyViTBlock
            ]
        )

        # 如果指定了 downsample 参数,则创建对应的下采样层
        self.downsample = (
            None
            if downsample is None
            else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
        )
    # 执行输入张量的前向传播,并返回一个规范化的张量
    def forward(self, x):
        # 遍历网络中的每个块进行前向传播
        for blk in self.blocks:
            # 如果使用了检查点技术,则通过检查点执行块的前向传播,否则直接调用块的前向传播
            x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
        # 如果存在下采样函数,则对输出张量进行下采样操作
        return x if self.downsample is None else self.downsample(x)

    # 返回一个描述层参数的字符串表示形式
    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
    class LayerNorm2d(nn.Module):
        """A PyTorch implementation of Layer Normalization in 2D."""

        def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
            """Initialize LayerNorm2d with the number of channels and an optional epsilon."""
            super().__init__()
            # Define learnable parameters for scaling and shifting
            self.weight = nn.Parameter(torch.ones(num_channels))
            self.bias = nn.Parameter(torch.zeros(num_channels))
            self.eps = eps

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """Perform a forward pass, normalizing the input tensor."""
            # Compute mean and standard deviation across channels
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            # Normalize the input tensor
            x = (x - u) / torch.sqrt(s + self.eps)
            # Scale and shift the normalized tensor
            return self.weight[:, None, None] * x + self.bias[:, None, None]


    class TinyViT(nn.Module):
        """
        The TinyViT architecture for vision tasks.

        Attributes:
            img_size (int): Input image size.
            in_chans (int): Number of input channels.
            num_classes (int): Number of classification classes.
            embed_dims (List[int]): List of embedding dimensions for each layer.
            depths (List[int]): List of depths for each layer.
            num_heads (List[int]): List of number of attention heads for each layer.
            window_sizes (List[int]): List of window sizes for each layer.
            mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
            drop_rate (float): Dropout rate for drop layers.
            drop_path_rate (float): Drop path rate for stochastic depth.
            use_checkpoint (bool): Use checkpointing for efficient memory usage.
            mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
            local_conv_size (int): Local convolution kernel size.
            layer_lr_decay (float): Layer-wise learning rate decay.

        Note:
            This implementation is generalized to accept a list of depths, attention heads,
            embedding dimensions and window sizes, which allows you to create a
            "stack" of TinyViT models of varying configurations.
        """

        def __init__(
            self,
            img_size=224,
            in_chans=3,
            num_classes=1000,
            embed_dims=(96, 192, 384, 768),
            depths=(2, 2, 6, 2),
            num_heads=(3, 6, 12, 24),
            window_sizes=(7, 7, 14, 7),
            mlp_ratio=4.0,
            drop_rate=0.0,
            drop_path_rate=0.1,
            use_checkpoint=False,
            mbconv_expand_ratio=4.0,
            local_conv_size=3,
            layer_lr_decay=1.0,
    def set_layer_lr_decay(self, layer_lr_decay):
        """Sets the learning rate decay for each layer in the TinyViT model."""
        decay_rate = layer_lr_decay  # 设置每个层的学习率衰减率

        # Layers -> blocks (depth)
        depth = sum(self.depths)  # 计算总的层深度
        lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]  # 计算每个层的学习率缩放比例

        def _set_lr_scale(m, scale):
            """Sets the learning rate scale for each layer in the model based on the layer's depth."""
            for p in m.parameters():
                p.lr_scale = scale  # 设置每个模型层的学习率缩放比例

        self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))  # 对输入嵌入层设置学习率缩放比例
        i = 0
        for layer in self.layers:
            for block in layer.blocks:
                block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))  # 对每个层块设置学习率缩放比例
                i += 1
            if layer.downsample is not None:
                layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))  # 对下采样层设置学习率缩放比例
        assert i == depth  # 确保设置了所有层的学习率缩放比例
        for m in [self.norm_head, self.head]:
            m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))  # 对归一化头部和头部设置学习率缩放比例

        for k, p in self.named_parameters():
            p.param_name = k  # 为每个参数设置参数名属性

        def _check_lr_scale(m):
            """Checks if the learning rate scale attribute is present in module's parameters."""
            for p in m.parameters():
                assert hasattr(p, "lr_scale"), p.param_name  # 检查模块参数中是否存在学习率缩放属性

        self.apply(_check_lr_scale)  # 应用检查学习率缩放属性的函数到模型中的所有模块

    def _init_weights(self, m):
        """Initializes weights for linear layers and layer normalization in the given module."""
        if isinstance(m, nn.Linear):
            # NOTE: This initialization is needed only for training.
            # trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)  # 初始化线性层的偏置为常数0
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)  # 初始化层归一化的偏置为常数0
            nn.init.constant_(m.weight, 1.0)  # 初始化层归一化的权重为常数1.0

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        """Returns a dictionary of parameter names where weight decay should not be applied."""
        return {"attention_biases"}  # 返回不应用权重衰减的参数名称字典

    def forward_features(self, x):
        """Runs the input through the model layers and returns the transformed output."""
        x = self.patch_embed(x)  # x input is (N, C, H, W)

        x = self.layers[0](x)  # 对输入应用第一个层
        start_i = 1

        for i in range(start_i, len(self.layers)):
            layer = self.layers[i]
            x = layer(x)  # 依次对每个层应用输入
        batch, _, channel = x.shape
        x = x.view(batch, 64, 64, channel)  # 调整输出的形状
        x = x.permute(0, 3, 1, 2)  # 调整输出的维度顺序
        return self.neck(x)  # 返回经过颈部处理后的输出

    def forward(self, x):
        """Executes a forward pass on the input tensor through the constructed model layers."""
        return self.forward_features(x)  # 执行输入张量通过构建模型层的前向传播

.\yolov8\ultralytics\models\sam\modules\transformer.py

# 导入所需的库
import math
from typing import Tuple, Type

import torch
from torch import Tensor, nn

# 导入自定义模块MLPBlock
from ultralytics.nn.modules import MLPBlock

# 定义一个名为TwoWayTransformer的神经网络模块,继承自nn.Module类
class TwoWayTransformer(nn.Module):
    """
    A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
    serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
    is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
    processing.

    Attributes:
        depth (int): The number of layers in the transformer.
        embedding_dim (int): The channel dimension for the input embeddings.
        num_heads (int): The number of heads for multihead attention.
        mlp_dim (int): The internal channel dimension for the MLP block.
        layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
        final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
        norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
    """

    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
    ) -> None:
        """
        A transformer decoder that attends to an input image using queries whose positional embedding is supplied.

        Args:
          depth (int): number of layers in the transformer
          embedding_dim (int): the channel dimension for the input embeddings
          num_heads (int): the number of heads for multihead attention. Must
            divide embedding_dim
          mlp_dim (int): the channel dimension internal to the MLP block
          activation (nn.Module): the activation to use in the MLP block
          attention_downsample_rate (int): downsample rate for attention
        """
        super().__init__()
        # 初始化模块参数
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        # 逐层添加TwoWayAttentionBlock,构建多层Transformer
        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )

        # 定义最终的注意力层,从查询到图像的注意力
        self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
        # 最终注意力层的层归一化
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

    def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
        point_pe: Tensor,
        **kwargs,
    ) -> Tuple[Tensor, Tensor]:
        """
        Defines the forward pass for the TwoWayTransformer module.

        Args:
            image_embedding (Tensor): The input image embedding.
            image_pe (Tensor): Positional encoding for the image.
            point_embedding (Tensor): The input point/query embedding.
            point_pe (Tensor): Positional encoding for the point/query.
            **kwargs: Additional arguments.

        Returns:
            Tuple[Tensor, Tensor]: Output of the transformer module.
        """
        # 略去了具体的前向传播过程,因为还未完整给出
        pass
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.

        Returns:
          (torch.Tensor): the processed point_embedding
          (torch.Tensor): the processed image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        # 获取图像嵌入的维度信息
        bs, c, h, w = image_embedding.shape
        # 将图像嵌入展平成 B x (H * W) x C,并置换维度顺序为 B x (H * W) x C
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        # 将图像位置编码展平成 B x (H * W) x C,并置换维度顺序为 B x (H * W) x C
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # 准备查询点
        queries = point_embedding
        keys = image_embedding

        # 应用 Transformer 块和最终的 LayerNorm
        for layer in self.layers:
            # 通过每个层处理查询和键
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )

        # 应用从点到图像的最终注意力层
        q = queries + point_embedding
        k = keys + image_pe
        # 对点到图像的注意力计算输出
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        # 将查询更新为原查询加上注意力输出
        queries = queries + attn_out
        # 对最终的注意力查询应用 LayerNorm
        queries = self.norm_final_attn(queries)

        return queries, keys
# 定义一个名为TwoWayAttentionBlock的新的神经网络模块,用于实现自注意力和交叉注意力的操作,
# 其中包括从查询到键和从键到查询的两个方向。该模块由四个主要层组成:
# (1) 自注意力层在稀疏输入上操作,
# (2) 将稀疏输入的交叉注意力应用于密集输入,
# (3) 对稀疏输入执行MLP块,
# (4) 将密集输入的交叉注意力应用于稀疏输入。

class TwoWayAttentionBlock(nn.Module):
    """
    An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
    keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
    of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
    sparse inputs.

    Attributes:
        self_attn (Attention): The self-attention layer for the queries.
        norm1 (nn.LayerNorm): Layer normalization following the first attention block.
        cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
        norm2 (nn.LayerNorm): Layer normalization following the second attention block.
        mlp (MLPBlock): MLP block that transforms the query embeddings.
        norm3 (nn.LayerNorm): Layer normalization following the MLP block.
        norm4 (nn.LayerNorm): Layer normalization following the third attention block.
        cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
        skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        Initialize the transformer block with four layers: 
        (1) self-attention on sparse inputs,
        (2) cross attention of sparse inputs to dense inputs,
        (3) mlp block on sparse inputs,
        (4) cross attention of dense inputs to sparse inputs.

        Args:
          embedding_dim (int): the channel dimension of the embeddings
          num_heads (int): the number of heads in the attention layers
          mlp_dim (int): the hidden dimension of the mlp block
          activation (nn.Module): the activation of the mlp block
          skip_first_layer_pe (bool): skip the PE on the first layer
        """
        super().__init__()
        # 第一层:自注意力层,用于处理查询
        self.self_attn = Attention(embedding_dim, num_heads)
        # 第一层后的层归一化
        self.norm1 = nn.LayerNorm(embedding_dim)

        # 第二层:从查询到键的交叉注意力层
        self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
        # 第二层后的层归一化
        self.norm2 = nn.LayerNorm(embedding_dim)

        # 第三层:MLP块,用于转换查询的嵌入
        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        # 第三层后的层归一化
        self.norm3 = nn.LayerNorm(embedding_dim)

        # 第四层:从密集输入到稀疏输入的交叉注意力层
        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)

        # 是否跳过第一层的位置编码
        self.skip_first_layer_pe = skip_first_layer_pe
    # 定义一个方法 `forward`,接收四个张量参数 `queries`、`keys`、`query_pe`、`key_pe`,返回两个张量。
    def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
        """Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
        
        # Self attention block
        # 如果设置了跳过第一层的位置编码,直接对 queries 执行自注意力机制
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            # 否则,将位置编码 query_pe 加到 queries 上,然后执行自注意力机制
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        # 对经过注意力机制后的 queries 执行 Layer Normalization
        queries = self.norm1(queries)

        # Cross attention block, tokens attending to image embedding
        # 将位置编码 query_pe 加到 queries 上,同时将位置编码 key_pe 加到 keys 上,然后执行跨注意力机制
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        # 对经过注意力机制后的 queries 执行 Layer Normalization
        queries = self.norm2(queries)

        # MLP block
        # 将 queries 输入到 MLP 网络中进行处理
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        # 对 MLP 输出后的 queries 执行 Layer Normalization
        queries = self.norm3(queries)

        # Cross attention block, image embedding attending to tokens
        # 将位置编码 query_pe 加到 queries 上,同时将位置编码 key_pe 加到 keys 上,然后执行反向的跨注意力机制
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        # 对经过反向注意力机制后的 keys 执行 Layer Normalization
        keys = self.norm4(keys)

        # 返回处理后的 queries 和 keys
        return queries, keys
class Attention(nn.Module):
    """An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
    values.
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
    ) -> None:
        """
        Initializes the Attention model with the given dimensions and settings.

        Args:
            embedding_dim (int): The dimensionality of the input embeddings.
            num_heads (int): The number of attention heads.
            downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.

        Raises:
            AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
        """
        super().__init__()
        # 设置嵌入维度和内部维度
        self.embedding_dim = embedding_dim
        self.internal_dim = embedding_dim // downsample_rate
        self.num_heads = num_heads
        # 检查 num_heads 是否能够整除内部维度
        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

        # 初始化线性投影层
        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

    @staticmethod
    def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
        """Separate the input tensor into the specified number of attention heads."""
        # 获取输入张量的形状信息
        b, n, c = x.shape
        # 重塑张量,将通道维度分为 num_heads 份
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    @staticmethod
    def _recombine_heads(x: Tensor) -> Tensor:
        """Recombine the separated attention heads into a single tensor."""
        # 获取张量的形状信息
        b, n_heads, n_tokens, c_per_head = x.shape
        # 转置张量,重新组合注意力头部
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        """Compute the attention output given the input query, key, and value tensors."""
        
        # 输入的投影
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # 分割为注意力头部
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        # 注意力计算
        _, _, _, c_per_head = q.shape
        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
        attn = attn / math.sqrt(c_per_head)
        attn = torch.softmax(attn, dim=-1)

        # 获取输出
        out = attn @ v
        out = self._recombine_heads(out)
        return self.out_proj(out)

.\yolov8\ultralytics\models\sam\modules\__init__.py

# 项目标题和许可证声明,表明这段代码实现了Ultralytics YOLO算法,并遵循AGPL-3.0许可证
# 这里的注释是项目的头部注释,通常用于声明项目名称、作者、许可证等重要信息
# 在代码中通常用来作为版权声明和许可证声明,提供给使用者和阅读者必要的信息
# 在此处可能包含有关项目功能、作者信息、项目版本、许可证类型等内容
# 头部注释通常是开源项目的一部分,用于遵循开源许可证的要求,并为项目提供法律保护和使用说明

.\yolov8\ultralytics\models\sam\predict.py

# 导入必要的库和模块
import numpy as np
import torch
import torch.nn.functional as F

# 导入 Ultralytics 自定义的数据增强模块 LetterBox
from ultralytics.data.augment import LetterBox
# 导入 Ultralytics 自定义的预测器基类 BasePredictor
from ultralytics.engine.predictor import BasePredictor
# 导入 Ultralytics 自定义的结果处理模块 Results
from ultralytics.engine.results import Results
# 导入 Ultralytics 的一些实用函数和操作
from ultralytics.utils import DEFAULT_CFG, ops
# 导入选择设备的函数
from ultralytics.utils.torch_utils import select_device

# 导入局部模块中的函数和类
from .amg import (
    batch_iterator,
    batched_mask_to_box,
    build_all_layer_point_grids,
    calculate_stability_score,
    generate_crop_boxes,
    is_box_near_crop_edge,
    remove_small_regions,
    uncrop_boxes_xyxy,
    uncrop_masks,
)
# 导入局部模块中的构建 SAM 模型的函数
from .build import build_sam

# 定义 SAM 模型的预测器类,继承自 BasePredictor 类
class Predictor(BasePredictor):
    """
    SAM 模型的预测器类,继承自 BasePredictor 类。

    该类提供了用于图像分割任务的模型推断接口。
    具有高级架构和可提示分割功能,支持灵活和实时的掩模生成。
    该类能够处理多种类型的提示,如边界框、点和低分辨率掩模。

    Attributes:
        cfg (dict): 模型和任务相关参数的配置字典。
        overrides (dict): 包含覆盖默认配置的值的字典。
        _callbacks (dict): 用户定义的回调函数字典,用于增强行为。
        args (namespace): 保存命令行参数或其他操作变量的命名空间。
        im (torch.Tensor): 预处理后的输入图像张量。
        features (torch.Tensor): 用于推断的提取图像特征。
        prompts (dict): 包含各种提示类型的集合,如边界框和点。
        segment_all (bool): 控制是否对图像中的所有对象进行分割或仅对指定对象进行分割的标志。
    """
    # 初始化 Predictor 对象,使用默认配置 cfg,如果提供了 overrides,则将其合并到配置中
    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
        """
        Initialize the Predictor with configuration, overrides, and callbacks.

        The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
        initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.

        Args:
            cfg (dict): Configuration dictionary.
            overrides (dict, optional): Dictionary of values to override default configuration.
            _callbacks (dict, optional): Dictionary of callback functions to customize behavior.
        """
        # 如果 overrides 为 None,则初始化为空字典
        if overrides is None:
            overrides = {}
        # 更新 overrides 字典,设置任务为 "segment",模式为 "predict",图像大小为 1024
        overrides.update(dict(task="segment", mode="predict", imgsz=1024))
        # 调用父类的初始化方法,传入 cfg、overrides 和 _callbacks
        super().__init__(cfg, overrides, _callbacks)
        # 设置 self.args.retina_masks 为 True,针对 SAM 模型的特定设置
        self.args.retina_masks = True
        # 初始化 self.im 为 None,用于存储输入图像
        self.im = None
        # 初始化 self.features 为 None,用于存储特征
        self.features = None
        # 初始化 self.prompts 为空字典,用于存储提示信息
        self.prompts = {}
        # 初始化 self.segment_all 为 False,用于控制是否对所有数据进行分割
        self.segment_all = False

    # 对输入图像进行预处理,以供模型推断使用
    def preprocess(self, im):
        """
        Preprocess the input image for model inference.

        The method prepares the input image by applying transformations and normalization.
        It supports both torch.Tensor and list of np.ndarray as input formats.

        Args:
            im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.

        Returns:
            (torch.Tensor): The preprocessed image tensor.
        """
        # 如果 self.im 不为 None,直接返回已经处理好的图像
        if self.im is not None:
            return self.im
        # 判断 im 是否为 torch.Tensor 类型之外的类型
        not_tensor = not isinstance(im, torch.Tensor)
        # 如果 im 不是 torch.Tensor 类型,则进行以下转换和预处理步骤
        if not_tensor:
            # 将输入的 HWC 格式的 numpy 数组堆叠成 BCHW 格式的 numpy 数组
            im = np.stack(self.pre_transform(im))
            # 对图像进行颜色通道反转(RGB to BGR)和维度转置,以符合模型输入要求
            im = im[..., ::-1].transpose((0, 3, 1, 2))
            # 转换为连续存储的数组
            im = np.ascontiguousarray(im)
            # 将 numpy 数组转换为 torch.Tensor
            im = torch.from_numpy(im)

        # 将处理好的图像数据移到设备(GPU 或 CPU)上
        im = im.to(self.device)
        # 如果模型使用 FP16 运算,则将图像数据类型转换为半精度(half),否则转换为单精度(float)
        im = im.half() if self.model.fp16 else im.float()
        # 如果输入图像不是 torch.Tensor 类型,则进行均值和标准差归一化处理
        if not_tensor:
            im = (im - self.mean) / self.std
        # 返回预处理后的图像 tensor
        return im

    # 对输入图像执行初始转换,以进行进一步的预处理
    def pre_transform(self, im):
        """
        Perform initial transformations on the input image for preprocessing.

        The method applies transformations such as resizing to prepare the image for further preprocessing.
        Currently, batched inference is not supported; hence the list length should be 1.

        Args:
            im (List[np.ndarray]): List containing images in HWC numpy array format.

        Returns:
            (List[np.ndarray]): List of transformed images.
        """
        # 断言输入的图像列表长度为 1,因为 SAM 模型不支持批量推断
        assert len(im) == 1, "SAM model does not currently support batched inference"
        # 创建 LetterBox 转换器,用于将输入图像调整为模型需要的大小
        letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
        # 对输入图像列表中的每张图像应用 LetterBox 转换
        return [letterbox(image=x) for x in im]
    def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
        """
        Perform image segmentation inference based on the given input cues, using the currently loaded image. This
        method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
        mask decoder for real-time and promptable segmentation tasks.

        Args:
            im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
            bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
            points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
            labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
            masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
            multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.

        Returns:
            (tuple): Contains the following three elements.
                - np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
                - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
                - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
        """
        # Override prompts if any stored in self.prompts
        # 从 self.prompts 中取出存储的提示信息(如果有),覆盖函数参数中的对应项
        bboxes = self.prompts.pop("bboxes", bboxes)
        points = self.prompts.pop("points", points)
        masks = self.prompts.pop("masks", masks)

        # 如果所有的提示信息都是 None,则调用 generate 方法生成输出
        if all(i is None for i in [bboxes, points, masks]):
            return self.generate(im, *args, **kwargs)

        # 否则,调用 prompt_inference 方法进行推断
        return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)

    def generate(
        self,
        im,
        crop_n_layers=0,
        crop_overlap_ratio=512 / 1500,
        crop_downscale_factor=1,
        point_grids=None,
        points_stride=32,
        points_batch_size=64,
        conf_thres=0.88,
        stability_score_thresh=0.95,
        stability_score_offset=0.95,
        crop_nms_thresh=0.7,
    ):
        """
        Generate segmentation masks based on the input image and various parameters.

        Args:
            im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
            crop_n_layers (int, optional): Number of layers to crop.
            crop_overlap_ratio (float, optional): Ratio of overlap in cropping.
            crop_downscale_factor (int, optional): Factor by which to downscale crops.
            point_grids (np.ndarray, optional): Grids of points for segmentation.
            points_stride (int, optional): Stride for points.
            points_batch_size (int, optional): Batch size for points processing.
            conf_thres (float, optional): Confidence threshold.
            stability_score_thresh (float, optional): Stability score threshold.
            stability_score_offset (float, optional): Stability score offset.
            crop_nms_thresh (float, optional): NMS threshold for cropping.

        Returns:
            None
        """
        # 实现生成分割 mask 的具体逻辑,根据参数设置生成相应的输出
        pass
    def setup_model(self, model, verbose=True):
        """
        Initializes the Segment Anything Model (SAM) for inference.

        This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary
        parameters for image normalization and other Ultralytics compatibility settings.

        Args:
            model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
            verbose (bool): If True, prints selected device information.

        Attributes:
            model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
            device (torch.device): The device to which the model and tensors are allocated.
            mean (torch.Tensor): The mean values for image normalization.
            std (torch.Tensor): The standard deviation values for image normalization.
        """
        # 选择设备并打印设备信息(如果 verbose=True)
        device = select_device(self.args.device, verbose=verbose)
        
        # 如果未提供预训练的模型,则根据配置构建 SAM 模型
        if model is None:
            model = build_sam(self.args.model)
        
        # 将模型设置为评估模式(不进行梯度更新)
        model.eval()
        
        # 将模型移动到指定的设备上
        self.model = model.to(device)
        self.device = device
        
        # 设置图像归一化所需的均值和标准差,并移动到指定的设备上
        self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
        self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)

        # 设置 Ultralytics 兼容性选项
        self.model.pt = False
        self.model.triton = False
        self.model.stride = 32
        self.model.fp16 = False
        
        # 标记初始化已完成
        self.done_warmup = True
    # 对 SAM 模型推断输出进行后处理,生成目标检测的掩码和边界框
    
    def postprocess(self, preds, img, orig_imgs):
        """
        Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
    
        The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
        The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
    
        Args:
            preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
            img (torch.Tensor): The processed input image tensor.
            orig_imgs (list | torch.Tensor): The original, unprocessed images.
    
        Returns:
            (list): List of Results objects containing detection masks, bounding boxes, and other metadata.
        """
    
        # 获取预测的掩码和得分
        pred_masks, pred_scores = preds[:2]
        # 如果需要分割所有类别,则获取预测的边界框
        pred_bboxes = preds[2] if self.segment_all else None
        # 生成掩码名称字典
        names = dict(enumerate(str(i) for i in range(len(pred_masks))))
    
        # 如果原始图像不是列表而是张量,则转换为 numpy 数组的批处理
        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
    
        results = []
        # 遍历每个预测的掩码、原始图像和图像路径
        for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
            # 如果存在预测的边界框,则调整边界框大小
            if pred_bboxes is not None:
                pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
                cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
                pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
    
            # 调整预测的掩码大小
            masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
            # 应用掩码阈值,转换为布尔值
            masks = masks > self.model.mask_threshold
            # 添加处理后的结果到结果列表
            results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
    
        # 重置分割所有类别的模式标志
        self.segment_all = False
        # 返回结果列表
        return results
    
    
    def setup_source(self, source):
        """
        Sets up the data source for inference.
    
        This method configures the data source from which images will be fetched for inference. The source could be a
        directory, a video file, or other types of image data sources.
    
        Args:
            source (str | Path): The path to the image data source for inference.
        """
        
        # 如果源路径不为 None,则调用父类方法设置数据源
        if source is not None:
            super().setup_source(source)
    def set_image(self, image):
        """
        Preprocesses and sets a single image for inference.

        This function sets up the model if not already initialized, configures the data source to the specified image,
        and preprocesses the image for feature extraction. Only one image can be set at a time.

        Args:
            image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.

        Raises:
            AssertionError: If more than one image is set.
        """
        # 如果模型尚未初始化,则根据给定的模型参数构建 SAM 模型
        if self.model is None:
            model = build_sam(self.args.model)
            # 初始化模型
            self.setup_model(model)
        
        # 配置数据源为指定的图像
        self.setup_source(image)
        
        # 检查数据集中是否只有一个图像,否则引发断言错误
        assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
        
        # 遍历数据集,预处理图像并提取特征,仅处理第一个 batch
        for batch in self.dataset:
            im = self.preprocess(batch[1])  # 对图像进行预处理
            self.features = self.model.image_encoder(im)  # 提取图像特征
            self.im = im  # 保存原始图像
            break

    def set_prompts(self, prompts):
        """Set prompts in advance."""
        # 设置预定义的提示语句
        self.prompts = prompts

    def reset_image(self):
        """Resets the image and its features to None."""
        # 重置图像和特征为 None
        self.im = None
        self.features = None

    @staticmethod
    def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
        """
        Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
        function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
        Suppression (NMS) to eliminate any newly created duplicate boxes.

        Args:
            masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
                                  the number of masks, H is height, and W is width.
            min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
            nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.

        Returns:
            (tuple([torch.Tensor, List[int]])):
                - new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
                - keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
        """
        import torchvision  # import statement needed for using torchvision

        if len(masks) == 0:
            return masks  # Return the input masks if empty

        # Filter small disconnected regions and holes
        new_masks = []  # Initialize an empty list for storing processed masks
        scores = []  # Initialize an empty list for storing scores of masks

        for mask in masks:
            mask = mask.cpu().numpy().astype(np.uint8)  # Convert mask tensor to numpy array of type uint8
            mask, changed = remove_small_regions(mask, min_area, mode="holes")  # Remove small holes
            unchanged = not changed  # Check if changes occurred in holes removal
            mask, changed = remove_small_regions(mask, min_area, mode="islands")  # Remove small islands
            unchanged = unchanged and not changed  # Check if changes occurred in islands removal

            new_masks.append(torch.as_tensor(mask).unsqueeze(0))  # Convert processed mask back to tensor and append
            scores.append(float(unchanged))  # Append the score (0 or 1) indicating if mask was unchanged

        # Recalculate boxes and remove any new duplicates
        new_masks = torch.cat(new_masks, dim=0)  # Concatenate all masks into a single tensor
        boxes = batched_mask_to_box(new_masks)  # Convert masks to bounding boxes
        keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)  # Perform NMS using scores

        return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep  # Return filtered masks and indices

.\yolov8\ultralytics\models\sam\__init__.py

# 导入模块 `SAM` 和 `Predictor`,它们来自当前目录下的 `model.py` 和 `predict.py`
from .model import SAM
from .predict import Predictor

# 定义公开接口 `__all__`,包含字符串 "SAM" 和 "Predictor" 的元组或列表
__all__ = "SAM", "Predictor"  # tuple or list

.\yolov8\ultralytics\models\utils\loss.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import torch
import torch.nn as nn
import torch.nn.functional as F

from ultralytics.utils.loss import FocalLoss, VarifocalLoss  # 导入 FocalLoss 和 VarifocalLoss 损失函数
from ultralytics.utils.metrics import bbox_iou  # 导入 bbox_iou 函数

from .ops import HungarianMatcher  # 导入匈牙利匹配器

class DETRLoss(nn.Module):
    """
    DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
    DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
    losses.

    Attributes:
        nc (int): The number of classes.
        loss_gain (dict): Coefficients for different loss components.
        aux_loss (bool): Whether to compute auxiliary losses.
        use_fl (bool): Use FocalLoss or not.
        use_vfl (bool): Use VarifocalLoss or not.
        use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
        uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
        matcher (HungarianMatcher): Object to compute matching cost and indices.
        fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
        vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
        device (torch.device): Device on which tensors are stored.
    """

    def __init__(
        self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
    ):
        """
        DETR loss function.

        Args:
            nc (int): The number of classes.
            loss_gain (dict): The coefficient of loss.
            aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
            use_vfl (bool): Use VarifocalLoss or not.
            use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
            uni_match_ind (int): The fixed indices of a layer.
        """
        super().__init__()

        # 如果 loss_gain 为 None,则使用默认的损失系数
        if loss_gain is None:
            loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
        
        # 设置类别数目 nc,初始化匈牙利匹配器 matcher
        self.nc = nc
        self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
        
        # 设置损失系数 loss_gain,是否计算辅助损失 aux_loss,以及是否使用 FocalLoss 和 VarifocalLoss
        self.loss_gain = loss_gain
        self.aux_loss = aux_loss
        self.fl = FocalLoss() if use_fl else None
        self.vfl = VarifocalLoss() if use_vfl else None

        # 是否使用固定层来分配辅助分支的标签,以及固定层的索引
        self.use_uni_match = use_uni_match
        self.uni_match_ind = uni_match_ind
        self.device = None
    # 计算分类损失,基于预测值、目标值和实际得分
    def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
        """Computes the classification loss based on predictions, target values, and ground truth scores."""
        # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
        name_class = f"loss_class{postfix}"
        bs, nq = pred_scores.shape[:2]
        
        # 创建一个全零的张量,形状为(bs, nq, self.nc + 1),类型为int64,存储在与targets相同的设备上
        one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
        # 使用scatter_函数将one_hot张量中的指定位置设为1,形成one-hot编码,排除最后一个类别
        one_hot.scatter_(2, targets.unsqueeze(-1), 1)
        one_hot = one_hot[..., :-1]
        # 将每个类别的得分乘以对应的one-hot编码,得到每个样本每个查询点的分类得分
        gt_scores = gt_scores.view(bs, nq, 1) * one_hot

        # 计算分类损失
        if self.fl:
            if num_gts and self.vfl:
                loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
            else:
                loss_cls = self.fl(pred_scores, one_hot.float())
            loss_cls /= max(num_gts, 1) / nq
        else:
            # 使用二分类交叉熵损失函数计算损失,mean(1)对每个查询点求均值,sum()对所有查询点求和
            loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum()  # YOLO CLS loss

        # 返回损失值,乘以分类损失增益系数
        return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}

    # 计算边界框损失,包括预测边界框和实际边界框的L1损失和GIoU损失
    def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
        """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
        boxes.
        """
        # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
        name_bbox = f"loss_bbox{postfix}"
        name_giou = f"loss_giou{postfix}"

        loss = {}
        if len(gt_bboxes) == 0:
            # 如果没有实际边界框,损失为0
            loss[name_bbox] = torch.tensor(0.0, device=self.device)
            loss[name_giou] = torch.tensor(0.0, device=self.device)
            return loss

        # 计算边界框损失
        loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
        # 计算GIoU损失
        loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
        loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
        loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]

        # 返回损失字典,将值展平为标量
        return {k: v.squeeze() for k, v in loss.items()}
    #     loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
    #                                                                     torch.tensor([num_gts], dtype=torch.float32))
    #     loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
    #     return loss

    # This function is for future RT-DETR Segment models
    # @staticmethod
    # def _dice_loss(inputs, targets, num_gts):
    #     inputs = F.sigmoid(inputs).flatten(1)
    #     targets = targets.flatten(1)
    #     numerator = 2 * (inputs * targets).sum(1)
    #     denominator = inputs.sum(-1) + targets.sum(-1)
    #     loss = 1 - (numerator + 1) / (denominator + 1)
    #     return loss.sum() / num_gts

    def _get_loss_aux(
        self,
        pred_bboxes,
        pred_scores,
        gt_bboxes,
        gt_cls,
        gt_groups,
        match_indices=None,
        postfix="",
        masks=None,
        gt_mask=None,
    ):
        """Get auxiliary losses."""
        # NOTE: loss class, bbox, giou, mask, dice
        # Initialize a tensor to hold different types of losses
        loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
        
        # If match_indices is not provided and using uni_match, compute match_indices using matcher
        if match_indices is None and self.use_uni_match:
            match_indices = self.matcher(
                pred_bboxes[self.uni_match_ind],
                pred_scores[self.uni_match_ind],
                gt_bboxes,
                gt_cls,
                gt_groups,
                masks=masks[self.uni_match_ind] if masks is not None else None,
                gt_mask=gt_mask,
            )
        
        # Iterate over predicted boxes and scores to compute losses
        for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
            aux_masks = masks[i] if masks is not None else None
            # Compute losses using _get_loss function
            loss_ = self._get_loss(
                aux_bboxes,
                aux_scores,
                gt_bboxes,
                gt_cls,
                gt_groups,
                masks=aux_masks,
                gt_mask=gt_mask,
                postfix=postfix,
                match_indices=match_indices,
            )
            # Accumulate class, bbox, and giou losses
            loss[0] += loss_[f"loss_class{postfix}"]
            loss[1] += loss_[f"loss_bbox{postfix}"]
            loss[2] += loss_[f"loss_giou{postfix}"]
            
            # Uncomment below section if handling mask and dice losses
            # if masks is not None and gt_mask is not None:
            #     loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
            #     loss[3] += loss_[f'loss_mask{postfix}']
            #     loss[4] += loss_[f'loss_dice{postfix}']

        # Construct a dictionary with computed losses
        loss = {
            f"loss_class_aux{postfix}": loss[0],
            f"loss_bbox_aux{postfix}": loss[1],
            f"loss_giou_aux{postfix}": loss[2],
        }
        
        # Uncomment below section if handling mask and dice losses
        # if masks is not None and gt_mask is not None:
        #     loss[f'loss_mask_aux{postfix}'] = loss[3]
        #     loss[f'loss_dice_aux{postfix}'] = loss[4]

        # Return the dictionary of computed losses
        return loss

    @staticmethod
    def _get_index(match_indices):
        """Returns batch indices, source indices, and destination indices from provided match indices."""
        # 生成一个批次索引,源索引和目标索引,从匹配索引中提取
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
        src_idx = torch.cat([src for (src, _) in match_indices])
        dst_idx = torch.cat([dst for (_, dst) in match_indices])
        return (batch_idx, src_idx), dst_idx

    def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
        """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
        # 根据匹配索引将预测的边界框分配给真实边界框
        pred_assigned = torch.cat(
            [
                t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
                for t, (i, _) in zip(pred_bboxes, match_indices)
            ]
        )
        gt_assigned = torch.cat(
            [
                t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
                for t, (_, j) in zip(gt_bboxes, match_indices)
            ]
        )
        return pred_assigned, gt_assigned

    def _get_loss(
        self,
        pred_bboxes,
        pred_scores,
        gt_bboxes,
        gt_cls,
        gt_groups,
        masks=None,
        gt_mask=None,
        postfix="",
        match_indices=None,
    ):
        """Get losses."""
        # 如果没有提供匹配索引,则调用self.matcher计算匹配索引
        if match_indices is None:
            match_indices = self.matcher(
                pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
            )

        # 调用_get_index函数获取索引和对应的真实边界框索引
        idx, gt_idx = self._get_index(match_indices)
        pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]

        bs, nq = pred_scores.shape[:2]
        # 创建一个全是self.nc的张量,作为分类目标
        targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
        targets[idx] = gt_cls[gt_idx]

        gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
        # 如果gt_bboxes非空,则计算预测边界框与真实边界框的IoU作为得分
        if len(gt_bboxes):
            gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)

        loss = {}
        # 调用_get_loss_class和_get_loss_bbox计算分类损失和边界框损失
        loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
        loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
        # 如果masks和gt_mask都不为None,则调用_get_loss_mask计算掩码损失
        # if masks is not None and gt_mask is not None:
        #     loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
        return loss
    def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
        """
        Args:
            pred_bboxes (torch.Tensor): [l, b, query, 4]
            pred_scores (torch.Tensor): [l, b, query, num_classes]
            batch (dict): A dict includes:
                gt_cls (torch.Tensor) with shape [num_gts, ],
                gt_bboxes (torch.Tensor): [num_gts, 4],
                gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
            postfix (str): postfix of loss name.
        """
        # 设定当前设备为预测边界框的设备
        self.device = pred_bboxes.device
        # 获取匹配索引,如果未提供则为 None
        match_indices = kwargs.get("match_indices", None)
        # 从批次中获取真实类别、边界框和分组信息
        gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]

        # 计算总损失,传入最后一个预测结果的边界框和分数,真实边界框、类别、分组信息及后缀
        total_loss = self._get_loss(
            pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
        )

        # 如果有辅助损失,则添加到总损失中
        if self.aux_loss:
            total_loss.update(
                self._get_loss_aux(
                    pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
                )
            )

        # 返回计算得到的总损失
        return total_loss
    # 定义了一个 RT-DETR 检测损失类,继承自 DETRLoss 类
    """
    Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.

    This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
    an additional denoising training loss when provided with denoising metadata.
    """

    # 前向传播函数,用于计算检测损失
    def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
        """
        Forward pass to compute the detection loss.

        Args:
            preds (tuple): Predicted bounding boxes and scores.
            batch (dict): Batch data containing ground truth information.
            dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
            dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
            dn_meta (dict, optional): Metadata for denoising. Default is None.

        Returns:
            (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
        """
        # 解析预测的边界框和分数
        pred_bboxes, pred_scores = preds
        # 计算标准检测损失
        total_loss = super().forward(pred_bboxes, pred_scores, batch)

        # 检查是否提供了去噪元数据以计算去噪训练损失
        if dn_meta is not None:
            # 提取去噪正样本索引和组数
            dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
            # 断言批数据中的组数量与去噪正样本索引长度相等
            assert len(batch["gt_groups"]) == len(dn_pos_idx)

            # 获取用于去噪的匹配索引
            match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])

            # 计算去噪训练损失
            dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
            total_loss.update(dn_loss)
        else:
            # 如果没有提供去噪元数据,则将去噪损失设置为零
            total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})

        # 返回总损失字典
        return total_loss

    @staticmethod
    def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
        """
        Get the match indices for denoising.

        Args:
            dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
            dn_num_group (int): Number of denoising groups.
            gt_groups (List[int]): List of integers representing the number of ground truths for each image.

        Returns:
            (List[tuple]): List of tuples containing matched indices for denoising.
        """
        # 初始化一个空列表,用于存储匹配的索引
        dn_match_indices = []
        
        # 计算每个图像的累积索引组
        idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
        
        # 遍历每个图像及其对应的ground truth数目
        for i, num_gt in enumerate(gt_groups):
            if num_gt > 0:
                # 生成包含所有ground truth索引的张量
                gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
                # 将ground truth索引重复dn_num_group次,以匹配denoising组的数目
                gt_idx = gt_idx.repeat(dn_num_group)
                # 断言:确保dn_pos_idx[i]和gt_idx长度相同
                assert len(dn_pos_idx[i]) == len(gt_idx), f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
                # 将匹配的索引对加入到dn_match_indices列表中
                dn_match_indices.append((dn_pos_idx[i], gt_idx))
            else:
                # 如果ground truth数目为0,则创建一个空的张量对
                dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
        
        # 返回包含所有匹配索引对的列表
        return dn_match_indices

.\yolov8\ultralytics\models\utils\ops.py

# Ultralytics YOLO 🚀, AGPL-3.0 license

import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch神经网络模块
import torch.nn.functional as F  # 导入PyTorch函数模块
from scipy.optimize import linear_sum_assignment  # 导入SciPy库中的linear_sum_assignment函数

from ultralytics.utils.metrics import bbox_iou  # 导入Ultralytics工具包中的bbox_iou函数
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh  # 导入Ultralytics工具包中的坐标转换函数


class HungarianMatcher(nn.Module):
    """
    A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an
    end-to-end fashion.

    HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost
    function that considers classification scores, bounding box coordinates, and optionally, mask predictions.

    Attributes:
        cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'.
        use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation.
        with_mask (bool): Indicates whether the model makes mask predictions.
        num_sample_points (int): The number of sample points used in mask cost calculation.
        alpha (float): The alpha factor in Focal Loss calculation.
        gamma (float): The gamma factor in Focal Loss calculation.

    Methods:
        forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the
            assignment between predictions and ground truths for a batch.
        _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted.
    """

    def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
        """Initializes HungarianMatcher with cost coefficients, Focal Loss, mask prediction, sample points, and alpha
        gamma factors.
        """
        super().__init__()
        if cost_gain is None:
            cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
        self.cost_gain = cost_gain  # 设置成本系数字典,包括'class', 'bbox', 'giou', 'mask', 'dice'
        self.use_fl = use_fl  # 是否使用Focal Loss进行分类成本计算
        self.with_mask = with_mask  # 模型是否进行了掩模预测
        self.num_sample_points = num_sample_points  # 掩模成本计算中使用的样本点数目
        self.alpha = alpha  # Focal Loss计算中的alpha系数
        self.gamma = gamma  # Focal Loss计算中的gamma系数

    # This function is for future RT-DETR Segment models
    # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
    #     assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
    #     # all masks share the same set of points for efficient matching
    #     sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
    #     sample_points = 2.0 * sample_points - 1.0
    #
    #     out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
    #     out_mask = out_mask.flatten(0, 1)
    #
    #     tgt_mask = torch.cat(gt_mask).unsqueeze(1)
    #     sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
    #     tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
    #
    # 使用 torch.amp 自动混合精度,禁用 CUDA
    with torch.amp.autocast("cuda", enabled=False):
        # 计算二进制交叉熵损失
        pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
        neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
        cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
        cost_mask /= self.num_sample_points
        
        # 计算 Dice 损失
        out_mask = F.sigmoid(out_mask)
        numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
        denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
        cost_dice = 1 - (numerator + 1) / (denominator + 1)
        
        # 计算最终的损失函数 C,结合二进制交叉熵损失和 Dice 损失,根据设定的权重
        C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
    # 返回最终的损失 C
    return C
def get_cdn_group(
    batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False
):
    """
    Get contrastive denoising training group. This function creates a contrastive denoising training group with positive
    and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates,
    and returns the modified labels, bounding boxes, attention mask and meta information.

    Args:
        batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes'
            (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length
            indicating the number of gts of each image.
        num_classes (int): Number of classes.
        num_queries (int): Number of queries.
        class_embed (torch.Tensor): Embedding weights to map class labels to embedding space.
        num_dn (int, optional): Number of denoising. Defaults to 100.
        cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5.
        box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0.
        training (bool, optional): If it's in training mode. Defaults to False.

    Returns:
        (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings,
            bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
            is less than or equal to 0, the function returns None for all elements in the tuple.
    """

    # 如果不处于训练模式或者 num_dn 小于等于 0,则返回 None
    if (not training) or num_dn <= 0:
        return None, None, None, None

    # 从 batch 中获取 gt_groups,即每张图像中 gt 的数量列表
    gt_groups = batch["gt_groups"]
    # 计算总的 gt 数量
    total_num = sum(gt_groups)
    # 获取一个 batch 中最大的 gt 数量
    max_nums = max(gt_groups)
    
    # 如果最大的 gt 数量为 0,则返回 None
    if max_nums == 0:
        return None, None, None, None
    
    # 计算每个 group 中的数量,确保至少为 1
    num_group = num_dn // max_nums
    num_group = 1 if num_group == 0 else num_group
    
    # 获取 batch 的大小
    bs = len(gt_groups)
    
    # 从 batch 中获取 gt_cls 和 gt_bbox
    gt_cls = batch["cls"]  # (bs*num, )
    gt_bbox = batch["bboxes"]  # bs*num, 4
    b_idx = batch["batch_idx"]
    
    # 每个 group 包含正负样本
    dn_cls = gt_cls.repeat(2 * num_group)  # (2*num_group*bs*num, )
    dn_bbox = gt_bbox.repeat(2 * num_group, 1)  # 2*num_group*bs*num, 4
    dn_b_idx = b_idx.repeat(2 * num_group).view(-1)  # (2*num_group*bs*num, )
    
    # 创建负样本的索引
    neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
    
    # 如果 cls_noise_ratio 大于 0,则对 dn_cls 应用噪声
    if cls_noise_ratio > 0:
        # 生成一个掩码,以半概率应用于 bbox
        mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
        idx = torch.nonzero(mask).squeeze(-1)
        # 随机生成新的标签
        new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
        dn_cls[idx] = new_label
    # 如果盒子噪声比例大于0,则进行以下操作
    known_bbox = xywh2xyxy(dn_bbox)  # 将相对坐标转换为绝对坐标格式

    # 计算随机扰动的大小
    diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale  # 2*num_group*bs*num, 4

    # 生成随机符号和随机部分
    rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
    rand_part = torch.rand_like(dn_bbox)
    rand_part[neg_idx] += 1.0
    rand_part *= rand_sign

    # 添加随机扰动到已知的边界框
    known_bbox += rand_part * diff

    # 将边界框裁剪到0到1的范围内
    known_bbox.clip_(min=0.0, max=1.0)

    # 将绝对坐标格式的边界框转换回相对坐标格式
    dn_bbox = xyxy2xywh(known_bbox)

    # 对相对坐标进行逆sigmoid变换
    dn_bbox = torch.logit(dn_bbox, eps=1e-6)

num_dn = int(max_nums * 2 * num_group)  # 计算总的去噪查询数

# 创建填充的类别嵌入和边界框
dn_cls_embed = class_embed[dn_cls]  # bs*num * 2 * num_group, 256
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)

# 构建映射索引用于对齐去噪后的查询与原始查询
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)

map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])

# 将类别嵌入和边界框填充到填充张量中
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox

tgt_size = num_dn + num_queries  # 计算目标的总大小
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)  # 创建注意力掩码

# 设定查询与重构之间的匹配不能看到
attn_mask[num_dn:, :num_dn] = True

# 设定重构之间相互不能看到
for i in range(num_group):
    if i == 0:
        attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
    if i == num_group - 1:
        attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
    else:
        attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
        attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True

# 构建去噪任务的元信息字典
dn_meta = {
    "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
    "dn_num_group": num_group,
    "dn_num_split": [num_dn, num_queries],
}

# 返回结果
return (
    padding_cls.to(class_embed.device),
    padding_bbox.to(class_embed.device),
    attn_mask.to(class_embed.device),
    dn_meta,
)
posted @ 2024-09-05 12:00  绝不原创的飞龙  阅读(2)  评论(0编辑  收藏  举报