Transformers-源码解析-九十九-

Transformers 源码解析(九十九)

.\models\sam\image_processing_sam.py

# 指定编码格式为 UTF-8
# 版权声明和许可信息
# 根据 Apache 许可证 2.0 版本,除非符合许可证规定,否则不得使用此文件
# 获取许可证的详细信息,请访问指定的 URL
# 如果适用法律要求或书面同意,本软件按“原样”分发,不提供任何明示或暗示的担保或条件
# 请查阅许可证了解具体语言和限制条款
"""SAM 的图像处理类。"""
# 导入所需的库和模块
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 NumPy 库,并使用 np 别名
import numpy as np

# 导入 Hugging Face 库中的图像处理相关工具和函数
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
# 导入 Hugging Face 库中的图像转换函数和工具
from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
# 导入 Hugging Face 库中的图像处理工具函数
from ...image_utils import (
    IMAGENET_DEFAULT_MEAN,
    IMAGENET_DEFAULT_STD,
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    get_image_size,
    infer_channel_dimension_format,
    is_scaled_image,
    make_list_of_images,
    to_numpy_array,
    valid_images,
    validate_kwargs,
    validate_preprocess_arguments,
)
# 导入 Hugging Face 库中的通用工具和函数
from ...utils import (
    TensorType,
    is_tf_available,
    is_torch_available,
    is_torchvision_available,
    logging,
    requires_backends,
)

# 如果 Torch 可用,则导入 Torch 库和 Torch 的功能模块
if is_torch_available():
    import torch
    import torch.nn.functional as F

# 如果 TorchVision 可用,则从 TorchVision 中导入批量 NMS 函数
if is_torchvision_available():
    from torchvision.ops.boxes import batched_nms

# 如果 TensorFlow 可用,则导入 TensorFlow 库和实验性的 NumPy 模块
if is_tf_available():
    import tensorflow as tf
    from tensorflow.experimental import numpy as tnp
    # 导入 Hugging Face 库中的 TensorFlow 相关工具函数
    from ...tf_utils import flatten, shape_list

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义 SAM 图像处理器类,继承自 BaseImageProcessor 类
class SamImageProcessor(BaseImageProcessor):
    r"""
    构造 SAM 图像处理器。
    """

    # 模型输入名称列表,此处仅包含 'pixel_values'
    model_input_names = ["pixel_values"]

    # SAM 图像处理器的初始化方法
    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        mask_size: Dict[str, int] = None,
        resample: PILImageResampling = PILImageResampling.BILINEAR,
        do_rescale: bool = True,
        rescale_factor: Union[int, float] = 1 / 255,
        do_normalize: bool = True,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_pad: bool = True,
        pad_size: int = None,
        mask_pad_size: int = None,
        do_convert_rgb: bool = True,
        **kwargs,
    ) -> None:
        # 调用父类初始化方法,传入任意关键字参数
        super().__init__(**kwargs)
        # 如果 size 为 None,则设置默认最长边为 1024 的大小字典
        size = size if size is not None else {"longest_edge": 1024}
        # 如果 size 不是字典,则调用函数获取大小字典,不默认转换为正方形
        size = get_size_dict(max_size=size, default_to_square=False) if not isinstance(size, dict) else size

        # 如果 pad_size 为 None,则设置默认高度和宽度都为 1024 的大小字典
        pad_size = pad_size if pad_size is not None else {"height": 1024, "width": 1024}
        # 调用函数获取大小字典,默认转换为正方形
        pad_size = get_size_dict(pad_size, default_to_square=True)

        # 如果 mask_size 为 None,则设置默认最长边为 256 的大小字典
        mask_size = mask_size if mask_size is not None else {"longest_edge": 256}
        # 如果 mask_size 不是字典,则调用函数获取大小字典,不默认转换为正方形
        mask_size = (
            get_size_dict(max_size=mask_size, default_to_square=False)
            if not isinstance(mask_size, dict)
            else mask_size
        )

        # 如果 mask_pad_size 为 None,则设置默认高度和宽度都为 256 的大小字典
        mask_pad_size = mask_pad_size if mask_pad_size is not None else {"height": 256, "width": 256}
        # 调用函数获取大小字典,默认转换为正方形
        mask_pad_size = get_size_dict(mask_pad_size, default_to_square=True)

        # 设置属性,是否进行 resize 操作
        self.do_resize = do_resize
        # 设置属性,图片大小的大小字典
        self.size = size
        # 设置属性,mask 的大小字典
        self.mask_size = mask_size
        # 设置属性,重采样方式
        self.resample = resample
        # 设置属性,是否进行 rescale 操作
        self.do_rescale = do_rescale
        # 设置属性,rescale 的因子
        self.rescale_factor = rescale_factor
        # 设置属性,是否进行 normalize 操作
        self.do_normalize = do_normalize
        # 设置属性,图片的均值,默认为 IMAGENET 的默认均值
        self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
        # 设置属性,图片的标准差,默认为 IMAGENET 的默认标准差
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
        # 设置属性,是否进行 pad 操作
        self.do_pad = do_pad
        # 设置属性,pad 的大小字典
        self.pad_size = pad_size
        # 设置属性,mask 的 pad 的大小字典
        self.mask_pad_size = mask_pad_size
        # 设置属性,是否进行 RGB 转换
        self.do_convert_rgb = do_convert_rgb
        # 设置属性,有效的处理器键列表
        self._valid_processor_keys = [
            "images",
            "segmentation_maps",
            "do_resize",
            "size",
            "mask_size",
            "resample",
            "do_rescale",
            "rescale_factor",
            "do_normalize",
            "image_mean",
            "image_std",
            "do_pad",
            "pad_size",
            "mask_pad_size",
            "do_convert_rgb",
            "return_tensors",
            "data_format",
            "input_data_format",
        ]

    def pad_image(
        self,
        image: np.ndarray,
        pad_size: Dict[str, int],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Pad an image to `(pad_size["height"], pad_size["width"])` with zeros to the right and bottom.

        Args:
            image (`np.ndarray`):
                Image to pad.
            pad_size (`Dict[str, int]`):
                Size of the output image after padding.
            data_format (`str` or `ChannelDimension`, *optional*):
                The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
                `data_format` of the `image` will be used.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # 获取目标填充后的输出高度和宽度
        output_height, output_width = pad_size["height"], pad_size["width"]
        # 获取输入图像的高度和宽度
        input_height, input_width = get_image_size(image, channel_dim=input_data_format)

        # 计算需要填充的宽度和高度
        pad_width = output_width - input_width
        pad_height = output_height - input_height

        # 使用零填充图像到目标大小
        padded_image = pad(
            image,
            ((0, pad_height), (0, pad_width)),
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )
        return padded_image

    def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
        """
        Compute the output size given input size and target long side length.
        """
        # 获取输入形状的高度和宽度
        oldh, oldw = old_shape
        # 计算缩放比例,确保最长边达到目标长度
        scale = longest_edge * 1.0 / max(oldh, oldw)
        # 计算新的高度和宽度
        newh, neww = oldh * scale, oldw * scale
        # 四舍五入并转换为整数
        newh = int(newh + 0.5)
        neww = int(neww + 0.5)
        return (newh, neww)

    def resize(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        resample: PILImageResampling = PILImageResampling.BICUBIC,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
        ):
        """
        Resize an image to a specific size.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Target size of the output image after resizing, specified as {'height': height, 'width': width}.
            resample (`PILImageResampling`, *optional*):
                Resampling method. Default is `PILImageResampling.BICUBIC`.
            data_format (`str` or `ChannelDimension`, *optional*):
                The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
                `data_format` of the `image` will be used.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
        """
        # TODO: Add implementation for image resizing
        pass
    ) -> np.ndarray:
        """
        Resize an image to `(size["height"], size["width"])`.

        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Dictionary in the format `{"longest_edge": int}` specifying the size of the output image. The longest
                edge of the image will be resized to the specified size, while the other edge will be resized to
                maintain the aspect ratio.
            resample:
                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
            data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The resized image.
        """
        size = get_size_dict(size)  # 调用函数 get_size_dict 将 size 转换为标准格式的尺寸字典
        if "longest_edge" not in size:
            raise ValueError(f"The `size` dictionary must contain the key `longest_edge`. Got {size.keys()}")
        input_size = get_image_size(image, channel_dim=input_data_format)  # 获取输入图片的尺寸信息
        output_height, output_width = self._get_preprocess_shape(input_size, size["longest_edge"])  # 根据输入和输出尺寸计算预处理后的图像尺寸
        return resize(
            image,
            size=(output_height, output_width),  # 调整图像大小为指定的输出尺寸
            resample=resample,  # 使用指定的重采样方法
            data_format=data_format,  # 输出图像的通道顺序格式
            input_data_format=input_data_format,  # 输入图像的通道顺序格式
            **kwargs,  # 其它可选参数
        )

    def _preprocess(
        self,
        image: ImageInput,
        do_resize: bool,
        do_rescale: bool,
        do_normalize: bool,
        size: Optional[Dict[str, int]] = None,
        resample: PILImageResampling = None,
        rescale_factor: Optional[float] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_pad: Optional[bool] = None,
        pad_size: Optional[Dict[str, int]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ):
        # 如果需要调整大小,则调用 resize 方法调整图像大小
        if do_resize:
            image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
        # 获取调整后图像的大小
        reshaped_input_size = get_image_size(image, channel_dim=input_data_format)

        # 如果需要重新缩放,则调用 rescale 方法重新缩放图像
        if do_rescale:
            image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

        # 如果需要归一化,则调用 normalize 方法对图像进行归一化处理
        if do_normalize:
            image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

        # 如果需要填充,则调用 pad_image 方法对图像进行填充处理
        if do_pad:
            image = self.pad_image(image=image, pad_size=pad_size, input_data_format=input_data_format)

        # 返回预处理后的图像及其调整前后的大小信息
        return image, reshaped_input_size

    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: Optional[bool] = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_rescale: bool = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_pad: Optional[bool] = None,
        pad_size: Optional[Dict[str, int]] = None,
        do_convert_rgb: Optional[bool] = None,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
        # 将图像转换为 numpy 数组
        image = to_numpy_array(image)

        # 如果需要将 PIL RGBA 图像转换为 RGB 格式
        if do_convert_rgb:
            image = convert_to_rgb(image)

        # 所有的转换操作都期望输入为 numpy 数组
        image = to_numpy_array(image)

        # 如果输入图像已经进行了缩放并且需要重新缩放,则发出警告
        if is_scaled_image(image) and do_rescale:
            logger.warning_once(
                "It looks like you are trying to rescale already rescaled images. If the input"
                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
            )

        # 推断输入数据的通道维度格式
        if input_data_format is None:
            input_data_format = infer_channel_dimension_format(image)

        # 获取原始图像的大小
        original_size = get_image_size(image, channel_dim=input_data_format)

        # 对图像进行预处理,并获取预处理后的图像及其调整后的大小信息
        image, reshaped_input_size = self._preprocess(
            image=image,
            do_resize=do_resize,
            size=size,
            resample=resample,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
            do_pad=do_pad,
            pad_size=pad_size,
            input_data_format=input_data_format,
        )

        # 如果指定了输出数据格式,则将图像转换为该格式
        if data_format is not None:
            image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

        # 返回最终的预处理结果,包括图像及其原始大小和调整后的大小信息
        return image, original_size, reshaped_input_size
    # 对分割地图进行预处理,返回处理后的分割地图和原始尺寸
    def _preprocess_mask(
        self,
        segmentation_map: ImageInput,
        do_resize: Optional[bool] = None,
        mask_size: Dict[str, int] = None,
        do_pad: Optional[bool] = None,
        mask_pad_size: Optional[Dict[str, int]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        # 将分割地图转换为 NumPy 数组
        segmentation_map = to_numpy_array(segmentation_map)

        # 如果分割地图是二维的,则添加通道维度,某些转换需要此维度
        if segmentation_map.ndim == 2:
            added_channel_dim = True
            segmentation_map = segmentation_map[None, ...]  # 添加通道维度
            input_data_format = ChannelDimension.FIRST  # 设置数据格式为通道维度在最前面
        else:
            added_channel_dim = False
            if input_data_format is None:
                # 推断通道维度格式,确保一维通道格式
                input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)

        # 获取原始图像尺寸
        original_size = get_image_size(segmentation_map, channel_dim=input_data_format)

        # 对分割地图进行预处理,包括调整大小、填充等操作
        segmentation_map, _ = self._preprocess(
            image=segmentation_map,
            do_resize=do_resize,
            size=mask_size,
            resample=PILImageResampling.NEAREST,
            do_rescale=False,
            do_normalize=False,
            do_pad=do_pad,
            pad_size=mask_pad_size,
            input_data_format=input_data_format,
        )

        # 如果之前添加了额外的通道维度,则在处理完成后去除
        if added_channel_dim:
            segmentation_map = segmentation_map.squeeze(0)  # 去除添加的通道维度
        segmentation_map = segmentation_map.astype(np.int64)  # 将分割地图转换为整型

        # 返回处理后的分割地图和原始尺寸
        return segmentation_map, original_size

    # 对图像及其分割地图进行预处理,支持多种处理选项
    def preprocess(
        self,
        images: ImageInput,
        segmentation_maps: Optional[ImageInput] = None,
        do_resize: Optional[bool] = None,
        size: Optional[Dict[str, int]] = None,
        mask_size: Optional[Dict[str, int]] = None,
        resample: Optional["PILImageResampling"] = None,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[Union[int, float]] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        do_pad: Optional[bool] = None,
        pad_size: Optional[Dict[str, int]] = None,
        mask_pad_size: Optional[Dict[str, int]] = None,
        do_convert_rgb: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: ChannelDimension = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ):
        # 函数用于图像和分割地图的预处理,支持多种选项和参数配置
        pass

    # 对预处理后的掩模进行后处理,包括阈值处理和二值化等
    def post_process_masks(
        self,
        masks,
        original_sizes,
        reshaped_input_sizes,
        mask_threshold=0.0,
        binarize=True,
        pad_size=None,
        return_tensors="pt",
        **kwargs,
    ):
        # 对预处理后的掩模进行后处理,支持阈值处理、二值化和填充操作
        pass
    ):
        """
        Remove padding and upscale masks to the original image size.

        Args:
            masks (`Union[List[torch.Tensor], List[np.ndarray], List[tf.Tensor]]`):
                Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
            original_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
                The original sizes of each image before it was resized to the model's expected input shape, in (height,
                width) format.
            reshaped_input_sizes (`Union[torch.Tensor, tf.Tensor, List[Tuple[int,int]]]`):
                The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                The threshold to use for binarizing the masks.
            binarize (`bool`, *optional*, defaults to `True`):
                Whether to binarize the masks.
            pad_size (`int`, *optional*, defaults to `self.pad_size`):
                The target size the images were padded to before being passed to the model. If None, the target size is
                assumed to be the processor's `pad_size`.
            return_tensors (`str`, *optional*, defaults to `"pt"`):
                If `"pt"`, return PyTorch tensors. If `"tf"`, return TensorFlow tensors.
        Returns:
            (`Union[torch.Tensor, tf.Tensor]`): Batched masks in batch_size, num_channels, height, width) format, where
            (height, width) is given by original_size.
        """
        # 根据 return_tensors 参数选择使用 PyTorch 或 TensorFlow 的后处理函数
        if return_tensors == "pt":
            return self._post_process_masks_pt(
                masks=masks,
                original_sizes=original_sizes,
                reshaped_input_sizes=reshaped_input_sizes,
                mask_threshold=mask_threshold,
                binarize=binarize,
                pad_size=pad_size,
            )
        elif return_tensors == "tf":
            return self._post_process_masks_tf(
                masks=masks,
                original_sizes=original_sizes,
                reshaped_input_sizes=reshaped_input_sizes,
                mask_threshold=mask_threshold,
                binarize=binarize,
                pad_size=pad_size,
            )
        else:
            # 如果 return_tensors 参数既不是 "pt" 也不是 "tf",抛出数值错误异常
            raise ValueError("return_tensors must be either 'pt' or 'tf'")

    def _post_process_masks_pt(
        self, masks, original_sizes, reshaped_input_sizes, mask_threshold=0.0, binarize=True, pad_size=None
    ):
        """
        Remove padding and upscale masks to the original image size.

        Args:
            masks (`Union[List[torch.Tensor], List[np.ndarray]]`):
                Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
            original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
                The original sizes of each image before it was resized to the model's expected input shape, in (height,
                width) format.
            reshaped_input_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
                The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                The threshold to use for binarizing the masks.
            binarize (`bool`, *optional*, defaults to `True`):
                Whether to binarize the masks.
            pad_size (`int`, *optional*, defaults to `self.pad_size`):
                The target size the images were padded to before being passed to the model. If None, the target size is
                assumed to be the processor's `pad_size`.
        Returns:
            (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
            is given by original_size.
        """
        requires_backends(self, ["torch"])  # 确保当前环境支持 torch
        pad_size = self.pad_size if pad_size is None else pad_size  # 如果未指定 pad_size,则使用类的默认值
        target_image_size = (pad_size["height"], pad_size["width"])  # 获取目标图像尺寸 (height, width)

        if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
            original_sizes = original_sizes.tolist()  # 将 original_sizes 转换为列表形式
        if isinstance(reshaped_input_sizes, (torch.Tensor, np.ndarray)):
            reshaped_input_sizes = reshaped_input_sizes.tolist()  # 将 reshaped_input_sizes 转换为列表形式

        output_masks = []  # 初始化空列表,用于存储输出的 masks

        for i, original_size in enumerate(original_sizes):
            if isinstance(masks[i], np.ndarray):
                masks[i] = torch.from_numpy(masks[i])  # 如果 masks[i] 是 np.ndarray,则转换为 torch.Tensor
            elif not isinstance(masks[i], torch.Tensor):
                raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")

            # 插值操作,将 masks[i] 缩放到 target_image_size 大小
            interpolated_mask = F.interpolate(masks[i], target_image_size, mode="bilinear", align_corners=False)
            # 截取插值后的结果,保留至 reshaped_input_sizes[i] 大小
            interpolated_mask = interpolated_mask[..., : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1]]
            # 再次插值,将 interpolated_mask 缩放至 original_size 大小
            interpolated_mask = F.interpolate(interpolated_mask, original_size, mode="bilinear", align_corners=False)

            if binarize:
                interpolated_mask = interpolated_mask > mask_threshold  # 根据阈值进行二值化处理

            output_masks.append(interpolated_mask)  # 将处理后的 mask 添加到输出列表中

        return output_masks  # 返回处理后的输出 masks
    ):
        """
        Remove padding and upscale masks to the original image size.

        Args:
            masks (`tf.Tensor`):
                Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
            original_sizes (`tf.Tensor`):
                The original size of the images before resizing for input to the model, in (height, width) format.
            reshaped_input_sizes (`tf.Tensor`):
                The size of the image input to the model, in (height, width) format. Used to remove padding.
            mask_threshold (`float`, *optional*, defaults to 0.0):
                The threshold to use for binarizing the masks.
            binarize (`bool`, *optional*, defaults to `True`):
                Whether to binarize the masks.
            pad_size (`int`, *optional*, defaults to `self.pad_size`):
                The target size the images were padded to before being passed to the model. If None, the target size is
                assumed to be the processor's `pad_size`.
        Returns:
            (`tf.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width) is
            given by original_size.
        """
        # Ensure the necessary backend for operations
        requires_backends(self, ["tf"])
        
        # Determine the padding size to use
        pad_size = self.pad_size if pad_size is None else pad_size
        target_image_size = (pad_size["height"], pad_size["width"])

        output_masks = []
        for i, original_size in enumerate(original_sizes):
            # Transpose masks to NHWC format as required by tf.image functions
            mask = tf.transpose(masks[i], perm=[0, 2, 3, 1])
            
            # Resize masks to match target_image_size
            interpolated_mask = tf.image.resize(mask, target_image_size, method="bilinear")
            
            # Remove padding from resized masks based on reshaped_input_sizes
            interpolated_mask = interpolated_mask[:, : reshaped_input_sizes[i][0], : reshaped_input_sizes[i][1], :]
            
            # Resize masks to original_size
            interpolated_mask = tf.image.resize(interpolated_mask, original_size, method="bilinear")
            
            # Binarize masks if specified
            if binarize:
                interpolated_mask = interpolated_mask > mask_threshold
            
            # Transpose masks back to original NCHW format
            output_masks.append(tf.transpose(interpolated_mask, perm=[0, 3, 1, 2]))

        return output_masks
        ):
        """
        Post processes mask that are generated by calling the Non Maximum Suppression algorithm on the predicted masks.

        Args:
            all_masks (`Union[List[torch.Tensor], List[tf.Tensor]]`):
                List of all predicted segmentation masks
            all_scores (`Union[List[torch.Tensor], List[tf.Tensor]]`):
                List of all predicted iou scores
            all_boxes (`Union[List[torch.Tensor], List[tf.Tensor]]`):
                List of all bounding boxes of the predicted masks
            crops_nms_thresh (`float`):
                Threshold for NMS (Non Maximum Suppression) algorithm.
            return_tensors (`str`, *optional*, defaults to `pt`):
                If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
        """
        if return_tensors == "pt":
            # 如果返回类型是 `pt`,调用 _postprocess_for_mg 函数对预测的mask进行后处理
            return _postprocess_for_mg(all_masks, all_scores, all_boxes, crops_nms_thresh)
        elif return_tensors == "tf":
            # 如果返回类型是 `tf`,调用 _postprocess_for_mg_tf 函数对预测的mask进行后处理
            return _postprocess_for_mg_tf(all_masks, all_scores, all_boxes, crops_nms_thresh)

    def generate_crop_boxes(
        self,
        image,
        target_size,
        crop_n_layers: int = 0,
        overlap_ratio: float = 512 / 1500,
        points_per_crop: Optional[int] = 32,
        crop_n_points_downscale_factor: Optional[List[int]] = 1,
        device: Optional["torch.device"] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        return_tensors: str = "pt",
    ):
        """
        Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.

        Args:
            image (`np.array`):
                Input original image
            target_size (`int`):
                Target size of the resized image
            crop_n_layers (`int`, *optional*, defaults to 0):
                If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where
                each layer has 2**i_layer number of image crops.
            overlap_ratio (`float`, *optional*, defaults to 512/1500):
                Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
                the image length. Later layers with more crops scale down this overlap.
            points_per_crop (`int`, *optional*, defaults to 32):
                Number of points to sample from each crop.
            crop_n_points_downscale_factor (`List[int]`, *optional*, defaults to 1):
                The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
            device (`torch.device`, *optional*, defaults to None):
                Device to use for the computation. If None, cpu will be used.
            input_data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the input image. If not provided, it will be inferred.
            return_tensors (`str`, *optional*, defaults to `pt`):
                If `pt`, returns `torch.Tensor`. If `tf`, returns `tf.Tensor`.
        """
        # Generate crop boxes, sample points, cropped images, and labels from the input image
        crop_boxes, points_per_crop, cropped_images, input_labels = _generate_crop_boxes(
            image,
            target_size,
            crop_n_layers,
            overlap_ratio,
            points_per_crop,
            crop_n_points_downscale_factor,
            input_data_format,
        )
        # Convert outputs to PyTorch tensors if return_tensors is 'pt'
        if return_tensors == "pt":
            # If device is not specified, default to CPU
            if device is None:
                device = torch.device("cpu")
            # Convert crop boxes, points_per_crop, and input_labels to PyTorch tensors
            crop_boxes = torch.tensor(crop_boxes, device=device)
            points_per_crop = torch.tensor(points_per_crop, device=device)
            # cropped_images remains as NumPy array
            input_labels = torch.tensor(input_labels, device=device)

        # Convert outputs to TensorFlow tensors if return_tensors is 'tf'
        elif return_tensors == "tf":
            # TensorFlow does not support device specification in this context
            if device is not None:
                raise ValueError("device is not a supported argument when return_tensors is tf!")
            # Convert crop boxes, points_per_crop, and input_labels to TensorFlow tensors
            crop_boxes = tf.convert_to_tensor(crop_boxes)
            points_per_crop = tf.convert_to_tensor(points_per_crop)
            # cropped_images remains as NumPy array
            input_labels = tf.convert_to_tensor(input_labels)
        else:
            # Raise an error if return_tensors is neither 'pt' nor 'tf'
            raise ValueError("return_tensors must be either 'pt' or 'tf'.")
        # Return generated crop boxes, points per crop, cropped images, and input labels
        return crop_boxes, points_per_crop, cropped_images, input_labels
        """
        根据给定的条件过滤预测的掩码,并执行必要的转换和填充操作。

        Args:
            masks (`Union[torch.Tensor, tf.Tensor]`):
                输入的掩码张量。
            iou_scores (`Union[torch.Tensor, tf.Tensor]`):
                IoU(Intersection over Union)分数的列表。
            original_size (`Tuple[int,int]`):
                原始图像的尺寸。
            cropped_box_image (`np.array`):
                裁剪后的图像数组。
            pred_iou_thresh (`float`, *optional*, 默认为 0.88):
                IoU 分数的阈值。
            stability_score_thresh (`float`, *optional*, 默认为 0.95):
                稳定性分数的阈值。
            mask_threshold (`float`, *optional*, 默认为 0):
                预测掩码的阈值。
            stability_score_offset (`float`, *optional*, 默认为 1):
                在 `_compute_stability_score` 方法中使用的稳定性分数的偏移量。
            return_tensors (`str`, *optional*, 默认为 `pt`):
                如果是 `pt`,返回 `torch.Tensor`;如果是 `tf`,返回 `tf.Tensor`。
        """
        if return_tensors == "pt":
            # 调用基于 PyTorch 的掩码过滤方法
            return self._filter_masks_pt(
                masks=masks,
                iou_scores=iou_scores,
                original_size=original_size,
                cropped_box_image=cropped_box_image,
                pred_iou_thresh=pred_iou_thresh,
                stability_score_thresh=stability_score_thresh,
                mask_threshold=mask_threshold,
                stability_score_offset=stability_score_offset,
            )
        elif return_tensors == "tf":
            # 调用基于 TensorFlow 的掩码过滤方法
            return self._filter_masks_tf(
                masks=masks,
                iou_scores=iou_scores,
                original_size=original_size,
                cropped_box_image=cropped_box_image,
                pred_iou_thresh=pred_iou_thresh,
                stability_score_thresh=stability_score_thresh,
                mask_threshold=mask_threshold,
                stability_score_offset=stability_score_offset,
            )
        """
        Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
        that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
        score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
        bounding boxes and pad the predicted masks if necessary.

        Args:
            masks (`torch.Tensor`):
                Input masks.
            iou_scores (`torch.Tensor`):
                List of IoU scores.
            original_size (`Tuple[int,int]`):
                Size of the original image.
            cropped_box_image (`np.array`):
                The cropped image.
            pred_iou_thresh (`float`, *optional*, defaults to 0.88):
                The threshold for the IoU scores.
            stability_score_thresh (`float`, *optional*, defaults to 0.95):
                The threshold for the stability score.
            mask_threshold (`float`, *optional*, defaults to 0):
                The threshold for the predicted masks.
            stability_score_offset (`float`, *optional*, defaults to 1):
                The offset for the stability score used in the `_compute_stability_score` method.

        """
        # Ensure the torch backend is available
        requires_backends(self, ["torch"])
        
        # Extract dimensions of the original image
        original_height, original_width = original_size
        
        # Flatten masks and IoU scores for easier manipulation
        iou_scores = iou_scores.flatten(0, 1)
        masks = masks.flatten(0, 1)
        
        # Check if the number of masks matches the number of IoU scores
        if masks.shape[0] != iou_scores.shape[0]:
            raise ValueError("masks and iou_scores must have the same batch size.")
        
        # Ensure masks and IoU scores are on the same device
        if masks.device != iou_scores.device:
            iou_scores = iou_scores.to(masks.device)
        
        # Determine batch size from the flattened masks
        batch_size = masks.shape[0]
        
        # Initialize a mask to keep all masks (defaulting to True)
        keep_mask = torch.ones(batch_size, dtype=torch.bool, device=masks.device)
        
        # Apply filtering based on IoU threshold
        if pred_iou_thresh > 0.0:
            keep_mask = keep_mask & (iou_scores > pred_iou_thresh)
        
        # Compute stability scores and filter based on stability score threshold
        if stability_score_thresh > 0.0:
            stability_scores = _compute_stability_score_pt(masks, mask_threshold, stability_score_offset)
            keep_mask = keep_mask & (stability_scores > stability_score_thresh)
        
        # Select scores and masks that meet the criteria
        scores = iou_scores[keep_mask]
        masks = masks[keep_mask]
        
        # Binarize masks and convert them to bounding boxes
        masks = masks > mask_threshold
        converted_boxes = _batched_mask_to_box(masks)
        
        # Check if boxes are near the cropped image edges and filter accordingly
        keep_mask = ~_is_box_near_crop_edge(
            converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
        )
        
        # Select final scores, masks, and converted boxes
        scores = scores[keep_mask]
        masks = masks[keep_mask]
        converted_boxes = converted_boxes[keep_mask]
        
        # Pad masks to original image dimensions and convert to RLE format
        masks = _pad_masks(masks, cropped_box_image, original_height, original_width)
        masks = _mask_to_rle_pytorch(masks)
        
        return masks, scores, converted_boxes
    # 定义一个私有方法 `_filter_masks_tf`,用于在 TensorFlow 中过滤掩码
    # 参数 `masks`: 掩码数据
    # 参数 `iou_scores`: IoU(交并比)分数
    # 参数 `original_size`: 原始图像尺寸
    # 参数 `cropped_box_image`: 裁剪后的图像框
    # 参数 `pred_iou_thresh`: 预测 IoU 阈值,默认为 0.88
    # 参数 `stability_score_thresh`: 稳定性分数阈值,默认为 0.95
    # 参数 `mask_threshold`: 掩码阈值,默认为 0
    # 参数 `stability_score_offset`: 稳定性分数偏移,默认为 1
    ):
        """
        Filters the predicted masks by selecting only the ones that meets several criteria. The first criterion being
        that the iou scores needs to be greater than `pred_iou_thresh`. The second criterion is that the stability
        score needs to be greater than `stability_score_thresh`. The method also converts the predicted masks to
        bounding boxes and pad the predicted masks if necessary.

        Args:
            masks (`tf.Tensor`):
                Input masks.
            iou_scores (`tf.Tensor`):
                List of IoU scores.
            original_size (`Tuple[int,int]`):
                Size of the orginal image.
            cropped_box_image (`np.array`):
                The cropped image.
            pred_iou_thresh (`float`, *optional*, defaults to 0.88):
                The threshold for the iou scores.
            stability_score_thresh (`float`, *optional*, defaults to 0.95):
                The threshold for the stability score.
            mask_threshold (`float`, *optional*, defaults to 0):
                The threshold for the predicted masks.
            stability_score_offset (`float`, *optional*, defaults to 1):
                The offset for the stability score used in the `_compute_stability_score` method.

        """
        # Ensure necessary backend support for TensorFlow
        requires_backends(self, ["tf"])
        # Extract dimensions of the original image
        original_height, original_width = original_size
        # Reshape IoU scores tensor for processing
        iou_scores = tf.reshape(iou_scores, [iou_scores.shape[0] * iou_scores.shape[1], iou_scores.shape[2:]])
        # Reshape masks tensor for processing
        masks = tf.reshape(masks, [masks.shape[0] * masks.shape[1], masks.shape[2:]])

        # Check if batch sizes of masks and IoU scores match
        if masks.shape[0] != iou_scores.shape[0]:
            raise ValueError("masks and iou_scores must have the same batch size.")

        # Retrieve batch size from masks tensor
        batch_size = masks.shape[0]

        # Initialize a mask to keep all elements
        keep_mask = tf.ones(batch_size, dtype=tf.bool)

        # Apply filter based on IoU threshold if specified
        if pred_iou_thresh > 0.0:
            keep_mask = keep_mask & (iou_scores > pred_iou_thresh)

        # Compute stability scores and apply filter based on stability score threshold if specified
        if stability_score_thresh > 0.0:
            stability_scores = _compute_stability_score_tf(masks, mask_threshold, stability_score_offset)
            keep_mask = keep_mask & (stability_scores > stability_score_thresh)

        # Filter out masks and scores based on the keep_mask
        scores = iou_scores[keep_mask]
        masks = masks[keep_mask]

        # Binarize masks
        masks = masks > mask_threshold

        # Convert masks to bounding boxes
        converted_boxes = _batched_mask_to_box_tf(masks)

        # Filter out boxes near the cropped image edges
        keep_mask = ~_is_box_near_crop_edge_tf(
            converted_boxes, cropped_box_image, [0, 0, original_width, original_height]
        )

        # Refilter masks, scores, and converted boxes based on the updated keep_mask
        scores = scores[keep_mask]
        masks = masks[keep_mask]
        converted_boxes = converted_boxes[keep_mask]

        # Pad masks to match original image dimensions
        masks = _pad_masks_tf(masks, cropped_box_image, original_height, original_width)

        # Convert masks to RLE format for non-maximum suppression
        masks = _mask_to_rle_tf(masks)

        # Return filtered masks, scores, and converted boxes
        return masks, scores, converted_boxes
# 计算两个掩码之间的稳定性评分的函数,对于每个掩码,使用阈值和偏移量计算与另一个掩码的交集并集比率。

def _compute_stability_score_pt(masks: "torch.Tensor", mask_threshold: float, stability_score_offset: int):
    # 计算两个掩码之间的交集数量,避免不必要的cast到torch.int64,使用int16和int32作为中间类型以节省内存。
    intersections = (
        (masks > (mask_threshold + stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    )
    
    # 计算两个掩码之间的并集数量,使用相同的中间数据类型作为内存优化措施。
    unions = (masks > (mask_threshold - stability_score_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
    
    # 使用交集和并集数量计算稳定性分数。
    stability_scores = intersections / unions
    
    # 返回稳定性分数。
    return stability_scores


# 使用TensorFlow进行相同功能的计算。将mask_threshold和stability_score_offset转换为浮点类型来确保正确进行除法。
def _compute_stability_score_tf(masks: "tf.Tensor", mask_threshold: float, stability_score_offset: int):
    intersections = tf.count_nonzero(
        masks > (mask_threshold + stability_score_offset), axis=[-1, -2], dtype=tf.float32
    )
    unions = tf.count_nonzero(masks > (mask_threshold - stability_score_offset), axis=[-1, -2], dtype=tf.float32)
    stability_scores = intersections / unions
    
    # 返回稳定性分数。
    return stability_scores


# 生成2D网格点列表,这些点在[0,1]x[0,1]区间用等间距插入。
def _build_point_grid(n_per_side: int) -> np.ndarray:
    offset = 1 / (2 * n_per_side)
    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
    points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
    return points


# 对坐标进行标准化,以适应给定的目标尺寸,考虑到原始尺寸和所需的乘法因子。
def _normalize_coordinates(
    target_size: int, coords: np.ndarray, original_size: Tuple[int, int], is_bounding_box=False
) -> np.ndarray:
    old_height, old_width = original_size
    scale = target_size * 1.0 / max(old_height, old_width)
    new_height, new_width = old_height * scale, old_width * scale
    new_width, new_height = int(new_width + 0.5), int(new_height + 0.5)

    # 复制输入数组并将其转换为float类型。
    coords = deepcopy(coords).astype(float)

    if is_bounding_box:
        coords = coords.reshape(-1, 2, 2)

    # 标准化坐标值。
    coords[..., 0] = coords[..., 0] * (new_width / old_width)
    coords[..., 1] = coords[..., 1] * (new_height / old_height)

    if is_bounding_box:
        coords = coords.reshape(-1, 4)

    # 返回标准化的坐标。
    return coords


# 定义生成截取方形盒子的函数,该函数支持自定义层的数量、重叠比率、点的数量和解 xuống因素的参数。
def _generate_crop_boxes(
    image,
    target_size: int,  # 在此处目标尺寸应该是整数还是元组并不是特别清晰, 但是通常此参数代表目标预处理尺寸的大小。
    crop_n_layers: int = 0,
    overlap_ratio: float = 512 / 1500,
    points_per_crop: Optional[int] = 32,
    crop_n_points_downscale_factor: Optional[List[int]] = 1,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[List[List[int]], List[int]]:
    """
    根据指定的参数生成不同的大小的截取框列表,每个层级不同大小的包含 (2^i)^2 个框。
    """
    # 如果输入的图像是列表,则抛出数值错误,仅支持单张图像进行裁剪生成
    if isinstance(image, list):
        raise ValueError("Only one image is allowed for crop generation.")
    
    # 将图像转换为 numpy 数组格式,确保后续处理的统一性
    image = to_numpy_array(image)
    
    # 获取原始图像的尺寸,根据输入数据格式获取
    original_size = get_image_size(image, input_data_format)
    
    # 初始化一个空列表,用于存储各层次的点网格
    points_grid = []
    for i in range(crop_n_layers + 1):
        # 计算每个裁剪区域的采样点数,根据指定的下采样因子进行缩放
        n_points = int(points_per_crop / (crop_n_points_downscale_factor**i))
        # 构建当前层次的点网格并添加到列表中
        points_grid.append(_build_point_grid(n_points))
    
    # 生成裁剪框和层次索引,确定各个裁剪区域在原始图像中的位置和层次
    crop_boxes, layer_idxs = _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size)
    
    # 根据生成的裁剪框和点网格,裁剪原始图像并生成裁剪后的图像以及对应的点网格
    cropped_images, point_grid_per_crop = _generate_crop_images(
        crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format
    )
    
    # 将裁剪框转换为 numpy 数组格式,并将数据类型设置为 float32
    crop_boxes = np.array(crop_boxes)
    crop_boxes = crop_boxes.astype(np.float32)
    
    # 将每个裁剪区域的点网格转换为 numpy 数组格式,并调整维度顺序以匹配后续处理的要求
    points_per_crop = np.array([point_grid_per_crop])
    points_per_crop = np.transpose(points_per_crop, axes=(0, 2, 1, 3))
    
    # 生成输入标签,初始化为与点网格相同大小的全 1 数组,数据类型为 int64
    input_labels = np.ones_like(points_per_crop[:, :, :, 0], dtype=np.int64)
    
    # 返回生成的裁剪框、点网格、裁剪后的图像和对应的输入标签
    return crop_boxes, points_per_crop, cropped_images, input_labels
# 生成每层裁剪框,以XYWH格式表示。XYWH格式包含以下必需索引:
#   - X:边界框左上角的X坐标
#   - Y:边界框左上角的Y坐标
#   - W:边界框的宽度
#   - H:边界框的高度
def _generate_per_layer_crops(crop_n_layers, overlap_ratio, original_size):
    crop_boxes, layer_idxs = [], []  # 初始化裁剪框列表和层索引列表
    im_height, im_width = original_size  # 获取原始图像的高度和宽度
    short_side = min(im_height, im_width)  # 计算图像的较短边

    # 原始图像
    crop_boxes.append([0, 0, im_width, im_height])  # 将整个图像作为一个裁剪框添加到列表中
    layer_idxs.append(0)  # 第一层的索引为0

    # 对于每一层裁剪
    for i_layer in range(crop_n_layers):
        n_crops_per_side = 2 ** (i_layer + 1)  # 计算每边的裁剪数量
        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))  # 计算重叠区域大小

        # 计算裁剪框的宽度和高度
        crop_width = int(math.ceil((overlap * (n_crops_per_side - 1) + im_width) / n_crops_per_side))
        crop_height = int(math.ceil((overlap * (n_crops_per_side - 1) + im_height) / n_crops_per_side))

        # 计算每个裁剪框的左上角坐标
        crop_box_x0 = [int((crop_width - overlap) * i) for i in range(n_crops_per_side)]
        crop_box_y0 = [int((crop_height - overlap) * i) for i in range(n_crops_per_side)]

        # 对每个左上角坐标组合进行裁剪框的生成
        for left, top in product(crop_box_x0, crop_box_y0):
            box = [left, top, min(left + crop_width, im_width), min(top + crop_height, im_height)]
            crop_boxes.append(box)  # 将裁剪框添加到裁剪框列表中
            layer_idxs.append(i_layer + 1)  # 添加相应层索引到层索引列表中

    return crop_boxes, layer_idxs  # 返回裁剪框列表和层索引列表


# 生成裁剪图像
def _generate_crop_images(
    crop_boxes, image, points_grid, layer_idxs, target_size, original_size, input_data_format=None
):
    cropped_images = []  # 初始化裁剪后的图像列表
    total_points_per_crop = []  # 初始化每个裁剪中的总点数列表

    # 遍历所有裁剪框
    for i, crop_box in enumerate(crop_boxes):
        left, top, right, bottom = crop_box  # 获取裁剪框的左上角和右下角坐标

        # 推断通道维度格式
        channel_dim = infer_channel_dimension_format(image, input_data_format)
        if channel_dim == ChannelDimension.LAST:
            cropped_im = image[top:bottom, left:right, :]  # 切片裁剪图像(通道在最后)
        else:
            cropped_im = image[:, top:bottom, left:right]  # 切片裁剪图像(通道在最前)

        cropped_images.append(cropped_im)  # 将裁剪后的图像添加到列表中

        cropped_im_size = get_image_size(cropped_im, channel_dim)  # 获取裁剪后图像的大小
        points_scale = np.array(cropped_im_size)[None, ::-1]  # 计算点的比例缩放

        points = points_grid[layer_idxs[i]] * points_scale  # 缩放对应的点
        normalized_points = _normalize_coordinates(target_size, points, original_size)  # 标准化坐标
        total_points_per_crop.append(normalized_points)  # 添加总点数到列表中

    return cropped_images, total_points_per_crop  # 返回裁剪后的图像列表和总点数列表


# 填充掩模
def _pad_masks(masks, crop_box: List[int], orig_height: int, orig_width: int):
    left, top, right, bottom = crop_box  # 获取裁剪框的左上角和右下角坐标
    if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
        return masks  # 如果裁剪框与原始图像大小相同,直接返回掩模

    # 坐标变换掩模
    pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)  # 计算填充量
    pad = (left, pad_x - left, top, pad_y - top)  # 构建填充元组
    # 使用 PyTorch 的 nn.functional 模块中的 pad 函数对 masks 进行填充操作
    # 参数 masks:需要填充的张量
    # 参数 pad:填充的大小,可以是单个整数表示每个维度填充相同的量,或者是元组表示每个维度填充的前后数量
    # 参数 value:填充时使用的值,默认为 0
    return torch.nn.functional.pad(masks, pad, value=0)
# 对输入的 masks 进行填充,以适应给定的裁剪框大小
def _pad_masks_tf(masks, crop_box: List[int], orig_height: int, orig_width: int):
    left, top, right, bottom = crop_box
    # 如果裁剪框与原始图像大小一致,则直接返回 masks,无需填充
    if left == 0 and top == 0 and right == orig_width and bottom == orig_height:
        return masks
    # 计算需要填充的宽度和高度
    pad_x, pad_y = orig_width - (right - left), orig_height - (bottom - top)
    # 构建填充参数,格式为(left, right, top, bottom)
    pad = (left, pad_x - left, top, pad_y - top)
    # 使用 TensorFlow 的 pad 函数对 masks 进行填充,填充值为常数0
    return tf.pad(masks, pad, constant_values=0)


# 检查边界框是否接近裁剪边缘,但不接近原始图像边缘
def _is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0):
    """Filter masks at the edge of a crop, but not at the edge of the original image."""
    # 将裁剪框和原始框转换为 Torch 张量
    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)

    left, top, _, _ = crop_box
    # 创建偏移量张量,并将其添加到 boxes 张量中
    offset = torch.tensor([[left, top, left, top]], device=boxes.device)
    if len(boxes.shape) == 3:
        offset = offset.unsqueeze(1)
    boxes = (boxes + offset).float()

    # 检查 boxes 是否接近裁剪边缘和图像边缘,使用 torch.isclose 函数进行比较
    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
    # 检查是否有任何接近裁剪边缘的边界框,并返回结果
    return torch.any(near_crop_edge, dim=1)


# 检查边界框是否接近裁剪边缘,但不接近原始图像边缘(使用 TensorFlow)
def _is_box_near_crop_edge_tf(boxes, crop_box, orig_box, atol=20.0):
    """Filter masks at the edge of a crop, but not at the edge of the original image."""
    # 将裁剪框和原始框转换为 TensorFlow 张量
    crop_box_tf = tf.convert_to_tensor(crop_box, dtype=tf.float32)
    orig_box_tf = tf.convert_to_tensor(orig_box, dtype=tf.float32)

    left, top, _, _ = crop_box
    # 创建偏移量张量,并将其添加到 boxes 张量中
    offset = tf.convert_to_tensor([[left, top, left, top]])
    if len(boxes.shape) == 3:
        offset = tf.expand_dims(offset, 1)
    boxes = tf.cast(boxes + offset, tf.float32)

    # 检查 boxes 是否接近裁剪边缘和图像边缘,使用 tfp.math.isclose 函数进行比较
    near_crop_edge = tfp.math.is_close(boxes, crop_box_tf[None, :], atol=atol, rtol=0)
    near_image_edge = tfp.math.is_close(boxes, orig_box_tf[None, :], atol=atol, rtol=0)
    near_crop_edge = tf.math.logical_and(near_crop_edge, ~near_image_edge)
    # 检查是否有任何接近裁剪边缘的边界框,并返回结果
    return tf.reduce_any(near_crop_edge, axis=1)


# 将批量的 masks 转换为包围框(使用 Torch)
def _batched_mask_to_box(masks: "torch.Tensor"):
    """
    Computes the bounding boxes around the given input masks. The bounding boxes are in the XYXY format which
    corresponds the following required indices:
        - LEFT: left hand side of the bounding box
        - TOP: top of the bounding box
        - RIGHT: right of the bounding box
        - BOTTOM: bottom of the bounding box

    Return [0,0,0,0] for an empty mask. For input shape channel_1 x channel_2 x ... x height x width, the output shape
    is channel_1 x channel_2 x ... x 4.

    Args:
        - masks (`torch.Tensor` of shape `(batch, nb_mask, height, width)`)
    """
    # 如果 masks 张量为空,则返回形状相同的零张量
    if torch.numel(masks) == 0:
        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
    # 将 masks 张量的形状规范化为 Cxheightxwidth 的格式
    shape = masks.shape
    height, width = shape[-2:]
    
    # 获取顶部和底部边界
    in_height, _ = torch.max(masks, dim=-1)
    # 创建高度坐标矩阵
    in_height_coords = in_height * torch.arange(height, device=in_height.device)[None, :]
    # 计算底部边界
    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
    # 更新高度坐标矩阵,将非边界位置的坐标置为零
    in_height_coords = in_height_coords + height * (~in_height)
    # 计算顶部边界
    top_edges, _ = torch.min(in_height_coords, dim=-1)
    
    # 获取左侧和右侧边界
    in_width, _ = torch.max(masks, dim=-2)
    # 创建宽度坐标矩阵
    in_width_coords = in_width * torch.arange(width, device=in_width.device)[None, :]
    # 计算右侧边界
    right_edges, _ = torch.max(in_width_coords, dim=-1)
    # 更新宽度坐标矩阵,将非边界位置的坐标置为零
    in_width_coords = in_width_coords + width * (~in_width)
    # 计算左侧边界
    left_edges, _ = torch.min(in_width_coords, dim=-1)
    
    # 如果掩码为空,右边界将在左边界左侧。将这些框替换为 [0, 0, 0, 0]
    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
    # 构建边界框数组
    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
    # 将空框对应的边界框置为零
    out = out * (~empty_filter).unsqueeze(-1)
    
    # 恢复到原始形状
    out = out.reshape(*shape[:-2], 4)
    # 返回边界框张量
    return out
# 将输入的掩码数据重新排列为Fortran顺序,并展平高度和宽度维度
batch_size, height, width = input_mask.shape

# 计算掩码数据在高度和宽度方向上的变化索引
input_mask = tf.transpose(input_mask, perm=[0, 2, 1])  # 将高度和宽度维度交换位置
input_mask = tf.reshape(input_mask, [batch_size, -1])  # 展平高度和宽度维度

# 计算掩码数据的变化位置
diff = input_mask[:, 1:] ^ input_mask[:, :-1]  # 计算相邻像素之间的不同
change_indices = tf.where(diff)  # 获取变化位置的索引

# 编码成运行长度编码(RLE)格式,符合pycocotools期望的格式
out = []
for i in tf.range(batch_size):
    cur_idxs = tf.boolean_mask(change_indices[:, 1], change_indices[:, 0] == i) + 1
    btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
    counts = [] if input_mask[i, 0] == 0 else [0]  # 如果第一个像素为0,则起始计数为0
    counts += [cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1]]
    out.append({"size": [height, width], "counts": counts})

return out
    # 将输入掩码进行转置,然后展开为二维数组
    input_mask = flatten(tf.transpose(input_mask, perm=(0, 2, 1)), 1)

    # 计算变化的索引位置
    diff = input_mask[:, 1:] ^ input_mask[:, :-1]
    # 找出发生变化的位置的索引
    change_indices = tf.where(diff)

    # 编码运行长度
    out = []
    for i in range(batch_size):
        # 找出当前批次中第 i 行发生变化的索引
        cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
        # 计算变化点之间的距离
        btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
        # 如果第一列的值为 0,counts 列表以空列表开始,否则以 [0] 开始
        counts = [] if input_mask[i, 0] == 0 else [0]
        # 添加第一个变化点前面的零以及变化点之间的距离
        counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
        # 将结果添加到输出列表中
        out.append({"size": [height, width], "counts": counts})
    # 返回最终结果列表
    return out
# 将非压缩的 RLE(Run-Length Encoding)转换为二进制掩码(mask)
def _rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
    """Compute a binary mask from an uncompressed RLE."""
    # 获取尺寸信息
    height, width = rle["size"]
    # 创建一个空的布尔类型数组,用于存储掩码
    mask = np.empty(height * width, dtype=bool)
    idx = 0
    parity = False
    # 根据 RLE 中的 counts 数组填充掩码数组
    for count in rle["counts"]:
        mask[idx : idx + count] = parity
        idx += count
        parity = not parity
    # 将平铺的掩码数组重新形状化为原始尺寸的掩码
    mask = mask.reshape(width, height)
    return mask.transpose()  # 将掩码转置为原始形状


# 对于 TensorFlow 版本的后处理,执行非极大值抑制(NMS)
def _postprocess_for_mg_tf(rle_masks, iou_scores, mask_boxes, amg_crops_nms_thresh=0.7):
    """
    Perform NMS (Non Maximum Suppression) on the outputs.

    Args:
            rle_masks (`tf.Tensor`):
                binary masks in the RLE format
            iou_scores (`tf.Tensor` of shape (nb_masks, 1)):
                iou_scores predicted by the model
            mask_boxes (`tf.Tensor`):
                The bounding boxes corresponding to segmentation masks
            amg_crops_nms_thresh (`float`, *optional*, defaults to 0.7):
                NMS threshold.
    """
    # 使用 TensorFlow 提供的 combined_non_max_suppression 函数执行 NMS
    keep_by_nms = tf.image.combined_non_max_suppression(
        boxes=mask_boxes.float(),
        scores=iou_scores,
        idxs=torch.zeros(mask_boxes.shape[0]),  # 使用零填充,因为在 TensorFlow 中没有对应的 idxs 参数
        iou_threshold=amg_crops_nms_thresh,
    )

    # 根据 NMS 的结果进行筛选
    iou_scores = iou_scores[keep_by_nms]
    rle_masks = [rle_masks[i] for i in keep_by_nms]
    mask_boxes = mask_boxes[keep_by_nms]
    # 将每个 RLE 格式的掩码转换为二进制掩码
    masks = [_rle_to_mask(rle) for rle in rle_masks]

    return masks, iou_scores, rle_masks, mask_boxes

.\models\sam\modeling_sam.py

# coding=utf-8
# 设置编码方式为 UTF-8

# 版权声明及许可证,声明代码版权及使用许可
# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch SAM model."""
# PyTorch SAM 模型的定义和实现

import collections  # 导入 collections 模块,用于高效数据容器类型
import math  # 导入 math 模块,提供数学运算函数
from dataclasses import dataclass  # 导入 dataclass 类装饰器,用于定义数据类
from typing import Dict, List, Optional, Tuple, Union  # 导入类型提示相关的库

import numpy as np  # 导入 NumPy 库,用于数值计算
import torch  # 导入 PyTorch 库,进行深度学习模型的构建和训练
import torch.nn.functional as F  # 导入 PyTorch 中的函数模块
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 模块,用于内存优化

from torch import Tensor, nn  # 导入 PyTorch 的张量类和神经网络模块

from ...activations import ACT2FN  # 导入激活函数
from ...modeling_outputs import BaseModelOutput  # 导入基础模型输出类
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 导入工具函数和日志模块
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig  # 导入 SAM 模型的配置类

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 以下是文档化的字符串,用于生成文档

_CONFIG_FOR_DOC = "SamConfig"  # 配置文档化字符串,指定 SAM 的配置类
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"  # 检查点文档化字符串,指定 SAM 模型的预训练检查点

# SAM 模型的预训练模型存档列表
SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "facebook/sam-vit-huge",
    "facebook/sam-vit-large",
    "facebook/sam-vit-base",
    # 查看所有 SAM 模型的列表可访问 https://huggingface.co/models?filter=sam
]


@dataclass
class SamVisionEncoderOutput(ModelOutput):
    """
    Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
    layer to the pooler_output.
    """
    # SAM 视觉编码器输出的基类,同时包含通过将投影层应用于池化输出获得的图像嵌入。
    # 可选参数:模型输出的图像嵌入向量,形状为(batch_size, output_dim)。仅在模型初始化时设置了 `with_projection=True` 时返回。
    image_embeds: Optional[torch.FloatTensor] = None
    
    # 必需参数:模型最后一层的隐藏状态输出,形状为(batch_size, sequence_length, hidden_size)。
    last_hidden_state: torch.FloatTensor = None
    
    # 可选参数:模型的隐藏状态输出,是一个元组,包含模型每一层的隐藏状态输出。当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回。
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    
    # 可选参数:模型的注意力权重输出,是一个元组,包含每个注意力头的注意力权重。当 `output_attentions=True` 或 `config.output_attentions=True` 时返回。
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class SamImageSegmentationOutput(ModelOutput):
    """
    Base class for Segment-Anything model's output

    Args:
        iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
            The iou scores of the predicted masks.
        pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
            The predicted low resolutions masks. Needs to be post-processed by the processor
        vision_hidden_states  (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
        vision_attentions  (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    iou_scores: torch.FloatTensor = None
    pred_masks: torch.FloatTensor = None
    vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None



class SamPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """
    # 初始化函数,用于初始化类实例
    def __init__(self, config):
        # 调用父类的初始化方法
        super().__init__()
        # 从配置对象中获取图像大小和patch大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置对象中获取通道数和隐藏层大小
        num_channels, hidden_size = config.num_channels, config.hidden_size
        # 如果图像大小不是可迭代对象,则将其转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        # 如果patch大小不是可迭代对象,则将其转换为元组
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算图像被分成的patch数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        
        # 将计算得到的各个属性赋值给实例变量
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches
        
        # 创建卷积层,用于投影输入像素值到隐藏表示空间
        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    # 前向传播函数,接收像素值作为输入,返回嵌入表示
    def forward(self, pixel_values):
        # 获取输入张量的维度信息
        batch_size, num_channels, height, width = pixel_values.shape
        
        # 检查通道数是否与配置中的一致
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        
        # 检查输入图像尺寸是否与配置中的一致
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )
        
        # 将输入图像通过投影层并进行维度转置,得到嵌入表示
        embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
        
        # 返回嵌入表示
        return embeddings
# 定义一个用于SAM模型中MLP块的类
class SamMLPBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 创建一个线性层,将输入维度调整为mlp_dim
        self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
        # 创建另一个线性层,将mlp_dim维度调整回hidden_size
        self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
        # 选择激活函数,根据配置选择相应的激活函数
        self.act = ACT2FN[config.hidden_act]

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # 第一个线性层的前向传播
        hidden_states = self.lin1(hidden_states)
        # 应用选择的激活函数
        hidden_states = self.act(hidden_states)
        # 第二个线性层的前向传播
        hidden_states = self.lin2(hidden_states)
        return hidden_states


# 从transformers.models.convnext.modeling_convnext.ConvNextLayerNorm复制并修改为SamLayerNorm
class SamLayerNorm(nn.Module):
    r"""支持两种数据格式(channels_last或channels_first)的LayerNorm。
    channels_last对应输入形状为(batch_size, height, width, channels),而channels_first对应输入形状为(batch_size, channels, height, width)。
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        # 初始化可学习的权重和偏置参数
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        # 如果数据格式不是channels_last或channels_first,则抛出异常
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {self.data_format}")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            # 对输入进行layer normalization,使用学习的权重和偏置参数
            x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            # 如果数据格式为channels_first,则对输入进行自定义的layer normalization
            input_dtype = x.dtype
            x = x.float()
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = x.to(dtype=input_dtype)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


class SamAttention(nn.Module):
    """
    SAM的注意力层,允许在将查询、键和值投影后缩小嵌入的大小。
    """
    # 初始化函数,接受配置参数和降采样率作为可选参数
    def __init__(self, config, downsample_rate=None):
        super().__init__()
        # 设置隐藏层大小为配置中的隐藏层大小
        self.hidden_size = config.hidden_size

        # 如果没有指定降采样率,则使用配置中的注意力降采样率
        downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate

        # 根据降采样率计算内部维度
        self.internal_dim = config.hidden_size // downsample_rate
        self.num_attention_heads = config.num_attention_heads

        # 检查内部维度是否可以整除注意力头数,否则抛出数值错误
        if self.internal_dim % config.num_attention_heads != 0:
            raise ValueError("num_attention_heads must divide hidden_size.")

        # 初始化查询、键、值、输出的线性投影层
        self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)

    # 将隐藏状态按注意力头数分离的函数
    def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
        batch, point_batch_size, n_tokens, channel = hidden_states.shape
        c_per_head = channel // num_attention_heads
        # 重塑张量形状以便每个注意力头独立操作
        hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
        return hidden_states.transpose(1, 2)

    # 将分离的注意力头重新组合为隐藏状态的函数
    def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
        batch, n_heads, n_tokens, c_per_head = hidden_states.shape
        # 调整张量形状以将注意力头合并回原始形式
        hidden_states = hidden_states.transpose(1, 2)
        return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)

    # 前向传播函数,接受查询、键、值张量,可选注意力相似性张量,并返回输出张量
    def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
        # 对输入进行投影
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)

        # 获取点批次大小
        point_batch_size = query.shape[1]

        # 将查询、键、值张量分离成注意力头
        query = self._separate_heads(query, self.num_attention_heads)
        key = self._separate_heads(key, self.num_attention_heads)
        value = self._separate_heads(value, self.num_attention_heads)

        # 计算注意力权重
        _, _, _, c_per_head = query.shape
        attn = query @ key.permute(0, 1, 3, 2)  # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
        attn = attn / math.sqrt(c_per_head)
        attn = torch.softmax(attn, dim=-1)

        # 如果提供了注意力相似性张量,则加到注意力权重上
        if attention_similarity is not None:
            attn = attn + attention_similarity
            attn = torch.softmax(attn, dim=-1)

        # 计算输出
        out = attn @ value
        out = self._recombine_heads(out, point_batch_size)
        out = self.out_proj(out)

        return out
class SamTwoWayAttentionBlock(nn.Module):
    def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
        """
        A transformer block with four layers:
            (1) self-attention of sparse inputs
            (2) cross attention of sparse inputs -> dense inputs
            (3) MLP block on sparse inputs
            (4) cross attention of dense inputs -> sparse inputs

        Arguments:
            config (`SamMaskDecoderConfig`):
                The configuration file used to instantiate the block
            attention_downsample_rate (*optionalk*, int, defaults to 2):
                The downsample ratio of the block used to reduce the inner dim of the attention.
            skip_first_layer_pe (*optional*, bool, defaults to `False`):
                Whether or not to skip the addition of the query_point_embedding on the first layer.
        """
        super().__init__()

        # Initialize hidden size and layer normalization epsilon from configuration
        self.hidden_size = config.hidden_size
        self.layer_norm_eps = config.layer_norm_eps

        # Self-attention layer for sparse inputs
        self.self_attn = SamAttention(config, downsample_rate=1)
        self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

        # Cross-attention from token to image (sparse to dense) inputs
        self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
        self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

        # MLP block on sparse inputs
        self.mlp = SamMLPBlock(config)
        self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

        # Layer normalization before cross-attention from image to token (dense to sparse)
        self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
        self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)

        # Option to skip adding query_point_embedding in the first layer
        self.skip_first_layer_pe = skip_first_layer_pe

    def forward(
        self,
        queries: Tensor,
        keys: Tensor,
        query_point_embedding: Tensor,
        key_point_embedding: Tensor,
        attention_similarity: Tensor,
        output_attentions: bool = False,
    ):
        # Self attention block
        if self.skip_first_layer_pe:
            # 如果需要跳过第一个自注意力层,则使用 queries 对自注意力进行处理
            queries = self.self_attn(query=queries, key=queries, value=queries)
        else:
            # 否则,将 query_point_embedding 添加到 queries 中,然后进行自注意力计算
            query = queries + query_point_embedding
            attn_out = self.self_attn(query=query, key=query, value=queries)
            queries = queries + attn_out
        # 对 queries 进行 Layer Normalization 处理
        queries = self.layer_norm1(queries)

        # Cross attention block, tokens attending to image embedding
        # 将 query_point_embedding 添加到 queries,将 key_point_embedding 添加到 keys
        query = queries + query_point_embedding
        key = keys + key_point_embedding

        # 使用 cross_attn_token_to_image 方法进行跨注意力计算,将结果添加到 queries 中
        attn_out = self.cross_attn_token_to_image(
            query=query, key=key, value=keys, attention_similarity=attention_similarity
        )
        queries = queries + attn_out

        # 对 queries 进行 Layer Normalization 处理
        queries = self.layer_norm2(queries)

        # MLP block
        # 使用 MLP 模块处理 queries
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        # 对 queries 进行 Layer Normalization 处理
        queries = self.layer_norm3(queries)

        # Cross attention block, image embedding attending to tokens
        # 将 query_point_embedding 添加到 queries,将 key_point_embedding 添加到 keys
        query = queries + query_point_embedding
        key = keys + key_point_embedding

        # 使用 cross_attn_image_to_token 方法进行跨注意力计算,将结果添加到 keys 中
        attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
        keys = keys + attn_out

        # 对 keys 进行 Layer Normalization 处理
        keys = self.layer_norm4(keys)

        # 输出为 (queries, keys) 元组
        outputs = (queries, keys)

        # 如果需要输出注意力权重,则将注意力权重添加到输出元组中
        if output_attentions:
            outputs = outputs + (attn_out,)
        else:
            outputs = outputs + (None,)

        # 返回最终的输出元组
        return outputs
# 定义一个双向转换器模型,继承自 nn.Module
class SamTwoWayTransformer(nn.Module):
    # 初始化函数,接收一个 SamMaskDecoderConfig 类型的配置对象作为参数
    def __init__(self, config: SamMaskDecoderConfig):
        super().__init__()
        # 保存配置对象
        self.config = config

        # 从配置中获取隐藏层数量
        self.num_hidden_layers = config.num_hidden_layers
        # 初始化一个模块列表用于保存多个双向注意力块
        self.layers = nn.ModuleList()

        # 根据隐藏层数量循环创建双向注意力块并添加到模块列表中
        for i in range(self.num_hidden_layers):
            self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))

        # 创建最终的注意力层对象和对应的 LayerNorm 层
        self.final_attn_token_to_image = SamAttention(config)
        self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)

    # 前向传播函数
    def forward(
        self,
        point_embeddings: Tensor,
        image_embeddings: Tensor,
        image_positional_embeddings: Tensor,
        attention_similarity: Tensor,
        target_embedding=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        # 如果未指定 output_attentions,则使用配置中的设定
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果未指定 output_hidden_states,则使用配置中的设定
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果未指定 return_dict,则使用配置中的设定
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 用于保存所有注意力值的元组
        all_attentions = ()

        # 如果 image_embeddings 为 None,则抛出 ValueError 异常
        if image_embeddings is None:
            raise ValueError("You have to specify an image_embedding")

        # 对 image_embeddings 和 image_positional_embeddings 进行形状变换和排列
        image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
        image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)

        # 准备查询向量
        queries = point_embeddings
        keys = image_embeddings

        # 对每个双向注意力块执行变换操作并应用最终的 LayerNorm
        for layer in self.layers:
            # 如果存在 target_embedding,则将其加到 queries 中
            if target_embedding is not None:
                queries += target_embedding

            # 调用当前层的 forward 方法进行注意力计算
            queries, keys, attention_outputs = layer(
                queries=queries,
                keys=keys,
                query_point_embedding=point_embeddings,
                key_point_embedding=image_positional_embeddings,
                attention_similarity=attention_similarity,
                output_attentions=output_attentions,
            )

            # 如果 output_attentions 为 True,则将当前层的 attention_outputs 添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (attention_outputs,)

        # 应用从点到图像的最终注意力层
        query = queries + point_embeddings
        key = keys + image_positional_embeddings

        # 调用最终的注意力层进行计算
        attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)

        # 将计算得到的 attn_out 加到 queries 中,并应用 LayerNorm
        queries = queries + attn_out
        queries = self.layer_norm_final_attn(queries)

        # 返回 queries, keys 和 all_attentions(如果有)
        return queries, keys, all_attentions
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置神经网络的层数
        self.num_layers = num_layers
        # 指定激活函数为ReLU
        self.activation = nn.ReLU()
        # 创建输入层到隐藏层的线性映射
        self.proj_in = nn.Linear(input_dim, hidden_dim)
        # 创建隐藏层到输出层的线性映射
        self.proj_out = nn.Linear(hidden_dim, output_dim)
        # 使用ModuleList创建隐藏层的线性映射列表
        self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
        # 根据需要设置输出是否经过sigmoid函数
        self.sigmoid_output = sigmoid_output

    def forward(self, hidden_states):
        # 输入数据经过输入层到隐藏层的线性映射
        hidden_states = self.proj_in(hidden_states)
        # 经过ReLU激活函数处理隐藏层的输出
        hidden_states = self.activation(hidden_states)
        # 遍历隐藏层列表,每一层经过线性映射和ReLU激活函数
        for layer in self.layers:
            hidden_states = self.activation(layer(hidden_states))

        # 经过隐藏层到输出层的线性映射
        hidden_states = self.proj_out(hidden_states)
        # 如果需要,对输出进行sigmoid函数处理
        if self.sigmoid_output:
            hidden_states = F.sigmoid(hidden_states)
        # 返回神经网络的输出
        return hidden_states
class SamMaskDecoder(nn.Module):
    def __init__(self, config: SamMaskDecoderConfig):
        super().__init__()

        self.hidden_size = config.hidden_size  # 从配置中获取隐藏层大小

        self.num_multimask_outputs = config.num_multimask_outputs  # 多重遮罩输出的数量
        self.num_mask_tokens = config.num_multimask_outputs + 1  # 遮罩标记的数量,包括一个IOU标记

        self.iou_token = nn.Embedding(1, self.hidden_size)  # 创建一个大小为1的嵌入层,用于IOU标记
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)  # 创建一个嵌入层,用于所有遮罩标记

        self.transformer = SamTwoWayTransformer(config)  # 创建一个SamTwoWayTransformer对象

        # 创建上采样卷积层和归一化层
        self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
        self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
        self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
        self.activation = nn.GELU()  # GELU激活函数

        mlps_list = []
        for _ in range(self.num_mask_tokens):
            mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
        self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)  # 创建一个包含多个SamFeedForward层的模块列表

        self.iou_prediction_head = SamFeedForward(
            self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
        )  # 创建一个SamFeedForward对象,用于IOU预测头部

    def forward(
        self,
        image_embeddings: torch.Tensor,
        image_positional_embeddings: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        multimask_output: bool,
        output_attentions: Optional[bool] = None,
        attention_similarity: torch.Tensor = None,
        target_embedding: torch.Tensor = None,
    ):
        # 此处应为前向传播方法,接收各种输入张量并进行模型的前向运算,但未提供具体实现,无法详细注释
        pass  # 占位符,表示此处未实现具体逻辑


class SamPositionalEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.scale = config.hidden_size // 2  # 缩放因子为隐藏层大小的一半
        self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))  # 注册位置嵌入的缓冲区

    def forward(self, input_coords, input_shape=None):
        """Positionally encode points that are normalized to [0,1]."""
        coordinates = input_coords.clone()

        if input_shape is not None:
            coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]  # 归一化x坐标
            coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]  # 归一化y坐标

        # 假设坐标位于[0, 1]^2区域并具有d_1 x ... x d_n x 2形状
        coordinates = 2 * coordinates - 1  # 映射到[-1, 1]区间
        coordinates = coordinates.to(self.positional_embedding.dtype)  # 将坐标转换为位置嵌入的数据类型
        coordinates = coordinates @ self.positional_embedding  # 点乘位置嵌入
        coordinates = 2 * np.pi * coordinates  # 缩放角度
        # 输出d_1 x ... x d_n x 通道形状
        return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)  # 返回正弦和余弦编码的拼接


class SamMaskEmbedding(nn.Module):
    # 此处未提供具体代码实现,无法详细注释
    pass  # 占位符,表示此处未实现具体逻辑
    # 初始化方法,接受一个配置对象作为参数
    def __init__(self, config: SamPromptEncoderConfig):
        # 调用父类的初始化方法
        super().__init__()
        # 计算输入通道数的四分之一,并赋值给实例变量
        self.mask_input_channels = config.mask_input_channels // 4
        # 根据配置中的激活函数名称,获取对应的激活函数,并赋值给实例变量
        self.activation = ACT2FN[config.hidden_act]
        # 创建第一个二维卷积层,输入通道为1,输出通道为四分之一的输入通道数,核大小为2x2,步长为2
        self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
        # 创建第二个二维卷积层,输入通道数为四分之一的输入通道数,输出通道数为配置中的输入通道数,核大小为2x2,步长为2
        self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
        # 创建第三个二维卷积层,输入通道数为配置中的输入通道数,输出通道数为配置中的隐藏大小,核大小为1x1
        self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
        # 创建第一个 SAM 层归一化实例,输入通道数为四分之一的输入通道数,epsilon 使用配置中的值,数据格式为"channels_first"
        self.layer_norm1 = SamLayerNorm(
            self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
        )
        # 创建第二个 SAM 层归一化实例,输入通道数为四倍的输入通道数,epsilon 使用配置中的值,数据格式为"channels_first"
        self.layer_norm2 = SamLayerNorm(
            self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
        )

    # 前向传播方法,接收掩码作为输入,返回密集嵌入向量
    def forward(self, masks):
        # 第一次卷积操作,将掩码传入第一个卷积层,得到隐藏状态
        hidden_states = self.conv1(masks)
        # 对第一个卷积层的输出进行 SAM 层归一化
        hidden_states = self.layer_norm1(hidden_states)
        # 对归一化后的输出应用激活函数
        hidden_states = self.activation(hidden_states)

        # 第二次卷积操作,将上一步的输出传入第二个卷积层,得到隐藏状态
        hidden_states = self.conv2(hidden_states)
        # 对第二个卷积层的输出进行 SAM 层归一化
        hidden_states = self.layer_norm2(hidden_states)
        # 对归一化后的输出应用激活函数
        hidden_states = self.activation(hidden_states)

        # 第三次卷积操作,将上一步的输出传入第三个卷积层,得到密集嵌入向量
        dense_embeddings = self.conv3(hidden_states)
        # 返回密集嵌入向量作为前向传播的结果
        return dense_embeddings
# 定义 SamPromptEncoder 类,继承自 nn.Module,用于处理 SamPrompt 编码器相关操作
class SamPromptEncoder(nn.Module):
    # 初始化方法,接受 SamPromptEncoderConfig 类型的 config 和共享的 patch embedding
    def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding):
        super().__init__()
        # 共享的 patch embedding
        self.shared_embedding = shared_patch_embedding
        # 创建 SamMaskEmbedding 对象,用于处理 mask 相关操作
        self.mask_embed = SamMaskEmbedding(config)
        # 创建一个只包含一个元素的 nn.Embedding 对象,用于处理没有 mask 的情况
        self.no_mask_embed = nn.Embedding(1, config.hidden_size)

        # 设置图像嵌入大小和输入图像大小
        self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
        self.input_image_size = config.image_size

        # 创建一个 nn.ModuleList,包含多个 nn.Embedding 对象,用于处理点嵌入
        self.point_embed = nn.ModuleList(
            [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
        )
        # 隐藏状态的大小
        self.hidden_size = config.hidden_size
        # 创建一个只包含一个元素的 nn.Embedding 对象,用于处理不是点的情况
        self.not_a_point_embed = nn.Embedding(1, config.hidden_size)

    # 内部方法,用于嵌入点提示
    def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
        """Embeds point prompts."""
        # 将点位移 0.5,以将其移动到像素中心
        points = points + 0.5
        # 如果需要填充
        if pad:
            # 创建目标点形状和标签形状的张量
            target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
            target_labels_shape = (points.shape[0], points.shape[1], 1)
            # 创建填充点和标签的零张量和负一标签
            padding_point = torch.zeros(target_point_shape, device=points.device)
            padding_label = -torch.ones(target_labels_shape, device=labels.device)
            # 在维度 2 上拼接点和标签
            points = torch.cat([points, padding_point], dim=2)
            labels = torch.cat([labels, padding_label], dim=2)
        # 输入形状为 (self.input_image_size, self.input_image_size)
        input_shape = (self.input_image_size, self.input_image_size)
        # 使用共享的嵌入嵌入点
        point_embedding = self.shared_embedding(points, input_shape)

        # 根据标签是否为 -1,选择不是点的嵌入或点的嵌入
        point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)

        # 对于 ONNX 导出,需要使用 torch.where 扩展标签张量
        point_embedding = torch.where(
            labels[..., None] != -10,
            point_embedding,
            torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
        )

        # 根据标签是否为 0,加上第一个点嵌入的权重
        point_embedding = torch.where(
            (labels == 0)[:, :, :, None],
            point_embedding + self.point_embed[0].weight[None, None, :, :],
            point_embedding,
        )

        # 根据标签是否为 1,加上第二个点嵌入的权重
        point_embedding = torch.where(
            (labels == 1)[:, :, :, None],
            point_embedding + self.point_embed[1].weight[None, None, :, :],
            point_embedding,
        )

        # 返回点嵌入结果
        return point_embedding
    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        # 将框的坐标加上0.5,以将坐标中心移至像素中心
        boxes = boxes + 0.5  # Shift to center of pixel
        batch_size, nb_boxes = boxes.shape[:2]
        # 将框的坐标重塑为(batch_size, nb_boxes, 2, 2)的形状
        coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
        # 设置输入形状为(self.input_image_size, self.input_image_size)
        input_shape = (self.input_image_size, self.input_image_size)
        # 使用共享的嵌入层来嵌入角点的坐标
        corner_embedding = self.shared_embedding(coords, input_shape)
        # 将角点嵌入矩阵的第一个维度加上self.point_embed[2].weight
        corner_embedding[:, :, 0, :] += self.point_embed[2].weight
        # 将角点嵌入矩阵的第二个维度加上self.point_embed[3].weight
        corner_embedding[:, :, 1, :] += self.point_embed[3].weight
        return corner_embedding

    def forward(
        self,
        input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        input_labels: Optional[torch.Tensor],
        input_boxes: Optional[torch.Tensor],
        input_masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Embeds different types of prompts, returning both sparse and dense embeddings.

        Args:
            points (`torch.Tensor`, *optional*):
                point coordinates and labels to embed.
            boxes (`torch.Tensor`, *optional*):
                boxes to embed
            masks (`torch.Tensor`, *optional*):
                masks to embed
        """
        sparse_embeddings = None
        batch_size = 1
        # 确定目标设备为self.shared_embedding.positional_embedding的设备
        target_device = self.shared_embedding.positional_embedding.device
        if input_points is not None:
            batch_size, point_batch_size = input_points.shape[:2]
            # 如果提供了points但未提供labels,则抛出异常
            if input_labels is None:
                raise ValueError("If points are provided, labels must also be provided.")
            # 使用_embed_points方法嵌入points的坐标和标签
            point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
            sparse_embeddings = point_embeddings
        if input_boxes is not None:
            batch_size = input_boxes.shape[0]
            # 使用_embed_boxes方法嵌入boxes
            box_embeddings = self._embed_boxes(input_boxes)
            if sparse_embeddings is None:
                sparse_embeddings = box_embeddings
            else:
                sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
        if input_masks is not None:
            # 使用mask_embed方法嵌入masks
            dense_embeddings = self.mask_embed(input_masks)
        else:
            # 使用no_mask_embed的权重初始化dense_embeddings
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

        # 如果sparse_embeddings仍为None,则初始化为全零张量
        if sparse_embeddings is None:
            sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)

        return sparse_embeddings, dense_embeddings
        """
        Add decomposed relative positional embeddings to the attention scores.

        Args:
            attn (`torch.Tensor`):
                Attention scores between query and key.
            query (`torch.Tensor`):
                Query tensor.
            rel_pos_h (`torch.Tensor`):
                Relative positional embeddings along height.
            rel_pos_w (`torch.Tensor`):
                Relative positional embeddings along width.
            q_size (`Tuple[int, int]`):
                Size of the query tensor (height, width).
            k_size (`Tuple[int, int]`):
                Size of the key tensor (height, width).

        Returns:
            `torch.Tensor`: Attention scores modified by relative positional embeddings.
        """

        max_rel_dist = int(2 * max(q_size[0], k_size[0]) - 1)

        # Interpolate relative position embeddings
        rel_pos_h_resized = F.interpolate(rel_pos_h.unsqueeze(0), size=max_rel_dist, mode="linear")
        rel_pos_w_resized = F.interpolate(rel_pos_w.unsqueeze(0), size=max_rel_dist, mode="linear")

        rel_pos_h_resized = rel_pos_h_resized.squeeze(0)
        rel_pos_w_resized = rel_pos_w_resized.squeeze(0)

        # Scale coordinates with maximum length if query and key sizes differ
        q_coords = torch.arange(q_size[0]).unsqueeze(1) * max(k_size[0] / q_size[0], 1.0)
        k_coords = torch.arange(k_size[0]).unsqueeze(0) * max(q_size[0] / k_size[0], 1.0)
        relative_coords_h = (q_coords - k_coords) + (k_size[0] - 1) * max(q_size[0] / k_size[0], 1.0)

        q_coords = torch.arange(q_size[1]).unsqueeze(1) * max(k_size[1] / q_size[1], 1.0)
        k_coords = torch.arange(k_size[1]).unsqueeze(0) * max(q_size[1] / k_size[1], 1.0)
        relative_coords_w = (q_coords - k_coords) + (k_size[1] - 1) * max(q_size[1] / k_size[1], 1.0)

        # Gather relative positional embeddings
        rel_pos_h = rel_pos_h_resized[relative_coords_h.long()]
        rel_pos_w = rel_pos_w_resized[relative_coords_w.long()]

        # Combine relative positional embeddings
        rel_pos = rel_pos_h + rel_pos_w.unsqueeze(0)

        # Reshape and expand relative positional embeddings
        rel_pos = rel_pos.unsqueeze(0).expand(attn.size(0), -1, -1)

        # Add relative positional embeddings to attention scores
        attn = attn + rel_pos

        return attn
    ) -> torch.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`torch.Tensor`):
                attention map.
            query (`torch.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`torch.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`torch.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`torch.Tensor`):
                attention map with added relative positional embeddings.
        """
        # 解包查询大小和键大小
        query_height, query_width = q_size
        key_height, key_width = k_size
        
        # 获取相对位置编码矩阵的高度和宽度方向上的影响
        relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
        relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)

        # 获取查询的批次大小、通道数和维度
        batch_size, _, dim = query.shape
        
        # 重塑查询张量为四维张量,以便进行后续的张量乘积操作
        reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
        
        # 计算相对位置编码对高度和宽度的影响
        rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
        rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
        
        # 重塑注意力图张量以便添加相对位置编码
        attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
        attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
        
        # 将注意力图重新展平为原始形状并返回
        attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
        return attn
    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        # 获取隐藏状态的维度信息
        batch_size, height, width, _ = hidden_states.shape
        
        # 使用 qkv 网络处理隐藏状态,生成 qkv 张量,形状为 (3, batch_size, nHead, height * width, channel)
        qkv = (
            self.qkv(hidden_states)
            .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
            .permute(2, 0, 3, 1, 4)
        )
        
        # 将 qkv 张量按照 q, k, v 分开,形状为 (batch_size * nHead, height * width, channel)
        query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
        
        # 计算注意力权重,使用注意力机制中的 query 和 key 进行点积
        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        # 如果使用相对位置编码,则添加分解的相对位置信息
        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )
        
        # 对注意力权重进行 softmax 归一化处理
        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        
        # 对注意力权重进行 dropout 处理
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        
        # 计算注意力输出,将注意力概率与 value 进行加权求和
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        
        # 使用投影层进行输出转换
        attn_output = self.proj(attn_output)
        
        # 如果需要输出注意力权重,则将其包含在输出中
        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        
        return outputs
    def __init__(self, config, window_size):
        """
        Initialize the SamVisionLayer module.

        Args:
            config (object): Configuration object containing parameters like hidden_size and layer_norm_eps.
            window_size (int): Size of the sliding window to partition the input.
        """
        # Call the parent class constructor to initialize it
        super().__init__()
        
        # Layer normalization applied to the input tensor
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Self-attention mechanism tailored for vision tasks
        self.attn = SamVisionAttention(config, window_size)
        
        # Layer normalization applied after self-attention
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        # Multilayer perceptron block for further processing
        self.mlp = SamMLPBlock(config)
        
        # Store the window size for later use
        self.window_size = window_size

    def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """
        Partition input tensor into non-overlapping windows with padding if necessary.

        Args:
            hidden_states (torch.Tensor): Input tensor with shape [batch_size, height, width, channel].
            window_size (int): Size of the window.

        Returns:
            windows (torch.Tensor): Tensor of windows after partitioning with shape [batch_size * num_windows, window_size, window_size, channel].
            (pad_height, pad_width) (Tuple[int, int]): Padded height and width before partitioning.
        """
        # Extract dimensions from the input tensor
        batch_size, height, width, channel = hidden_states.shape
        
        # Calculate padding required to make dimensions divisible by window_size
        pad_h = (window_size - height % window_size) % window_size
        pad_w = (window_size - width % window_size) % window_size
        
        # Apply padding to the input tensor
        hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
        
        # Update height and width dimensions after padding
        pad_height, pad_width = height + pad_h, width + pad_w
        
        # Reshape the tensor into windows of the specified size
        hidden_states = hidden_states.reshape(
            batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
        )
        
        # Permute dimensions to arrange windows properly
        windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
        
        return windows, (pad_height, pad_width)

    def window_unpartition(
        self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
    ):
        """
        Reconstruct the original tensor from windows.

        Args:
            windows (torch.Tensor): Tensor of windows with shape [batch_size * num_windows, window_size, window_size, channel].
            window_size (int): Size of the window.
            padding_shape (Tuple[int, int]): Padded height and width before partitioning.
            original_shape (Tuple[int, int]): Original height and width of the input tensor before padding.

        Returns:
            Tensor of original shape [batch_size, height, width, channel].
        """
        # Implementation of this method is typically specific to how window_partition was implemented.
        # It reconstructs the original tensor from the windows created in window_partition.
        pass
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (tensor):
                输入的张量,包含 [batch_size * num_windows, window_size, window_size, channel] 的数据。
            window_size (int):
                窗口大小。
            padding_shape (Tuple):
                填充后的高度和宽度 (pad_height, pad_width)。
            original_shape (Tuple):
                填充前的原始高度和宽度 (height, width)。

        Returns:
            hidden_states: 没有分区的序列,维度为 [batch_size, height, width, channel]。
        """
        pad_height, pad_width = padding_shape
        height, width = original_shape
        batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
        hidden_states = windows.reshape(
            batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
        )
        hidden_states = (
            hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
        )

        hidden_states = hidden_states[:, :height, :width, :].contiguous()
        return hidden_states

    def forward(
        self,
        hidden_states: torch.Tensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        # 窗口分区
        if self.window_size > 0:
            height, width = hidden_states.shape[1], hidden_states.shape[2]
            hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)

        hidden_states, attn_weights = self.attn(
            hidden_states=hidden_states,
            output_attentions=output_attentions,
        )
        # 反向窗口分区
        if self.window_size > 0:
            hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))

        hidden_states = residual + hidden_states
        layernorm_output = self.layer_norm2(hidden_states)
        hidden_states = hidden_states + self.mlp(layernorm_output)

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs
class SamVisionNeck(nn.Module):
    # SamVisionNeck 类,用于实现视觉模型的颈部结构
    def __init__(self, config: SamVisionConfig):
        super().__init__()
        self.config = config

        # 第一个卷积层,将输入特征映射到输出通道数
        self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
        # 第一个层归一化层,对输出进行通道方向的标准化
        self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
        # 第二个卷积层,继续处理特征映射,增加网络的非线性能力
        self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
        # 第二个层归一化层,对输出进行通道方向的标准化
        self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")

    def forward(self, hidden_states):
        # 将输入特征的维度重新排列,以适应卷积层的输入要求
        hidden_states = hidden_states.permute(0, 3, 1, 2)
        # 第一个卷积层的前向传播
        hidden_states = self.conv1(hidden_states)
        # 第一个层归一化层的前向传播
        hidden_states = self.layer_norm1(hidden_states)

        # 第二个卷积层的前向传播
        hidden_states = self.conv2(hidden_states)
        # 第二个层归一化层的前向传播
        hidden_states = self.layer_norm2(hidden_states)
        return hidden_states


class SamVisionEncoder(nn.Module):
    # SamVisionEncoder 类,用于实现视觉模型的编码器结构
    def __init__(self, config: SamVisionConfig):
        super().__init__()
        self.config = config
        self.image_size = config.image_size

        # 图像分块嵌入层,将图像转换为序列数据
        self.patch_embed = SamPatchEmbeddings(config)

        self.pos_embed = None
        if config.use_abs_pos:
            # 如果使用绝对位置编码,则初始化绝对位置嵌入
            self.pos_embed = nn.Parameter(
                torch.zeros(
                    1,
                    config.image_size // config.patch_size,
                    config.image_size // config.patch_size,
                    config.hidden_size,
                )
            )

        # 编码器的层列表
        self.layers = nn.ModuleList()
        for i in range(config.num_hidden_layers):
            # 创建并添加视觉层
            layer = SamVisionLayer(
                config,
                window_size=config.window_size if i not in config.global_attn_indexes else 0,
            )
            self.layers.append(layer)

        # 视觉模型的颈部结构
        self.neck = SamVisionNeck(config)

        # 是否使用梯度检查点
        self.gradient_checkpointing = False

    def get_input_embeddings(self):
        return self.patch_embed

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 省略部分:此处省略了 forward 方法的参数描述
    ) -> Union[Tuple, SamVisionEncoderOutput]:
        # 检查是否需要输出注意力权重,默认使用配置中的设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 检查是否需要输出隐藏状态,默认使用配置中的设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 检查是否使用返回字典格式,默认使用配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果像素值为 None,则抛出数值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 将像素值传入到补丁嵌入层中
        hidden_states = self.patch_embed(pixel_values)
        # 如果存在位置编码,则将其加到隐藏状态中
        if self.pos_embed is not None:
            hidden_states = hidden_states + self.pos_embed

        # 初始化用于存储所有隐藏状态的变量,如果不需要输出隐藏状态则设为 None
        all_hidden_states = () if output_hidden_states else None
        # 初始化用于存储所有自注意力权重的变量,如果不需要输出注意力权重则设为 None
        all_self_attentions = () if output_attentions else None

        # 遍历所有编码器层
        for i, layer_module in enumerate(self.layers):
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到存储列表中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 如果启用梯度检查点并且处于训练模式,则使用梯度检查点函数调用层模块
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                )
            else:
                # 否则直接调用层模块,传入当前隐藏状态和是否需要输出注意力权重的标志
                layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)

            # 更新隐藏状态为当前层模块的输出的第一个元素(通常是新的隐藏状态)
            hidden_states = layer_outputs[0]

            # 如果需要输出注意力权重,则将当前层的自注意力权重添加到存储列表中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最终的隐藏状态添加到存储列表中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 将最终的隐藏状态传入到“neck”层中进行处理
        hidden_states = self.neck(hidden_states)

        # 如果不使用返回字典格式,则构建返回的输出元组
        if not return_dict:
            outputs = (hidden_states,)
            if output_hidden_states:
                outputs = outputs + (all_hidden_states,)
            if output_attentions:
                outputs = outputs + (all_self_attentions,)
            return outputs

        # 如果使用返回字典格式,则构建并返回 SamVisionEncoderOutput 对象
        return SamVisionEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
class SamPreTrainedModel(PreTrainedModel):
    # 配置使用的配置类
    config_class = SamConfig
    # 模型中基础模型的前缀名称
    base_model_prefix = "sam"
    # 主输入名称为像素值
    main_input_name = "pixel_values"

    def _init_weights(self, module):
        # 从配置中获取初始化范围的标准差
        std = self.config.initializer_range
        # 如果模块是线性层、二维卷积层或反卷积层
        if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
            # 初始化权重为正态分布
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果有偏置,则初始化为零
            if module.bias is not None:
                module.bias.data.zero_()
        # 如果模块是嵌入层
        elif isinstance(module, nn.Embedding):
            # 初始化权重为正态分布
            module.weight.data.normal_(mean=0.0, std=std)
            # 如果有填充索引,则将对应索引的权重初始化为零
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()



@add_start_docstrings(
    "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
    " optional 2D location and bounding boxes.",
    SAM_START_DOCSTRING,
)
class SamModel(SamPreTrainedModel):
    # 需要共享权重的键列表
    _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]

    def __init__(self, config):
        super().__init__(config)
        # 共享图像嵌入
        self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)

        # 视觉编码器
        self.vision_encoder = SamVisionEncoder(config.vision_config)
        # 提示编码器,使用共享的图像嵌入
        self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding)
        # 掩码解码器
        self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)

        # 进行初始化后处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回视觉编码器的输入嵌入
        return self.vision_encoder.get_input_embeddings()


注释:
    def get_image_wide_positional_embeddings(self):
        # 获取图像嵌入的位置编码,使用配置中的图像嵌入大小
        size = self.config.prompt_encoder_config.image_embedding_size
        # 获取共享的图像嵌入位置编码的设备和数据类型
        target_device = self.shared_image_embedding.positional_embedding.device
        target_dtype = self.shared_image_embedding.positional_embedding.dtype
        # 创建一个全为1的张量作为网格,设备和数据类型与位置编码相同
        grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
        # 计算垂直方向的位置编码
        y_embed = grid.cumsum(dim=0) - 0.5
        # 计算水平方向的位置编码
        x_embed = grid.cumsum(dim=1) - 0.5
        # 将位置编码归一化到 [0, 1] 范围内
        y_embed = y_embed / size
        x_embed = x_embed / size

        # 使用共享的图像嵌入模型,将 x 和 y 的位置编码堆叠起来作为输入
        positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
        # 将通道维度放到最前面,返回的形状为 channel x height x width
        return positional_embedding.permute(2, 0, 1).unsqueeze(0)

    @torch.no_grad()
    def get_image_embeddings(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        返回通过视觉编码器处理像素值得到的图像嵌入。

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                输入的像素值
            output_attentions (`bool`, *optional*):
                是否返回所有注意力层的注意力张量。
            output_hidden_states (`bool`, *optional*):
                是否返回所有层的隐藏状态。
            return_dict (`bool`, *optional*):
                是否返回 [`~utils.ModelOutput`] 而不是简单的元组。

        """
        # 使用视觉编码器处理像素值,根据参数决定是否返回特定的信息
        vision_output = self.vision_encoder(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 提取视觉编码器输出的图像嵌入张量
        image_embeddings = vision_output[0]
        return image_embeddings

    @torch.no_grad()
    def get_prompt_embeddings(
        self,
        input_points: Optional[torch.FloatTensor] = None,
        input_labels: Optional[torch.LongTensor] = None,
        input_boxes: Optional[torch.FloatTensor] = None,
        input_masks: Optional[torch.LongTensor] = None,
    ):
        # 返回用于提示的嵌入,接受多种类型的输入作为可选参数
        # (具体实现部分未提供,在实际代码中可能有更多的细节处理)
    ):
        r"""
        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.

        Args:
            input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
                Optional input points for the prompt encoder. The padding of the point is automatically done by the
                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
                point. The model will output `point_batch_size` times 3 masks in total.
            input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
                processor, or can be fed by the user.
            input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
                processor. users can also pass manually the input boxes.
            input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
                Optional input masks for the prompt encoder.
        """
        # 使用 prompt_encoder 方法计算 prompt 的嵌入结果,传入参数包括输入的点、标签、框和掩码
        prompt_output = self.prompt_encoder(
            input_points=input_points,
            input_labels=input_labels,
            input_boxes=input_boxes,
            input_masks=input_masks,
        )
        # 返回 prompt 的嵌入结果
        return prompt_output

    @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        input_points: Optional[torch.FloatTensor] = None,
        input_labels: Optional[torch.LongTensor] = None,
        input_boxes: Optional[torch.FloatTensor] = None,
        input_masks: Optional[torch.LongTensor] = None,
        image_embeddings: Optional[torch.FloatTensor] = None,
        multimask_output: bool = True,
        attention_similarity: Optional[torch.FloatTensor] = None,
        target_embedding: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,

.\models\sam\modeling_tf_sam.py

# coding=utf-8
# 指定文件编码为 UTF-8

# Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
# 版权声明,声明文件版权归 Meta AI 作者和 HuggingFace 团队所有

# Licensed under the Apache License, Version 2.0 (the "License");
# 依据 Apache 许可证 2.0 版本授权

# you may not use this file except in compliance with the License.
# 除非符合许可证规定,否则不得使用本文件

# You may obtain a copy of the License at
# 可以从以下网址获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,否则本软件按“原样”分发,不附任何明示或暗示的担保或条件

# See the License for the specific language governing permissions and
# limitations under the License.
# 请参阅许可证,了解权限和限制的具体条款

"""
TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
discrepancy, the original file should be regarded as the 'reference' version.
"""
# TensorFlow SAM 模型文件,大部分由 PyTorch 原始文件自动翻译生成,如有不一致,请以原始文件为参考版本

# 导入必要的库和模块
from __future__ import annotations

import collections  # 导入 collections 库
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Optional, Tuple, Union  # 导入类型提示所需的类和联合类型

import numpy as np  # 导入 NumPy 库并简写为 np
import tensorflow as tf  # 导入 TensorFlow 库并简写为 tf

# 导入相对路径的模块
from ...activations_tf import ACT2FN  # 从活化函数模块导入 ACT2FN
from ...modeling_tf_outputs import TFBaseModelOutput  # 从 TensorFlow 输出模块导入 TFBaseModelOutput
from ...modeling_tf_utils import (  # 从 TensorFlow 实用工具模块导入以下功能:
    TFModelInputType, TFPreTrainedModel,  # TFModelInputType、TFPreTrainedModel 类
    keras, shape_list, unpack_inputs  # keras、shape_list、unpack_inputs 函数
)
from ...tf_utils import flatten, functional_layernorm  # 从 TensorFlow 实用工具模块导入 flatten、functional_layernorm 函数
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging  # 从工具模块导入 ModelOutput 类、若干函数、logging 模块
from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig  # 从 sam 配置模块导入若干配置类

logger = logging.get_logger(__name__)  # 获取当前模块的 logger 对象

_CONFIG_FOR_DOC = "SamConfig"  # 用于文档的配置信息
_CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"  # 用于文档的检查点信息

TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST = [  # TensorFlow SAM 预训练模型存档列表
    "facebook/sam-vit-huge",  # Facebook SAM-ViT Huge 模型
    "facebook/sam-vit-large",  # Facebook SAM-ViT Large 模型
    "facebook/sam-vit-base",  # Facebook SAM-ViT Base 模型
    # 查看所有 SAM 模型,请访问 https://huggingface.co/models?filter=sam
]

@dataclass
class TFSamVisionEncoderOutput(ModelOutput):
    """
    Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
    layer to the pooler_output.
    """
    # TFSamVisionEncoderOutput 类,用作 SAM 视觉模型输出的基类,还包含通过将投影层应用于 pooler_output 获得的图像嵌入
    """
    Args:
        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    # 可选参数:图像嵌入,形状为 `(batch_size, output_dim)`,当模型以 `with_projection=True` 初始化时返回
    image_embeds: tf.Tensor | None = None

    # 必需参数:最后一个隐藏层的隐藏状态,形状为 `(batch_size, sequence_length, hidden_size)`
    last_hidden_state: tf.Tensor = None

    # 可选参数:元组,包含隐藏状态的序列,形状为 `(batch_size, sequence_length, hidden_size)`
    # 当 `output_hidden_states=True` 或 `config.output_hidden_states=True` 时返回
    # 包括模型每一层的输出以及可选的初始嵌入输出
    hidden_states: Tuple[tf.Tensor, ...] | None = None

    # 可选参数:元组,包含注意力权重的序列,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`
    # 当 `output_attentions=True` 或 `config.output_attentions=True` 时返回
    # 用于计算自注意力头中加权平均值的注意力权重
    attentions: Tuple[tf.Tensor, ...] | None = None
# 定义一个数据类,表示Segment-Anything模型的输出结果,继承自ModelOutput类
@dataclass
class TFSamImageSegmentationOutput(ModelOutput):
    """
    Base class for Segment-Anything model's output

    Args:
        iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
            The iou scores of the predicted masks.
            预测掩膜的IoU分数。

        pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
            The predicted low resolutions masks. Needs to be post-processed by the processor.
            预测的低分辨率掩膜,需要由处理器进行后处理。

        vision_hidden_states  (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
            the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
            视觉模型在每一层输出的隐藏状态,以及可选的初始嵌入输出。

        vision_attentions  (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,在注意力softmax后计算的,用于计算自注意力头中的加权平均值。

        mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
            注意力权重,在注意力softmax后计算的,用于计算自注意力头中的加权平均值。
    """

    # IoU分数的张量,形状为(batch_size, num_masks)
    iou_scores: tf.Tensor = None

    # 预测掩膜的张量,形状为(batch_size, num_masks, height, width)
    pred_masks: tf.Tensor = None

    # 视觉隐藏状态的元组,每个元素是一个形状为(batch_size, sequence_length, hidden_size)的张量
    vision_hidden_states: Tuple[tf.Tensor, ...] | None = None

    # 视觉注意力的元组,每个元素是一个形状为(batch_size, num_heads, sequence_length, sequence_length)的张量
    vision_attentions: Tuple[tf.Tensor, ...] | None = None

    # 掩膜解码器注意力的元组,每个元素是一个形状为(batch_size, num_heads, sequence_length, sequence_length)的张量
    mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None


class TFSamPatchEmbeddings(keras.layers.Layer):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """
    # 初始化方法,用于初始化类实例
    def __init__(self, config, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 从配置中获取图像大小和补丁大小
        image_size, patch_size = config.image_size, config.patch_size
        # 从配置中获取通道数和隐藏层大小
        num_channels, hidden_size = config.num_channels, config.hidden_size
        # 如果图像大小和补丁大小不是可迭代对象,则转换为元组
        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        # 计算图像中的补丁数量
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        # 将计算得到的各种参数保存在类实例中
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        # 创建投影层,使用 Conv2D 卷积层
        self.projection = keras.layers.Conv2D(
            hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
        )

    # 调用方法,用于执行前向传播
    def call(self, pixel_values):
        # 获取输入张量的形状信息
        batch_size, num_channels, height, width = shape_list(pixel_values)
        # 如果输入张量的通道数与配置中的不匹配,则引发值错误
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
            )
        # 如果输入图像的高度或宽度与配置中的不匹配,则引发值错误
        if height != self.image_size[0] or width != self.image_size[1]:
            raise ValueError(
                f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
            )
        # 对输入张量进行转置,然后通过投影层进行嵌入处理
        embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
        # 返回嵌入结果
        return embeddings

    # 构建方法,用于构建模型
    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        # 标记模型为已构建
        self.built = True
        # 如果投影层已经存在,则使用 TensorFlow 的 name_scope 来构建投影层
        if getattr(self, "projection", None) is not None:
            with tf.name_scope(self.projection.name):
                self.projection.build([None, None, None, self.num_channels])
class TFSamMLPBlock(keras.layers.Layer):
    # 初始化方法,用于创建 TFSamMLPBlock 实例
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        # 创建第一个全连接层,设置输出维度为 config.mlp_dim,命名为 "lin1"
        self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
        # 创建第二个全连接层,设置输出维度为 config.hidden_size,命名为 "lin2"
        self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
        # 获取激活函数,根据配置从全局变量 ACT2FN 中选择对应的函数
        self.act = ACT2FN[config.hidden_act]
        # 保存配置信息到实例变量中
        self.config = config

    # 前向传播方法,用于定义层的计算逻辑
    def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
        # 使用第一个全连接层进行计算
        hidden_states = self.lin1(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 使用第二个全连接层进行计算
        hidden_states = self.lin2(hidden_states)
        # 返回计算结果
        return hidden_states

    # 构建方法,用于构建层的权重
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果 lin1 层存在,则为其构建权重
        if getattr(self, "lin1", None) is not None:
            with tf.name_scope(self.lin1.name):
                self.lin1.build([None, None, self.config.hidden_size])
        # 如果 lin2 层存在,则为其构建权重
        if getattr(self, "lin2", None) is not None:
            with tf.name_scope(self.lin2.name):
                self.lin2.build([None, None, self.config.mlp_dim])


class TFSamLayerNorm(keras.layers.Layer):
    # LayerNorm 层支持 channels_last 或 channels_first 两种数据格式
    # 默认使用 channels_last
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = normalized_shape
        # 如果数据格式不在支持的列表中,则抛出异常
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError(f"Unsupported data format: {self.data_format}")

    # 构建方法,用于构建层的权重
    def build(self, input_shape):
        # 添加权重:标准化尺寸对应的权重,初始化为全1;偏置初始化为全0
        self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
        self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
        super().build(input_shape)

    # 前向传播方法,用于定义层的计算逻辑
    def call(self, x: tf.Tensor) -> tf.Tensor:
        # 根据数据格式选择不同的 LayerNorm 函数进行计算
        if self.data_format == "channels_last":
            x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
        elif self.data_format == "channels_first":
            x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
        return x


class TFSamAttention(keras.layers.Layer):
    """
    SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
    values.
    """
    # 初始化方法,接受配置和可选的下采样率作为参数
    def __init__(self, config, downsample_rate=None, **kwargs):
        # 调用父类初始化方法
        super().__init__(**kwargs)
        # 设置隐藏层大小
        self.hidden_size = config.hidden_size

        # 如果未提供下采样率,则使用配置中的注意力下采样率
        downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate

        # 计算内部维度,即隐藏层大小除以下采样率
        self.internal_dim = config.hidden_size // downsample_rate
        # 设置注意力头的数量
        self.num_attention_heads = config.num_attention_heads
        # 检查内部维度是否可以被注意力头数量整除
        if self.internal_dim % config.num_attention_heads != 0:
            raise ValueError("num_attention_heads must divide hidden_size.")

        # 初始化查询投影层
        self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
        # 初始化键投影层
        self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
        # 初始化值投影层
        self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
        # 初始化输出投影层
        self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")

    # 将隐藏状态张量分离为多个注意力头
    def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
        batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
        # 计算每个注意力头的通道数
        c_per_head = channel // num_attention_heads
        # 重塑张量形状以分离头
        hidden_states = tf.reshape(
            hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
        )
        return tf.transpose(hidden_states, perm=[0, 2, 1, 3])

    # 将分离的注意力头重新组合为隐藏状态张量
    def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
        batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
        # 调换张量的维度顺序
        hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
        # 重塑张量形状以重新组合头
        return tf.reshape(
            hidden_states,
            (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
        )

    # 模型的调用方法,接受查询、键和值张量,并返回输出张量
    def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
        # 对输入进行投影
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)

        # 获取点批处理大小
        point_batch_size = shape_list(query)[1]
        
        # 将投影后的张量分离为多个注意力头
        query = self._separate_heads(query, self.num_attention_heads)
        key = self._separate_heads(key, self.num_attention_heads)
        value = self._separate_heads(value, self.num_attention_heads)

        # 计算自注意力
        _, _, _, c_per_head = shape_list(query)
        attn = tf.matmul(
            query, tf.transpose(key, perm=[0, 1, 3, 2])
        )  # batch_size * point_batch_size  x N_heads x N_tokens x N_tokens
        # 缩放注意力权重
        attn = attn / tf.math.sqrt(float(c_per_head))
        # 应用 softmax 函数获得归一化的注意力权重
        attn = tf.nn.softmax(attn, axis=-1)

        # 计算输出张量
        out = tf.matmul(attn, value)
        # 将重新组合后的注意力头合并
        out = self._recombine_heads(out, point_batch_size)
        # 对输出应用输出投影
        out = self.out_proj(out)

        return out
    # 构建函数,用于构建模型的输入形状
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 设置标记为已构建
        self.built = True
        
        # 如果存在查询投影层,执行以下操作
        if getattr(self, "q_proj", None) is not None:
            # 在 TensorFlow 中创建名称作用域,命名为 self.q_proj.name
            with tf.name_scope(self.q_proj.name):
                # 使用 [None, None, self.hidden_size] 的形状构建查询投影层
                self.q_proj.build([None, None, self.hidden_size])
        
        # 如果存在键投影层,执行以下操作
        if getattr(self, "k_proj", None) is not None:
            # 在 TensorFlow 中创建名称作用域,命名为 self.k_proj.name
            with tf.name_scope(self.k_proj.name):
                # 使用 [None, None, self.hidden_size] 的形状构建键投影层
                self.k_proj.build([None, None, self.hidden_size])
        
        # 如果存在值投影层,执行以下操作
        if getattr(self, "v_proj", None) is not None:
            # 在 TensorFlow 中创建名称作用域,命名为 self.v_proj.name
            with tf.name_scope(self.v_proj.name):
                # 使用 [None, None, self.hidden_size] 的形状构建值投影层
                self.v_proj.build([None, None, self.hidden_size])
        
        # 如果存在输出投影层,执行以下操作
        if getattr(self, "out_proj", None) is not None:
            # 在 TensorFlow 中创建名称作用域,命名为 self.out_proj.name
            with tf.name_scope(self.out_proj.name):
                # 使用 [None, None, self.internal_dim] 的形状构建输出投影层
                self.out_proj.build([None, None, self.internal_dim])
# 定义了一个基于 Transformer 的自定义层,用于处理两种不同数据类型之间的注意力机制和多层感知机操作

class TFSamTwoWayAttentionBlock(keras.layers.Layer):
    def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
        """
        初始化函数,配置了四个层次的 Transformer 模块:
            (1) 自注意力层,处理稀疏输入
            (2) 从稀疏输入到密集输入的交叉注意力层
            (3) 在稀疏输入上的多层感知机块
            (4) 从密集输入到稀疏输入的交叉注意力层

        Arguments:
            config (`SamMaskDecoderConfig`):
                用于实例化该块的配置文件
            attention_downsample_rate (*optionalk*, int, defaults to 2):
                用于减少注意力内部维度的下采样比率
            skip_first_layer_pe (*optional*, bool, defaults to `False`):
                是否跳过在第一层添加 query_point_embedding 的步骤
        """
        super().__init__(**kwargs)

        # 从配置中获取隐藏大小和层归一化的 epsilon 值
        self.hidden_size = config.hidden_size
        self.layer_norm_eps = config.layer_norm_eps

        # 定义自注意力层和归一化层
        self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")

        # 定义从标记到图像的交叉注意力层和归一化层
        self.cross_attn_token_to_image = TFSamAttention(
            config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
        )
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")

        # 定义多层感知机块和归一化层
        self.mlp = TFSamMLPBlock(config, name="mlp")
        self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")

        # 定义从图像到标记的交叉注意力层的归一化层
        self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
        self.cross_attn_image_to_token = TFSamAttention(
            config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
        )

        # 是否跳过第一层添加 query_point_embedding
        self.skip_first_layer_pe = skip_first_layer_pe

    def call(
        self,
        queries: tf.Tensor,
        keys: tf.Tensor,
        query_point_embedding: tf.Tensor,
        key_point_embedding: tf.Tensor,
        output_attentions: bool = False,
    ):
        # Self attention block
        # 如果设置了跳过第一层的位置编码,则使用自注意力机制
        if self.skip_first_layer_pe:
            queries = self.self_attn(query=queries, key=queries, value=queries)
        else:
            # 否则,将位置编码加到查询中,然后进行自注意力计算
            query = queries + query_point_embedding
            attn_out = self.self_attn(query=query, key=query, value=queries)
            queries = queries + attn_out
        queries = self.layer_norm1(queries)

        # Cross attention block, tokens attending to image embedding
        # 将位置编码加到查询中,图像的键和位置编码相加后进行交叉注意力计算
        query = queries + query_point_embedding
        key = keys + key_point_embedding
        attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
        queries = queries + attn_out
        queries = self.layer_norm2(queries)

        # MLP block
        # 使用多层感知机(MLP)处理查询
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.layer_norm3(queries)

        # Cross attention block, image embedding attending to tokens
        # 将位置编码加到查询中,图像的查询和位置编码相加后进行交叉注意力计算
        query = queries + query_point_embedding
        key = keys + key_point_embedding
        attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
        keys = keys + attn_out
        keys = self.layer_norm4(keys)

        # 构建输出元组
        outputs = (queries, keys)

        # 如果需要输出注意力权重,则添加到输出元组中
        if output_attentions:
            outputs = outputs + (attn_out,)
        else:
            outputs = outputs + (None,)

        return outputs

    def build(self, input_shape=None):
        # 如果模型已经构建,则直接返回
        if self.built:
            return
        self.built = True
        
        # 构建自注意力层(如果存在)
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        
        # 构建层归一化1(如果存在)
        if getattr(self, "layer_norm1", None) is not None:
            with tf.name_scope(self.layer_norm1.name):
                self.layer_norm1.build([None, None, None, self.hidden_size])
        
        # 构建图像到标记的交叉注意力层(如果存在)
        if getattr(self, "cross_attn_token_to_image", None) is not None:
            with tf.name_scope(self.cross_attn_token_to_image.name):
                self.cross_attn_token_to_image.build(None)
        
        # 构建层归一化2(如果存在)
        if getattr(self, "layer_norm2", None) is not None:
            with tf.name_scope(self.layer_norm2.name):
                self.layer_norm2.build([None, None, None, self.hidden_size])
        
        # 构建多层感知机(MLP)层(如果存在)
        if getattr(self, "mlp", None) is not None:
            with tf.name_scope(self.mlp.name):
                self.mlp.build(None)
        
        # 构建层归一化3(如果存在)
        if getattr(self, "layer_norm3", None) is not None:
            with tf.name_scope(self.layer_norm3.name):
                self.layer_norm3.build([None, None, None, self.hidden_size])
        
        # 构建图像到标记的交叉注意力层(如果存在)
        if getattr(self, "cross_attn_image_to_token", None) is not None:
            with tf.name_scope(self.cross_attn_image_to_token.name):
                self.cross_attn_image_to_token.build(None)
# 定义一个名为 TFSamTwoWayTransformer 的自定义层,继承自 keras.layers.Layer
class TFSamTwoWayTransformer(keras.layers.Layer):
    # 初始化函数,接受一个 config 对象和其他关键字参数
    def __init__(self, config: SamMaskDecoderConfig, **kwargs):
        # 调用父类的初始化函数
        super().__init__(**kwargs)
        # 将传入的 config 对象保存为属性
        self.config = config

        # 从 config 中获取隐藏层数目并保存为属性
        self.num_hidden_layers = config.num_hidden_layers
        # 初始化一个空列表用于保存多个自定义注意力块
        self.layers = []

        # 循环创建指定数量的自定义注意力块,并添加到 self.layers 中
        for i in range(self.num_hidden_layers):
            self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))

        # 创建一个用于最终注意力层的对象,并命名为 final_attn_token_to_image
        self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
        # 创建一个 LayerNormalization 层,并使用 config 中的 epsilon 参数,命名为 layer_norm_final_attn
        self.layer_norm_final_attn = keras.layers.LayerNormalization(
            epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
        )

    # 定义 call 方法,处理输入张量并执行前向传播
    def call(
        self,
        point_embeddings: tf.Tensor,
        image_embeddings: tf.Tensor,
        image_positional_embeddings: tf.Tensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TFBaseModelOutput]:
        # 确定是否输出注意力权重,默认从 self.config 中获取
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 确定是否输出隐藏状态,默认从 self.config 中获取
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否使用返回字典,默认从 self.config 中获取
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 初始化一个空元组,用于存储所有注意力权重
        all_attentions = ()

        # 如果 image_embeddings 为 None,则抛出 ValueError
        if image_embeddings is None:
            raise ValueError("You have to specify an image_embedding")

        # 对 image_embeddings 进行转置和扁平化处理,保持一致性,并添加一个维度
        image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
        # 对 image_positional_embeddings 进行转置和扁平化处理,保持一致性,并添加一个维度
        image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]

        # 准备查询向量,使用 point_embeddings
        queries = point_embeddings
        # 准备键向量,使用 image_embeddings

        keys = image_embeddings

        # 遍历 self.layers 中的每个注意力块,并应用于查询和键向量
        for layer in self.layers:
            queries, keys, attention_outputs = layer(
                queries=queries,
                keys=keys,
                query_point_embedding=point_embeddings,
                key_point_embedding=image_positional_embeddings,
                output_attentions=output_attentions,
            )

            # 如果设置了 output_attentions,则将注意力权重输出存储到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (attention_outputs,)

        # 应用从点到图像的最终注意力层
        query = queries + point_embeddings
        key = keys + image_positional_embeddings

        # 调用 self.final_attn_token_to_image 执行注意力操作,输出结果存储在 attn_out 中
        attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)

        # 将 attn_out 加到 queries 中,并通过 layer_norm_final_attn 进行规范化处理
        queries = queries + attn_out
        queries = self.layer_norm_final_attn(queries)

        # 返回处理后的 queries、keys 和 all_attentions(如果有)
        return queries, keys, all_attentions
    # 构建方法,用于构建模型结构
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 设置标志为已构建
        self.built = True
        
        # 如果存在最终注意力机制的映射到图像的部分,构建这部分
        if getattr(self, "final_attn_token_to_image", None) is not None:
            # 在命名空间下构建最终注意力机制映射到图像的层
            with tf.name_scope(self.final_attn_token_to_image.name):
                self.final_attn_token_to_image.build(None)
        
        # 如果存在最终注意力层归一化部分,构建这部分
        if getattr(self, "layer_norm_final_attn", None) is not None:
            # 在命名空间下构建最终注意力层归一化层
            with tf.name_scope(self.layer_norm_final_attn.name):
                self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
        
        # 遍历所有层,分别在其命名空间下构建每一层
        for layer in self.layers:
            with tf.name_scope(layer.name):
                layer.build(None)
class TFSamFeedForward(keras.layers.Layer):
    # 定义一个自定义层 TFSamFeedForward,继承自 keras.layers.Layer
    def __init__(
        self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
    ):
        super().__init__(**kwargs)
        # 调用父类的初始化方法
        self.num_layers = num_layers
        # 设置层的数量
        self.activation = keras.layers.ReLU()
        # 设置激活函数为 ReLU
        self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
        # 定义输入投影层,将输入维度映射到隐藏维度
        self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
        # 定义输出投影层,将隐藏维度映射到输出维度
        self.layers = [
            keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
            for i in range(num_layers - 2)
        ]
        # 定义多个隐藏层,除去输入和输出层
        self.sigmoid_output = sigmoid_output
        # 设置是否使用 sigmoid 输出
        self.hidden_dim = hidden_dim
        # 存储隐藏层维度
        self.input_dim = input_dim
        # 存储输入维度

    def call(self, hidden_states):
        # 定义层的前向传播过程
        hidden_states = self.proj_in(hidden_states)
        # 输入投影层处理输入数据
        hidden_states = self.activation(hidden_states)
        # 使用激活函数处理投影后的数据
        for layer in self.layers:
            hidden_states = self.activation(layer(hidden_states))
            # 遍历并使用激活函数处理每个隐藏层的数据

        hidden_states = self.proj_out(hidden_states)
        # 输出投影层处理隐藏层输出数据
        if self.sigmoid_output:
            hidden_states = tf.sigmoid(hidden_states)
            # 如果需要 sigmoid 输出,则应用 sigmoid 函数
        return hidden_states
        # 返回最终输出数据

    def build(self, input_shape=None):
        # 定义层的构建方法
        if self.built:
            return
        self.built = True
        # 标记层已构建
        if getattr(self, "proj_in", None) is not None:
            with tf.name_scope(self.proj_in.name):
                self.proj_in.build([None, None, self.input_dim])
        # 如果输入投影层存在,使用其名称作用域构建
        if getattr(self, "proj_out", None) is not None:
            with tf.name_scope(self.proj_out.name):
                self.proj_out.build([None, None, self.hidden_dim])
        # 如果输出投影层存在,使用其名称作用域构建
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build([None, None, self.hidden_dim])
        # 遍历每个隐藏层,使用其名称作用域构建
    def __init__(self, config: SamMaskDecoderConfig, **kwargs):
        super().__init__(**kwargs)  # 调用父类的初始化方法

        self.hidden_size = config.hidden_size  # 设置隐藏层大小

        self.num_multimask_outputs = config.num_multimask_outputs  # 多掩模输出数量
        self.num_mask_tokens = config.num_multimask_outputs + 1  # 掩模令牌数量

        self.transformer = TFSamTwoWayTransformer(config, name="transformer")  # 创建一个双向变换器对象

        self.upscale_conv1 = keras.layers.Conv2DTranspose(
            self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
        )  # 第一个上采样卷积层,将隐藏层大小的四分之一作为输出,步长为2,使用通道优先的数据格式

        self.upscale_conv2 = keras.layers.Conv2DTranspose(
            self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
        )  # 第二个上采样卷积层,将隐藏层大小的八分之一作为输出,步长为2,使用通道优先的数据格式

        self.upscale_layer_norm = TFSamLayerNorm(
            self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
        )  # 上采样层的归一化层,以隐藏层大小的四分之一作为输入,使用通道优先的数据格式

        self.activation = tf.nn.gelu  # 激活函数设置为 GELU 函数

        mlps_list = []
        for i in range(self.num_mask_tokens):
            mlps_list += [
                TFSamFeedForward(
                    self.hidden_size,
                    self.hidden_size,
                    self.hidden_size // 8,
                    3,
                    name=f"output_hypernetworks_mlps_._{i}",
                )
            ]  # 构建多个前馈网络,并添加到列表中作为超网络的输出

        self.output_hypernetworks_mlps = mlps_list  # 超网络的输出层列表

        self.iou_prediction_head = TFSamFeedForward(
            self.hidden_size,
            config.iou_head_hidden_dim,
            self.num_mask_tokens,
            config.iou_head_depth,
            name="iou_prediction_head",
        )  # IOU 预测头部的前馈网络

    def build(self, input_shape=None):
        if self.built:
            return  # 如果已经构建过,直接返回

        self.built = True  # 标记为已构建

        self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)  # 添加 IOU 令牌权重参数

        self.mask_tokens = self.add_weight(
            shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
        )  # 添加掩模令牌权重参数

        if getattr(self, "transformer", None) is not None:
            with tf.name_scope(self.transformer.name):
                self.transformer.build(None)  # 构建变换器对象

        if getattr(self, "upscale_conv1", None) is not None:
            with tf.name_scope(self.upscale_conv1.name):
                self.upscale_conv1.build([None, self.hidden_size, None, None])  # 构建第一个上采样卷积层

        if getattr(self, "upscale_conv2", None) is not None:
            with tf.name_scope(self.upscale_conv2.name):
                self.upscale_conv2.build([None, self.hidden_size // 4, None, None])  # 构建第二个上采样卷积层

        if getattr(self, "upscale_layer_norm", None) is not None:
            with tf.name_scope(self.upscale_layer_norm.name):
                self.upscale_layer_norm.build(None)  # 构建上采样层的归一化层

        if getattr(self, "iou_prediction_head", None) is not None:
            with tf.name_scope(self.iou_prediction_head.name):
                self.iou_prediction_head.build(None)  # 构建 IOU 预测头部的前馈网络

        for mlp in self.output_hypernetworks_mlps:
            with tf.name_scope(mlp.name):
                mlp.build(None)  # 构建超网络的输出层列表中的每个前馈网络
    # 定义一个方法 `call`,接受多个参数作为输入
    def call(
        self,
        # 图像嵌入向量,使用 TensorFlow 的张量表示
        image_embeddings: tf.Tensor,
        # 图像位置嵌入向量,使用 TensorFlow 的张量表示
        image_positional_embeddings: tf.Tensor,
        # 稀疏提示嵌入向量,使用 TensorFlow 的张量表示
        sparse_prompt_embeddings: tf.Tensor,
        # 密集提示嵌入向量,使用 TensorFlow 的张量表示
        dense_prompt_embeddings: tf.Tensor,
        # 是否输出多掩码的结果,布尔类型
        multimask_output: bool,
        # 是否输出注意力信息,可选的布尔类型参数,默认为 None
        output_attentions: Optional[bool] = None,
class TFSamPositionalEmbedding(keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.scale = config.hidden_size // 2  # 计算缩放因子,用于位置编码
        self.config = config

    def build(self, input_shape):
        # 构建层时,创建一个不可训练的权重矩阵作为位置编码的基础
        self.positional_embedding = self.add_weight(
            name="positional_embedding",
            shape=(2, self.config.num_pos_feats),  # 设置权重矩阵的形状
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),  # 使用随机正态分布初始化权重
            trainable=False,  # 设置权重为不可训练
        )
        super().build(input_shape)

    def call(self, input_coords, input_shape=None):
        """Positionally encode points that are normalized to [0,1]."""
        coordinates = tf.identity(input_coords)

        if input_shape is not None:
            coordinates = tf.stack(
                [
                    tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],  # 将 x 坐标归一化到 [0,1]
                    tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],  # 将 y 坐标归一化到 [0,1]
                ],
                axis=-1,
            )

        # 将归一化后的坐标转换到 [-1, 1] 区间
        coordinates = 2 * coordinates - 1
        coordinates = tf.cast(coordinates, self.positional_embedding.dtype)  # 转换坐标数据类型以匹配位置编码的数据类型
        coordinates = tf.matmul(coordinates, self.positional_embedding)  # 计算坐标与位置编码的乘积
        coordinates = 2 * np.pi * coordinates  # 缩放乘积以增加周期性
        # 输出正弦和余弦函数的组合,用于位置编码
        return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)


class TFSamMaskEmbedding(keras.layers.Layer):
    def __init__(self, config: SamPromptEncoderConfig, **kwargs):
        super().__init__(**kwargs)
        self.mask_input_channels = config.mask_input_channels // 4  # 计算掩码输入通道数的四分之一
        self.activation = ACT2FN[config.hidden_act]  # 激活函数由配置决定
        self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")  # 第一个卷积层
        self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")  # 第二个卷积层
        self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")  # 第三个卷积层
        self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")  # 第一个层归一化层
        self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")  # 第二个层归一化层
        self.config = config  # 保存配置信息
    def call(self, masks):
        # 转置输入张量,将通道维度移到最后一个维度
        masks = tf.transpose(masks, perm=(0, 2, 3, 1))  # Convert to channels-last
        # 第一层卷积操作
        hidden_states = self.conv1(masks)
        # 第一层层归一化
        hidden_states = self.layer_norm1(hidden_states)
        # 激活函数处理
        hidden_states = self.activation(hidden_states)

        # 第二层卷积操作
        hidden_states = self.conv2(hidden_states)
        # 第二层层归一化
        hidden_states = self.layer_norm2(hidden_states)
        # 激活函数处理
        hidden_states = self.activation(hidden_states)
        # 第三层卷积操作
        dense_embeddings = self.conv3(hidden_states)
        # 转置张量,将通道维度移到第二个位置,回到 channels-first 格式
        dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2))  # Convert back to channels-first
        return dense_embeddings

    def build(self, input_shape=None):
        # 由于此类不会使用标准的虚拟输入,因此需要显式的 build 方法
        if self.built:
            return
        self.built = True
        # 使用 tf.name_scope 确定每个层的输入形状
        with tf.name_scope("conv1"):
            self.conv1.build([None, None, None, 1])
        with tf.name_scope("conv2"):
            self.conv2.build([None, None, None, self.mask_input_channels])
        with tf.name_scope("conv3"):
            self.conv3.build([None, None, None, self.mask_input_channels * 4])
        with tf.name_scope("layer_norm1"):
            self.layer_norm1.build([None, None, None, self.mask_input_channels])
        with tf.name_scope("layer_norm2"):
            self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
class TFSamPromptEncoder(keras.layers.Layer):
    def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
        super().__init__(**kwargs)
        self.shared_embedding = shared_patch_embedding  # 设置共享的补丁嵌入对象
        self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")  # 创建一个 TFSamMaskEmbedding 实例作为 mask_embed
        self.no_mask_embed = None  # 初始化为 None,在 build 方法中将被赋值为权重
        self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)  # 图像嵌入的尺寸
        self.input_image_size = config.image_size  # 输入图像的尺寸

        self.point_embed = []  # 初始化为空列表,用于存储点嵌入的权重
        self.hidden_size = config.hidden_size  # 隐藏层的大小
        self.not_a_point_embed = None  # 初始化为 None,在 build 方法中将被赋值为权重
        self.config = config  # 保存配置对象

    def build(self, input_shape=None):
        self.no_mask_embed = self.add_weight(
            name="no_mask_embed.weight",
            shape=(1, self.hidden_size),
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
            trainable=True,
        )  # 添加一个权重变量,用于表示没有 mask 的嵌入

        self.point_embed = [
            self.add_weight(
                name=f"point_embed_._{i}.weight",
                shape=(1, self.hidden_size),
                initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
                trainable=True,
            )
            for i in range(self.config.num_point_embeddings)
        ]  # 添加多个权重变量,用于表示点嵌入

        self.not_a_point_embed = self.add_weight(
            name="not_a_point_embed.weight",
            shape=(1, self.hidden_size),
            initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
            trainable=True,
        )  # 添加一个权重变量,用于表示非点嵌入

        with tf.name_scope("mask_embed"):
            # 显式构建 mask_embed,因为它不会被标准的虚拟输入所触及
            self.mask_embed.build(
                (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
            )  # 构建 mask_embed 的结构

        if self.built:
            return
        self.built = True

        if getattr(self, "mask_embed", None) is not None:
            with tf.name_scope(self.mask_embed.name):
                self.mask_embed.build(None)  # 如果 mask_embed 存在,则进一步构建它的结构
    def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
        """Embeds point prompts."""
        # 将点坐标加上0.5以将其移动到像素中心
        points = points + 0.5  # Shift to center of pixel
        if pad:
            # 构建目标点的形状,用于填充
            target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
            target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
            # 创建零填充的点和标签
            padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
            padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
            # 在点和标签的第三个维度上连接填充内容
            points = tf.concat([points, padding_point], axis=2)
            labels = tf.concat([labels, padding_label], axis=2)
        input_shape = (self.input_image_size, self.input_image_size)
        # 使用共享的嵌入层嵌入点坐标
        point_embedding = self.shared_embedding(points, input_shape)

        # 根据标签值进行条件选择和嵌入
        point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)

        point_embedding = tf.where(
            labels[..., None] != -10,
            point_embedding,
            tf.zeros_like(point_embedding),
        )
        point_embedding = tf.where(
            (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
        )
        point_embedding = tf.where(
            (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
        )
        return point_embedding

    def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
        """Embeds box prompts."""
        # 将框坐标加上0.5以将其移动到像素中心
        boxes = boxes + 0.5  # Shift to center of pixel
        batch_size, nb_boxes = shape_list(boxes)[:2]
        # 重塑框的坐标形状以适应嵌入层
        coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
        input_shape = (self.input_image_size, self.input_image_size)
        # 使用共享的嵌入层嵌入角点坐标
        corner_embedding = self.shared_embedding(coords, input_shape)
        # 根据条件在角点嵌入上添加偏移量
        corner_embedding += tf.where(
            tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
            self.point_embed[2][0],
            self.point_embed[3][0],
        )
        return corner_embedding

    def call(
        self,
        batch_size: Optional[int],
        input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
        input_labels: tf.Tensor | None,
        input_boxes: tf.Tensor | None,
        input_masks: tf.Tensor | None,
        """
        Embeds different types of prompts, returning both sparse and dense embeddings.

        Args:
            points (`tf.Tensor`, *optional*):
                point coordinates and labels to embed.
            boxes (`tf.Tensor`, *optional*):
                boxes to embed
            masks (`tf.Tensor`, *optional`):
                masks to embed
        """
        # 初始化稀疏和密集嵌入为 None
        sparse_embeddings = None

        # 如果输入的点不为空,则进行点嵌入
        if input_points is not None:
            # 获取批量大小和点批次大小
            batch_size, point_batch_size = shape_list(input_points)[:2]
            # 如果输入的标签为空,则抛出数值错误
            if input_labels is None:
                raise ValueError("If points are provided, labels must also be provided.")
            # 使用内部方法 _embed_points 进行点的嵌入
            point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
            # 创建全零的稀疏嵌入张量
            sparse_embeddings = tf.zeros(
                (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
            )
            # 将点嵌入拼接到稀疏嵌入张量中
            sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)

        # 如果输入的盒子不为空,则进行盒子的嵌入
        if input_boxes is not None:
            # 获取批量大小
            batch_size = shape_list(input_boxes)[0]
            # 使用内部方法 _embed_boxes 进行盒子的嵌入
            box_embeddings = self._embed_boxes(input_boxes)
            # 如果稀疏嵌入张量为空,则将盒子嵌入设为稀疏嵌入张量
            if sparse_embeddings is None:
                sparse_embeddings = box_embeddings
            else:
                # 否则将盒子嵌入拼接到稀疏嵌入张量中
                sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)

        # 如果输入的掩码不为空,则进行掩码的嵌入
        if input_masks is not None:
            # 使用 mask_embed 方法进行掩码的嵌入
            dense_embeddings = self.mask_embed(input_masks)
        else:
            # 否则使用无掩码嵌入的第一个元素作为密集嵌入
            dense_embeddings = self.no_mask_embed[0]
            # 调整密集嵌入的形状
            dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
            # 在指定维度上复制密集嵌入
            dense_embeddings = tf.tile(
                dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
            )

        # 如果稀疏嵌入张量仍为空,则创建全零的稀疏嵌入张量
        if sparse_embeddings is None:
            sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)

        # 返回稀疏嵌入张量和密集嵌入张量
        return sparse_embeddings, dense_embeddings
# 定义一个自定义的注意力层,用于处理多头注意力机制和相对位置编码
class TFSamVisionAttention(keras.layers.Layer):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(self, config, window_size, **kwargs):
        super().__init__(**kwargs)
        # 计算输入大小,根据配置和窗口大小决定
        input_size = (
            (config.image_size // config.patch_size, config.image_size // config.patch_size)
            if window_size == 0
            else (window_size, window_size)
        )
        self.input_size = input_size

        # 注意力头的数量
        self.num_attention_heads = config.num_attention_heads
        # 每个注意力头的维度
        head_dim = config.hidden_size // config.num_attention_heads
        self.head_dim = head_dim
        # 缩放因子,用于缩放注意力分数
        self.scale = head_dim**-0.5
        # 注意力层的 dropout 概率
        self.dropout = config.attention_dropout

        # QKV 查询键值对应的全连接层
        self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
        # 投影层,用于最终的输出
        self.proj = keras.layers.Dense(config.hidden_size, name="proj")

        # 是否使用相对位置编码
        self.use_rel_pos = config.use_rel_pos
        if self.use_rel_pos:
            # 如果使用相对位置编码,确保输入大小已提供
            if input_size is None:
                raise ValueError("Input size must be provided if using relative positional encoding.")
        self.config = config

    def build(self, input_shape=None):
        # 如果输入大小不为 None,则初始化相对位置编码
        if self.input_size is not None:
            # 水平方向的相对位置编码权重矩阵
            self.rel_pos_h = self.add_weight(
                shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
            )
            # 垂直方向的相对位置编码权重矩阵
            self.rel_pos_w = self.add_weight(
                shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
            )

        if self.built:
            return
        self.built = True
        # 构建 QKV 全连接层
        if getattr(self, "qkv", None) is not None:
            with tf.name_scope(self.qkv.name):
                self.qkv.build([None, None, self.config.hidden_size])
        # 构建投影层
        if getattr(self, "proj", None) is not None:
            with tf.name_scope(self.proj.name):
                self.proj.build([None, None, self.config.hidden_size])
    def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
        """
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`tf.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        """
        # Calculate the maximum relative distance based on query and key sizes
        max_rel_dist = int(2 * max(q_size, k_size) - 1)
        
        # Interpolate rel_pos if its length does not match max_rel_dist
        if rel_pos.shape[0] != max_rel_dist:
            # Resize rel_pos using bilinear interpolation to match max_rel_dist
            rel_pos_resized = tf.image.resize(
                tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
                size=(max_rel_dist, rel_pos.shape[1]),
                method="bilinear",
            )
            # Reshape the interpolated rel_pos to match the expected shape
            rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
        else:
            # Use rel_pos directly if its length matches max_rel_dist
            rel_pos_resized = rel_pos

        # Calculate relative coordinates scaled according to the sizes of q and k
        q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
        k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
        relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

        # Gather positional embeddings based on the calculated relative coordinates
        return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))

    def add_decomposed_rel_pos(
        self,
        attn: tf.Tensor,
        query: tf.Tensor,
        rel_pos_h: tf.Tensor,
        rel_pos_w: tf.Tensor,
        q_size: Tuple[int, int],
        k_size: Tuple[int, int],
        ...
    ) -> tf.Tensor:
        """
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`tf.Tensor`):
                attention map.
            query (`tf.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`tf.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`tf.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`tf.Tensor`):
                attention map with added relative positional embeddings.
        """
        # 解包查询和键的空间尺寸
        query_height, query_width = q_size
        key_height, key_width = k_size
        
        # 获取相对位置嵌入的高度和宽度分量
        relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
        relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
        
        # 获取查询张量的形状信息
        batch_size, _, dim = shape_list(query)
        
        # 将查询张量重塑为四维张量,以便进行相对位置计算
        reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
        
        # 使用 Einstein Summation 计算高度和宽度方向上的相对位置加权
        rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
        rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
        
        # 将注意力张量重塑为五维张量,以便应用相对位置加权
        attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
        
        # 将相对位置加权应用到注意力张量上,并将其重塑为二维张量
        attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
        attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
        
        # 返回加入相对位置嵌入后的注意力张量
        return attn
    # 定义一个方法,输入为 hidden_states (TensorFlow 张量),output_attentions 和 training 的布尔参数
    def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
        # 获取 batch_size, height, width 和通道数
        batch_size, height, width, _ = shape_list(hidden_states)
        # 通过 qkv 层处理 hidden_states,得到一个形状为 (batch_size, height * width, 3, num_attention_heads, -1) 的张量
        qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
        # 调整维度,将 qkv 张量从 (batch_size, height * width, 3, num_attention_heads, -1) 转换为 (3, batch_size, num_attention_heads, height * width, -1)
        qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
        # 将 qkv 张量重新调整为 (batch_size * num_attention_heads, height * width, channel),并拆分为 query, key, value
        query, key, value = tf.unstack(
            tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
        )
        # 计算注意力权重,query 与 key 的乘积,缩放后进行矩阵乘法
        attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)

        # 如果使用相对位置编码,调用 add_decomposed_rel_pos 方法添加相对位置编码
        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        # 对注意力权重进行 softmax 操作,归一化权重
        attn_weights = tf.nn.softmax(attn_weights, axis=-1)

        # 根据是否训练,应用 dropout
        if training:
            attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
        else:
            attn_probs = attn_weights

        # 计算注意力输出,使用注意力权重与 value 的矩阵乘法,重塑形状为 (batch_size, num_attention_heads, height, width, -1)
        attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
        # 调整输出维度,将其从 (batch_size, num_attention_heads, height, width, -1) 转换为 (batch_size, height, width, num_attention_heads, -1)
        attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
        # 将注意力输出重新调整形状为 (batch_size, height, width, hidden_size)
        attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))

        # 通过 proj 层对注意力输出进行投影
        attn_output = self.proj(attn_output)

        # 如果需要输出注意力权重,返回包含注意力输出和注意力权重的元组,否则只返回注意力输出
        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)

        return outputs
# 定义自定义层 TFSamVisionLayer,继承自 keras.layers.Layer
class TFSamVisionLayer(keras.layers.Layer):
    
    # 初始化方法,接受配置 config 和窗口大小 window_size 作为参数
    def __init__(self, config, window_size, **kwargs):
        super().__init__(**kwargs)
        
        # 创建 LayerNormalization 层,用于归一化数据,设定 epsilon 为 config 中的 layer_norm_eps,命名为 "layer_norm1"
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
        
        # 创建 TFSamVisionAttention 层,用于处理注意力机制,传入配置 config 和窗口大小 window_size,命名为 "attn"
        self.attn = TFSamVisionAttention(config, window_size, name="attn")
        
        # 创建第二个 LayerNormalization 层,设定 epsilon 为 config 中的 layer_norm_eps,命名为 "layer_norm2"
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
        
        # 创建 TFSamMLPBlock 层,用于多层感知机处理,传入配置 config,命名为 "mlp"
        self.mlp = TFSamMLPBlock(config, name="mlp")
        
        # 设置窗口大小和配置参数为实例变量
        self.window_size = window_size
        self.config = config

    # 定义窗口划分方法,接受 hidden_states 和 window_size 作为输入,返回划分后的窗口和填充后的高度和宽度元组
    def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
        # 获取 hidden_states 的形状信息
        batch_size, height, width, channel = shape_list(hidden_states)

        # 计算高度和宽度的填充量,使其能够被 window_size 整除
        pad_h = (window_size - height % window_size) % window_size
        pad_w = (window_size - width % window_size) % window_size
        
        # 如果存在填充,则在高度和宽度上进行填充
        if pad_h > 0 or pad_w > 0:
            hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
        
        # 计算填充后的高度和宽度
        pad_height, pad_width = height + pad_h, width + pad_w
        
        # 将 hidden_states 重新形状为 batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
        hidden_states = tf.reshape(
            hidden_states,
            [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
        )
        
        # 将形状变换后的 hidden_states 进行转置和重塑,得到 windows,形状为 [-1, window_size, window_size, channel]
        windows = tf.reshape(
            tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
        )
        
        # 返回 windows 和填充后的高度和宽度元组
        return windows, (pad_height, pad_width)

    # 定义窗口反划分方法,接受 windows、window_size、padding_shape 和 original_shape 作为输入,返回反划分后的 hidden_states
    def window_unpartition(
        self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
    ) -> tf.Tensor:
        # 获取填充后的高度和宽度
        pad_height, pad_width = padding_shape
        
        # 获取原始的高度和宽度
        height, width = original_shape
        
        # 计算 batch_size
        batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
        
        # 将 windows 重新形状为 batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
        hidden_states = tf.reshape(
            windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
        )
        
        # 进行转置和重塑,得到 hidden_states,形状为 batch_size, pad_height, pad_width, -1
        hidden_states = tf.reshape(
            tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
        )
        
        # 如果填充后的高度或宽度大于原始高度或宽度,则截取对应部分
        if pad_height > height or pad_width > width:
            hidden_states = hidden_states[:, :height, :width, :]
        
        # 返回反划分后的 hidden_states
        return hidden_states

    # 定义调用方法 call,接受 hidden_states、output_attentions 和 training 作为输入
    def call(
        self,
        hidden_states: tf.Tensor,
        output_attentions: Optional[bool] = False,
        training: Optional[bool] = False,
    ) -> Tuple[tf.Tensor]:
        # 保留原始隐藏状态作为残差连接的基础
        residual = hidden_states

        # 应用 Layer Normalization 到隐藏状态
        hidden_states = self.layer_norm1(hidden_states)
        
        # 如果窗口大小大于0,则进行窗口划分
        if self.window_size > 0:
            # 获取隐藏状态的高度和宽度
            height, width = hidden_states.shape[1], hidden_states.shape[2]
            # 对隐藏状态进行窗口划分,同时获取填充形状
            hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)

        # 进行自注意力机制操作,并获取注意力权重
        hidden_states, attn_weights = self.attn(
            hidden_states=hidden_states,
            output_attentions=output_attentions,
            training=training,
        )
        
        # 如果窗口大小大于0,则进行窗口反划分
        if self.window_size > 0:
            hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))

        # 添加残差连接到经过注意力操作后的隐藏状态
        hidden_states = residual + hidden_states
        
        # 应用 Layer Normalization 到加和后的隐藏状态
        layernorm_output = self.layer_norm2(hidden_states)
        
        # 应用 MLP(多层感知机)到 Layer Normalization 后的输出
        hidden_states = hidden_states + self.mlp(layernorm_output)

        # 准备输出,包含最终的隐藏状态
        outputs = (hidden_states,)
        
        # 如果需要输出注意力权重,则将它们加入到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        # 标记为已构建
        self.built = True
        
        # 构建 layer_norm1 层
        if getattr(self, "layer_norm1", None) is not None:
            with tf.name_scope(self.layer_norm1.name):
                self.layer_norm1.build([None, None, None, self.config.hidden_size])
        
        # 构建 attention 层
        if getattr(self, "attn", None) is not None:
            with tf.name_scope(self.attn.name):
                self.attn.build(None)
        
        # 构建 layer_norm2 层
        if getattr(self, "layer_norm2", None) is not None:
            with tf.name_scope(self.layer_norm2.name):
                self.layer_norm2.build([None, None, None, self.config.hidden_size])
        
        # 构建 MLP 层
        if getattr(self, "mlp", None) is not None:
            with tf.name_scope(self.mlp.name):
                self.mlp.build(None)
# 定义自定义的视觉编码器层,继承自 Keras 的 Layer 类
class TFSamVisionEncoder(keras.layers.Layer):
    # 初始化方法,接收配置对象和其他关键字参数
    def __init__(self, config: SamVisionConfig, **kwargs):
        super().__init__(**kwargs)
        # 将配置对象保存到实例变量中
        self.config = config
        # 设置图像大小属性
        self.image_size = config.image_size

        # 创建图像块嵌入层对象,并命名为 "patch_embed"
        self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")

        # 初始化位置嵌入变量,暂时设为 None
        self.pos_embed = None

        # 初始化层列表,用于保存多个视觉层对象
        self.layers = []
        # 循环创建指定数量的视觉层对象
        for i in range(config.num_hidden_layers):
            # 创建单个视觉层对象,名称中包括层索引 i
            layer = TFSamVisionLayer(
                config,
                window_size=config.window_size if i not in config.global_attn_indexes else 0,
                name=f"layers_._{i}",
            )
            # 将创建的视觉层对象添加到层列表中
            self.layers.append(layer)

        # 创建视觉颈部对象,并命名为 "neck"
        self.neck = TFSamVisionNeck(config, name="neck")
    # 构建模型,初始化模型的权重和结构
    def build(self, input_shape=None):
        # 如果模型已经构建过,则直接返回
        if self.built:
            return
        # 标记模型为已构建状态
        self.built = True
        
        # 如果配置要求使用绝对位置编码
        if self.config.use_abs_pos:
            # 初始化绝对位置嵌入,其形状与预训练图像大小相关
            self.pos_embed = self.add_weight(
                shape=[
                    1,
                    self.config.image_size // self.config.patch_size,
                    self.config.image_size // self.config.patch_size,
                    self.config.hidden_size,
                ],
                initializer="zeros",
                trainable=True,
                name="pos_embed",
            )

        # 如果已定义 patch_embed 属性,则构建其内部结构
        if getattr(self, "patch_embed", None) is not None:
            with tf.name_scope(self.patch_embed.name):
                self.patch_embed.build(None)
        
        # 如果已定义 neck 属性,则构建其内部结构
        if getattr(self, "neck", None) is not None:
            with tf.name_scope(self.neck.name):
                self.neck.build(None)
        
        # 遍历模型的所有层,并构建每一层的结构
        for layer in self.layers:
            with tf.name_scope(layer.name):
                layer.build(None)

    # 返回输入嵌入
    def get_input_embeddings(self):
        return self.patch_embed

    # 模型调用函数,用于执行前向传播
    def call(
        self,
        pixel_values: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        # 更多参数,用于控制模型行为

        # 函数参数说明:
        # pixel_values: 输入的像素张量,可以为 None
        # output_attentions: 是否输出注意力信息,可选布尔值
        # output_hidden_states: 是否输出隐藏状态信息,可选布尔值
        # return_dict: 是否返回字典格式的结果,可选布尔值
        # training: 是否处于训练模式,可选布尔值,默认为 False
        ) -> Union[Tuple, TFSamVisionEncoderOutput]:
        # 确定是否输出注意力权重
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 确定是否输出隐藏层状态
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否使用返回字典形式
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果像素值为 None,则抛出数值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 使用 patch_embed 方法处理像素值,得到隐藏状态
        hidden_states = self.patch_embed(pixel_values)
        # 如果存在位置编码,则加上位置编码
        if self.pos_embed is not None:
            hidden_states = hidden_states + self.pos_embed

        # 如果输出隐藏状态为真,则初始化存储所有隐藏状态的元组
        all_hidden_states = () if output_hidden_states else None
        # 如果输出注意力权重为真,则初始化存储所有注意力权重的元组
        all_self_attentions = () if output_attentions else None

        # 遍历所有层并逐层处理
        for i, layer_module in enumerate(self.layers):
            # 如果输出隐藏状态为真,则将当前隐藏状态添加到所有隐藏状态中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 调用当前层的处理方法,获取当前层的输出
            layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)

            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]

            # 如果输出注意力权重为真,则将当前层的注意力权重添加到所有注意力权重中
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        # 如果输出隐藏状态为真,则将最终隐藏状态添加到所有隐藏状态中
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # 使用 neck 方法处理最终的隐藏状态
        hidden_states = self.neck(hidden_states)

        # 如果不使用返回字典形式,则按顺序返回隐藏状态、所有隐藏状态、所有注意力权重
        if not return_dict:
            outputs = (hidden_states,)
            if output_hidden_states:
                outputs = outputs + (all_hidden_states,)
            if output_attentions:
                outputs = outputs + (all_self_attentions,)
            return outputs

        # 如果使用返回字典形式,则返回 TFSamVisionEncoderOutput 类的实例
        return TFSamVisionEncoderOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )
@add_start_docstrings(
    "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
    " optional 2D location and bounding boxes.",
    SAM_START_DOCSTRING,
)
class TFSamModel(TFSamPreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        # 初始化共享的图像嵌入层,使用配置中的视觉配置
        self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")

        # 初始化视觉编码器,使用配置中的视觉配置
        self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
        
        # 初始化提示编码器,使用配置中的提示编码器配置和共享的图像嵌入层
        self.prompt_encoder = TFSamPromptEncoder(
            config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
        )
        
        # 初始化掩码解码器,使用配置中的掩码解码器配置
        self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
        
        # 保存配置以供后续调用使用
        self.config = config

    def get_input_embeddings(self):
        # 获取视觉编码器的输入嵌入
        return self.vision_encoder.get_input_embeddings()

    def get_image_wide_positional_embeddings(self):
        # 获取图像广域位置嵌入

        # 图像嵌入尺寸
        size = self.config.prompt_encoder_config.image_embedding_size
        
        # 创建尺寸为(size, size)的全1张量
        grid = tf.ones((size, size))
        
        # 沿着垂直方向累积求和,并进行中心化处理
        y_embed = tf.math.cumsum(grid, axis=0) - 0.5
        
        # 沿着水平方向累积求和,并进行中心化处理
        x_embed = tf.math.cumsum(grid, axis=1) - 0.5
        
        # 将嵌入位置坐标归一化
        y_embed = y_embed / size
        x_embed = x_embed / size
        
        # 使用共享的图像嵌入层获取位置嵌入张量
        positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
        
        # 调整维度顺序为 channel x height x width,并扩展维度为 (1, channel, height, width)
        return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0)

    def get_image_embeddings(
        self,
        pixel_values,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 获取图像嵌入
        # 在实际使用中,此方法将会进一步实现
        pass
    ):
        r"""
        Returns the image embeddings by passing the pixel values through the vision encoder.

        Args:
            pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
                Input pixel values
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.

        """
        # 使用视觉编码器处理像素值,返回图像嵌入向量
        vision_output = self.vision_encoder(
            pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # 提取图像嵌入向量
        image_embeddings = vision_output[0]
        return image_embeddings

    def get_prompt_embeddings(
        self,
        input_points: tf.Tensor | None = None,
        input_labels: tf.Tensor | None = None,
        input_boxes: tf.Tensor | None = None,
        input_masks: tf.Tensor | None = None,
    ):
        r"""
        Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.

        Args:
            input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
                Optional input points for the prompt encoder. The padding of the point is automatically done by the
                processor. `point_batch_size` refers to the number of masks that we want the model to predict per
                point. The model will output `point_batch_size` times 3 masks in total.
            input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
                Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
                processor, or can be fed by the user.
            input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
                Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
                processor. users can also pass manually the input boxes.
            input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
                Optional input masks for the prompt encoder.
        """
        # 使用提示编码器处理输入的点、标签、框和掩码,返回提示嵌入向量
        prompt_output = self.prompt_encoder(
            input_points=input_points,
            input_labels=input_labels,
            input_boxes=input_boxes,
            input_masks=input_masks,
        )
        return prompt_output

    @unpack_inputs
    @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
    # 定义一个方法 `call`,用于执行模型推断或训练过程中的前向传播
    def call(
        self,
        pixel_values: TFModelInputType | None = None,
        input_points: tf.Tensor | None = None,
        input_labels: tf.Tensor | None = None,
        input_boxes: tf.Tensor | None = None,
        input_masks: tf.Tensor | None = None,
        image_embeddings: tf.Tensor | None = None,
        multimask_output: bool = True,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        training: bool = False,
        **kwargs,
    ):
        # 如果 `serving_output` 方法定义了输出值 `output: TFSamImageSegmentationOutput`
        def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
            # 如果模型配置要求输出隐藏状态,则将视觉隐藏状态转换为 TensorFlow 张量,否则设为 None
            hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
            # 如果模型配置要求输出注意力权重,则将视觉注意力转换为 TensorFlow 张量,否则设为 None
            attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None

            # 返回一个 `TFSamImageSegmentationOutput` 对象,根据模型配置决定是否包含隐藏状态和注意力权重
            return TFSamImageSegmentationOutput(
                iou_scores=output.iou_scores,
                pred_masks=output.pred_masks,
                vision_hidden_states=hs if self.config.output_hidden_states else None,
                vision_attentions=attns if self.config.output_attentions else None,
                mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
            )

        # 定义一个方法 `build`,用于构建模型
        def build(self, input_shape=None):
            # 如果模型已经构建完成,则直接返回
            if self.built:
                return
            # 将模型状态标记为已构建
            self.built = True
            # 如果模型具有 `shared_image_embedding` 属性,则构建共享图像嵌入
            if getattr(self, "shared_image_embedding", None) is not None:
                with tf.name_scope(self.shared_image_embedding.name):
                    self.shared_image_embedding.build(None)
            # 如果模型具有 `vision_encoder` 属性,则构建视觉编码器
            if getattr(self, "vision_encoder", None) is not None:
                with tf.name_scope(self.vision_encoder.name):
                    self.vision_encoder.build(None)
            # 如果模型具有 `prompt_encoder` 属性,则构建提示编码器
            if getattr(self, "prompt_encoder", None) is not None:
                with tf.name_scope(self.prompt_encoder.name):
                    self.prompt_encoder.build(None)
            # 如果模型具有 `mask_decoder` 属性,则构建掩码解码器
            if getattr(self, "mask_decoder", None) is not None:
                with tf.name_scope(self.mask_decoder.name):
                    self.mask_decoder.build(None)

.\models\sam\processing_sam.py

# coding=utf-8
# 设置文件编码为 UTF-8,确保支持多语言字符

# 版权声明和许可证信息
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processor class for SAM.
"""
# 导入深拷贝函数和类型提示模块
from copy import deepcopy
from typing import Optional, Union

# 导入 NumPy 库
import numpy as np

# 导入处理工具混合类、批编码类和工具函数
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available

# 如果可用,导入 PyTorch 库
if is_torch_available():
    import torch

# 如果可用,导入 TensorFlow 库
if is_tf_available():
    import tensorflow as tf

# SAM 处理器类,继承自处理器混合类
class SamProcessor(ProcessorMixin):
    r"""
    Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
    single processor.

    [`SamProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of
    [`~SamImageProcessor.__call__`] for more information.

    Args:
        image_processor (`SamImageProcessor`):
            An instance of [`SamImageProcessor`]. The image processor is a required input.
    """

    # 类属性:包含的属性列表
    attributes = ["image_processor"]
    # 类属性:图像处理器类名
    image_processor_class = "SamImageProcessor"

    # 初始化方法,接受图像处理器实例作为参数
    def __init__(self, image_processor):
        # 调用父类的初始化方法
        super().__init__(image_processor)
        # 设置当前处理器为图像处理器
        self.current_processor = self.image_processor
        # 设置点填充值为 -10
        self.point_pad_value = -10
        # 设置目标尺寸为图像处理器定义的最长边尺寸
        self.target_size = self.image_processor.size["longest_edge"]

    # 对象调用方法,用于处理输入数据
    def __call__(
        self,
        images=None,
        segmentation_maps=None,
        input_points=None,
        input_labels=None,
        input_boxes=None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    def _pad_points_and_labels(self, input_points, input_labels):
        """
        The method pads the 2D points and labels to the maximum number of points in the batch.
        """
        # Determine the maximum number of points among all input batches
        expected_nb_points = max([point.shape[0] for point in input_points])
        processed_input_points = []
        # Iterate over each batch of points and labels
        for i, point in enumerate(input_points):
            # If the number of points in the current batch is less than the maximum,
            # pad with zeros up to the maximum number of points
            if point.shape[0] != expected_nb_points:
                point = np.concatenate(
                    [point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
                )
                # Append padding values to the corresponding labels array
                input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
            processed_input_points.append(point)
        input_points = processed_input_points
        return input_points, input_labels

    def _normalize_coordinates(
        self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False
    ):
        """
        Normalize coordinates based on the target size and original size.
        """
        # Depending on whether the coordinates are for bounding boxes or not, calculate the scaling factor
        if is_bounding_box:
            scale = max(original_size) / float(target_size)
        else:
            scale = float(target_size) / max(original_size)
        # Normalize coordinates by scaling
        coords[:, 0::2] *= scale
        coords[:, 1::2] *= scale
        return coords
    ) -> np.ndarray:
        """
        Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H, W) format.
        """
        # 从参数 original_size 中获取原始图像的高度和宽度
        old_h, old_w = original_size
        # 使用 self.image_processor._get_preprocess_shape 方法获取预处理后的图像尺寸
        new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size)
        # 深拷贝坐标数组,并将其转换为浮点数类型
        coords = deepcopy(coords).astype(float)

        # 如果 is_bounding_box 为 True,则将坐标数组重新整形为 (N, 2, 2) 的形式
        if is_bounding_box:
            coords = coords.reshape(-1, 2, 2)

        # 将坐标数组中 x 轴坐标缩放至新的宽度比例
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        # 将坐标数组中 y 轴坐标缩放至新的高度比例
        coords[..., 1] = coords[..., 1] * (new_h / old_h)

        # 如果 is_bounding_box 为 True,则将坐标数组重新整形为 (N, 4) 的形式
        if is_bounding_box:
            coords = coords.reshape(-1, 4)

        # 返回处理后的坐标数组
        return coords

    def _check_and_preprocess_points(
        self,
        input_points=None,
        input_labels=None,
        input_boxes=None,
    ):
        r"""
        Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they
        are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`,
        it is converted to a `numpy.ndarray` and then to a `list`.
        """
        # 如果 input_points 不为 None,则进行下列处理
        if input_points is not None:
            # 如果 input_points 具有 "numpy" 属性,说明是 Torch 或 TF 张量,将其转换为 numpy.ndarray 再转换为 list
            if hasattr(input_points, "numpy"):  # Checks for TF or Torch tensor
                input_points = input_points.numpy().tolist()

            # 检查 input_points 是否为有效格式,必须是浮点数列表的列表
            if not isinstance(input_points, list) or not isinstance(input_points[0], list):
                raise ValueError("Input points must be a list of list of floating points.")
            # 将每个 input_points 转换为 numpy 数组
            input_points = [np.array(input_point) for input_point in input_points]
        else:
            input_points = None

        # 如果 input_labels 不为 None,则进行下列处理
        if input_labels is not None:
            # 如果 input_labels 具有 "numpy" 属性,将其转换为 list
            if hasattr(input_labels, "numpy"):
                input_labels = input_labels.numpy().tolist()

            # 检查 input_labels 是否为有效格式,必须是整数列表的列表
            if not isinstance(input_labels, list) or not isinstance(input_labels[0], list):
                raise ValueError("Input labels must be a list of list integers.")
            # 将每个 input_labels 转换为 numpy 数组
            input_labels = [np.array(label) for label in input_labels]
        else:
            input_labels = None

        # 如果 input_boxes 不为 None,则进行下列处理
        if input_boxes is not None:
            # 如果 input_boxes 具有 "numpy" 属性,将其转换为 list
            if hasattr(input_boxes, "numpy"):
                input_boxes = input_boxes.numpy().tolist()

            # 检查 input_boxes 是否为有效格式,必须是浮点数列表的列表的列表
            if (
                not isinstance(input_boxes, list)
                or not isinstance(input_boxes[0], list)
                or not isinstance(input_boxes[0][0], list)
            ):
                raise ValueError("Input boxes must be a list of list of list of floating points.")
            # 将每个 input_boxes 转换为 numpy 数组,并指定数据类型为 np.float32
            input_boxes = [np.array(box).astype(np.float32) for box in input_boxes]
        else:
            input_boxes = None

        # 返回处理后的 input_points, input_labels, input_boxes
        return input_points, input_labels, input_boxes

    @property
    def model_input_names(self):
        # 获取 self.image_processor 中的 model_input_names 属性,并去重后返回为列表
        image_processor_input_names = self.image_processor.model_input_names
        return list(dict.fromkeys(image_processor_input_names))

    def post_process_masks(self, *args, **kwargs):
        # 调用 self.image_processor 中的 post_process_masks 方法,将参数传递并返回其结果
        return self.image_processor.post_process_masks(*args, **kwargs)

.\models\sam\__init__.py

# 引入类型检查模块
from typing import TYPE_CHECKING

# 引入所需的依赖和模块
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_tf_available,
    is_torch_available,
    is_vision_available,
)

# 定义模块的导入结构
_import_structure = {
    "configuration_sam": [
        "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "SamConfig",
        "SamMaskDecoderConfig",
        "SamPromptEncoderConfig",
        "SamVisionConfig",
    ],
    "processing_sam": ["SamProcessor"],
}

# 检查是否可以导入 torch,若不可用则引发异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加相关模块到导入结构
    _import_structure["modeling_sam"] = [
        "SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "SamModel",
        "SamPreTrainedModel",
    ]

# 检查是否可以导入 tensorflow,若不可用则引发异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加相关模块到导入结构
    _import_structure["modeling_tf_sam"] = [
        "TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
        "TFSamModel",
        "TFSamPreTrainedModel",
    ]

# 检查是否可以导入视觉处理模块,若不可用则引发异常
try:
    if not is_vision_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则添加相关模块到导入结构
    _import_structure["image_processing_sam"] = ["SamImageProcessor"]

# 如果是类型检查阶段
if TYPE_CHECKING:
    # 导入配置相关的类
    from .configuration_sam import (
        SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
        SamConfig,
        SamMaskDecoderConfig,
        SamPromptEncoderConfig,
        SamVisionConfig,
    )
    # 导入处理相关的类
    from .processing_sam import SamProcessor

    # 检查是否可以导入 torch,若不可用则跳过
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 torch 模型相关的类
        from .modeling_sam import SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, SamPreTrainedModel

    # 检查是否可以导入 tensorflow,若不可用则跳过
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 tensorflow 模型相关的类
        from .modeling_tf_sam import TF_SAM_PRETRAINED_MODEL_ARCHIVE_LIST, TFSamModel, TFSamPreTrainedModel

    # 检查是否可以导入视觉处理模块,若不可用则跳过
    try:
        if not is_vision_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入视觉处理模块相关的类
        from .image_processing_sam import SamImageProcessor

# 如果不是类型检查阶段,则进行懒加载模块的设置
else:
    import sys

    # 将当前模块设为懒加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\seamless_m4t\configuration_seamless_m4t.py

# 设置文件编码为 UTF-8
# 版权声明,2023 年由 HuggingFace Inc. 团队保留所有权利
#
# 根据 Apache 许可证 2.0 版本授权使用该文件
# 除非符合许可证规定,否则不得使用此文件
# 您可以在以下网址获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按"原样"分发本软件
# 不提供任何明示或暗示的担保或条件
# 请查看许可证以获取更多详细信息
""" SeamlessM4T 模型配置"""

# 从 transformers 库导入预训练配置类 PretrainedConfig
# 从 utils 模块导入 logging 函数
from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 映射预训练配置文件的 URL 地址
SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "facebook/hf-seamless-m4t-medium": "https://huggingface.co/facebook/hf-seamless-m4t-medium/resolve/main/config.json",
    # 查看所有 SeamlessM4T 模型的地址:https://huggingface.co/models?filter=seamless_m4t
}

# SeamlessM4TConfig 类,继承自 PretrainedConfig 类
class SeamlessM4TConfig(PretrainedConfig):
    r"""
    这是一个配置类,用于存储 [`~SeamlessM4TModel`] 的配置。根据指定的参数实例化一个 SeamlessM4T 模型配置,
    定义模型的架构。使用默认配置实例化将产生类似于 SeamlessM4T
    ["facebook/hf-seamless-m4t-medium"](https://huggingface.co/"facebook/hf-seamless-m4t-medium") 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型的输出。阅读 [`PretrainedConfig`] 的文档获取更多信息。

    ```
    >>> from transformers import SeamlessM4TModel, SeamlessM4TConfig

    >>> # 初始化 SeamlessM4T "facebook/hf-seamless-m4t-medium" 风格的配置
    >>> configuration = SeamlessM4TConfig()

    >>> # 根据 "facebook/hf-seamless-m4t-medium" 风格的配置初始化模型
    >>> model = SeamlessM4TModel(configuration)

    >>> # 访问模型配置
    >>> configuration = model.config
    ```
    """

    # 模型类型设为 "seamless_m4t"
    model_type = "seamless_m4t"
    # 初始化函数,用于创建一个新的对象实例
    def __init__(
        self,
        vocab_size=256102,  # 词汇表大小,默认为 256102
        t2u_vocab_size=10082,  # t2u 词汇表大小,默认为 10082

        # 共享配置
        hidden_size=1024,  # 隐藏层大小,默认为 1024
        initializer_range=0.02,  # 初始化范围,默认为 0.02
        layer_norm_eps=1e-5,  # 层归一化的 epsilon,默认为 1e-5
        use_cache=True,  # 是否使用缓存,默认为 True
        max_position_embeddings=1024,  # 最大位置嵌入数,默认为 1024
        is_encoder_decoder=True,  # 是否为编码器-解码器模型,默认为 True
        encoder_layerdrop=0.05,  # 编码器层的层丢弃率,默认为 0.05
        decoder_layerdrop=0.05,  # 解码器层的层丢弃率,默认为 0.05
        activation_function="relu",  # 激活函数,默认为 "relu"
        dropout=0.1,  # 普通的 dropout 率,默认为 0.1
        attention_dropout=0.1,  # 注意力 dropout 率,默认为 0.1
        activation_dropout=0.0,  # 激活函数的 dropout 率,默认为 0.0
        scale_embedding=True,  # 是否缩放嵌入,默认为 True

        # 文本编码器|解码器配置
        encoder_layers=24,  # 编码器层数,默认为 24
        encoder_ffn_dim=8192,  # 编码器 FFN 维度,默认为 8192
        encoder_attention_heads=16,  # 编码器注意力头数,默认为 16
        decoder_layers=24,  # 解码器层数,默认为 24
        decoder_ffn_dim=8192,  # 解码器 FFN 维度,默认为 8192
        decoder_attention_heads=16,  # 解码器注意力头数,默认为 16
        decoder_start_token_id=3,  # 解码器起始标记 ID,默认为 3
        max_new_tokens=256,  # 最大新标记数,默认为 256
        pad_token_id=0,  # 填充标记 ID,默认为 0
        bos_token_id=2,  # 开始标记 ID,默认为 2
        eos_token_id=3,  # 结束标记 ID,默认为 3

        # 语音编码器配置
        speech_encoder_layers=24,  # 语音编码器层数,默认为 24
        speech_encoder_attention_heads=16,  # 语音编码器注意力头数,默认为 16
        speech_encoder_intermediate_size=4096,  # 语音编码器中间层大小,默认为 4096
        speech_encoder_hidden_act="swish",  # 语音编码器隐藏层激活函数,默认为 "swish"
        speech_encoder_dropout=0.0,  # 语音编码器 dropout 率,默认为 0.0
        add_adapter=True,  # 是否添加适配器,默认为 True
        speech_encoder_layerdrop=0.1,  # 语音编码器层的层丢弃率,默认为 0.1
        feature_projection_input_dim=160,  # 特征投影输入维度,默认为 160
        num_conv_pos_embeddings=128,  # 卷积位置嵌入数,默认为 128
        num_conv_pos_embedding_groups=16,  # 卷积位置嵌入分组数,默认为 16
        adaptor_kernel_size=8,  # 适配器卷积核大小,默认为 8
        adaptor_stride=8,  # 适配器卷积步长,默认为 8
        adaptor_dropout=0.1,  # 适配器 dropout 率,默认为 0.1
        num_adapter_layers=1,  # 适配器层数,默认为 1
        position_embeddings_type="relative",  # 位置嵌入类型,默认为 "relative"
        rotary_embedding_base=10000,  # 旋转嵌入基数,默认为 10000
        max_source_positions=4096,  # 最大源位置数,默认为 4096
        conv_depthwise_kernel_size=31,  # 深度卷积核大小,默认为 31

        # t2u 配置
        t2u_bos_token_id=0,  # t2u 开始标记 ID,默认为 0
        t2u_pad_token_id=1,  # t2u 填充标记 ID,默认为 1
        t2u_eos_token_id=2,  # t2u 结束标记 ID,默认为 2
        t2u_decoder_start_token_id=2,  # t2u 解码器起始标记 ID,默认为 2
        t2u_max_new_tokens=1024,  # t2u 最大新标记数,默认为 1024
        t2u_encoder_layers=6,  # t2u 编码器层数,默认为 6
        t2u_encoder_ffn_dim=8192,  # t2u 编码器 FFN 维度,默认为 8192
        t2u_encoder_attention_heads=16,  # t2u 编码器注意力头数,默认为 16
        t2u_decoder_layers=6,  # t2u 解码器层数,默认为 6
        t2u_decoder_ffn_dim=8192,  # t2u 解码器 FFN 维度,默认为 8192
        t2u_decoder_attention_heads=16,  # t2u 解码器注意力头数,默认为 16
        t2u_max_position_embeddings=2048,  # t2u 最大位置嵌入数,默认为 2048

        # hifi-gan 语音合成器配置
        sampling_rate=16000,  # 采样率,默认为 16000
        upsample_initial_channel=512,  # 上采样初始通道数,默认为 512
        upsample_rates=[5, 4, 4, 2, 2],  # 上采样倍率列表,默认为 [5, 4, 4, 2, 2]
        upsample_kernel_sizes=[11, 8, 8, 4, 4],  # 上采样卷积核大小列表,默认为 [11, 8, 8, 4, 4]
        resblock_kernel_sizes=[3, 7, 11],  # ResBlock 卷积核大小列表,默认为 [3, 7, 11]
        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],  # ResBlock 扩张大小列表,默认为 [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
        leaky_relu_slope=0.1,  # Leaky ReLU 斜率,默认为 0.1

        # 特定于 Code Hifi-Gan 的配置
        unit_hifi_gan_vocab_size=10000,  # Hifi-Gan 单元词汇表大小,默认为 10000
        unit_embed_dim=1280,  # 单元嵌入维度,默认为 1280
        lang_embed_dim=256,  # 语言嵌入维度,默认为 256
        spkr_embed_dim=256,  # 说话人嵌入维度,默认为 256
        vocoder_num_langs=36,  # 语音合成器支持的语言数,默认为 36
        vocoder_num_spkrs=200,  # 语音合成器支持的说话人数,默认为 200
        variance_predictor_kernel_size=3,  # 方差预测器卷积核大小,默认为 3
        var_pred_dropout=0.5,  # 方差预测器 dropout 率,默认为 0.5
        vocoder_offset=4,  # 语音合成器偏移量,默认为 4
        **kwargs,  # 其他参数,使用字典方式接收
posted @ 2024-06-29 16:56  绝不原创的飞龙  阅读(4)  评论(0编辑  收藏  举报