Transformers-源码解析-六十-

Transformers 源码解析(六十)

.\models\instructblip\processing_instructblip.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
"""

import os
from typing import List, Optional, Union

# Importing necessary modules from the package
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ..auto import AutoTokenizer


class InstructBlipProcessor(ProcessorMixin):
    r"""
    Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
    processor.

    [`InstructBlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
    docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.

    Args:
        image_processor (`BlipImageProcessor`):
            An instance of [`BlipImageProcessor`]. The image processor is a required input.
        tokenizer (`AutoTokenizer`):
            An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
        qformer_tokenizer (`AutoTokenizer`):
            An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
    """

    # Define class attributes
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "BlipImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor, tokenizer, qformer_tokenizer):
        # Call the constructor of the superclass (ProcessorMixin)
        super().__init__(image_processor, tokenizer)

        # Initialize QFormer tokenizer attribute
        self.qformer_tokenizer = qformer_tokenizer
    # 定义一个特殊方法 __call__,使对象可以像函数一样被调用
    def __call__(
        # 参数:images,接受图像输入,可以是单个图像或图像列表
        self,
        images: ImageInput = None,
        # 参数:text,接受文本输入,可以是单个文本、预分词文本或它们的列表
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        # 参数:add_special_tokens,是否添加特殊令牌
        add_special_tokens: bool = True,
        # 参数:padding,是否进行填充,可以是布尔值、字符串或填充策略对象
        padding: Union[bool, str, PaddingStrategy] = False,
        # 参数:truncation,是否进行截断,可以是布尔值、字符串或截断策略对象
        truncation: Union[bool, str, TruncationStrategy] = None,
        # 参数:max_length,返回的最大长度
        max_length: Optional[int] = None,
        # 参数:stride,步进大小,默认为0
        stride: int = 0,
        # 参数:pad_to_multiple_of,填充到的倍数大小
        pad_to_multiple_of: Optional[int] = None,
        # 参数:return_attention_mask,是否返回注意力掩码
        return_attention_mask: Optional[bool] = None,
        # 参数:return_overflowing_tokens,是否返回溢出的令牌
        return_overflowing_tokens: bool = False,
        # 参数:return_special_tokens_mask,是否返回特殊令牌掩码
        return_special_tokens_mask: bool = False,
        # 参数:return_offsets_mapping,是否返回偏移映射
        return_offsets_mapping: bool = False,
        # 参数:return_token_type_ids,是否返回令牌类型ID
        return_token_type_ids: bool = False,
        # 参数:return_length,是否返回长度
        return_length: bool = False,
        # 参数:verbose,是否输出详细信息,默认为True
        verbose: bool = True,
        # 参数:return_tensors,返回的张量类型,可以是字符串或张量类型对象
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 其他关键字参数
        **kwargs,
    ) -> BatchFeature:
        """
        使用 [`BlipImageProcessor.__call__`] 方法准备模型的图像数据,
        和 [`BertTokenizerFast.__call__`] 方法准备模型的文本数据。

        更多信息请参考上述两个方法的文档字符串。
        """
        # 如果既没有图像也没有文本,抛出数值错误异常
        if images is None and text is None:
            raise ValueError("You have to specify at least images or text.")

        # 创建一个空的 BatchFeature 对象用于存储编码后的数据
        encoding = BatchFeature()

        # 如果有文本输入,则使用 tokenizer 对文本进行编码
        if text is not None:
            text_encoding = self.tokenizer(
                text=text,
                add_special_tokens=add_special_tokens,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                stride=stride,
                pad_to_multiple_of=pad_to_multiple_of,
                return_attention_mask=return_attention_mask,
                return_overflowing_tokens=return_overflowing_tokens,
                return_special_tokens_mask=return_special_tokens_mask,
                return_offsets_mapping=return_offsets_mapping,
                return_token_type_ids=return_token_type_ids,
                return_length=return_length,
                verbose=verbose,
                return_tensors=return_tensors,
                **kwargs,
            )
            # 将文本编码结果更新到 encoding 对象中
            encoding.update(text_encoding)

            # 使用 qformer_tokenizer 对文本进行编码
            qformer_text_encoding = self.qformer_tokenizer(
                text=text,
                add_special_tokens=add_special_tokens,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                stride=stride,
                pad_to_multiple_of=pad_to_multiple_of,
                return_attention_mask=return_attention_mask,
                return_overflowing_tokens=return_overflowing_tokens,
                return_special_tokens_mask=return_special_tokens_mask,
                return_offsets_mapping=return_offsets_mapping,
                return_token_type_ids=return_token_type_ids,
                return_length=return_length,
                verbose=verbose,
                return_tensors=return_tensors,
                **kwargs,
            )
            # 将 qformer_tokenizer 的输入 ID 更新到 encoding 对象中
            encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
            # 将 qformer_tokenizer 的 attention mask 更新到 encoding 对象中
            encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")

        # 如果有图像输入,则使用 image_processor 对图像进行处理
        if images is not None:
            image_encoding = self.image_processor(images, return_tensors=return_tensors)
            # 将图像处理结果更新到 encoding 对象中
            encoding.update(image_encoding)

        # 返回编码后的数据对象 encoding
        return encoding

    # 从 transformers.models.blip.processing_blip.BlipProcessor.batch_decode 复制,使用 BertTokenizerFast->PreTrainedTokenizer
    def batch_decode(self, *args, **kwargs):
        """
        此方法将所有参数转发给 PreTrainedTokenizer 的 [`~PreTrainedTokenizer.batch_decode`] 方法。
        详细信息请参考该方法的文档字符串。
        """
        # 调用 tokenizer 的 batch_decode 方法,并返回结果
        return self.tokenizer.batch_decode(*args, **kwargs)
    # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        # 调用 PreTrainedTokenizer 的 decode 方法,将所有参数转发到该方法
        return self.tokenizer.decode(*args, **kwargs)

    @property
    # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
    def model_input_names(self):
        # 获取 tokenizer 的模型输入名称列表
        tokenizer_input_names = self.tokenizer.model_input_names
        # 获取 image_processor 的模型输入名称列表
        image_processor_input_names = self.image_processor.model_input_names
        # 返回去重后的 tokenizer 和 image_processor 的模型输入名称列表组成的列表
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

    # overwrite to save the Q-Former tokenizer in a separate folder
    def save_pretrained(self, save_directory, **kwargs):
        # 如果 save_directory 是一个文件而不是文件夹,则抛出错误
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        # 创建 save_directory 文件夹(如果不存在)
        os.makedirs(save_directory, exist_ok=True)
        # 定义 Q-Former tokenizer 的保存路径
        qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
        # 将 Q-Former tokenizer 保存到指定路径
        self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path)
        # 调用父类的 save_pretrained 方法,保存模型到 save_directory
        return super().save_pretrained(save_directory, **kwargs)

    # overwrite to load the Q-Former tokenizer from a separate folder
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 从指定路径加载 Q-Former tokenizer
        qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
        # 获取从预训练模型名或路径中获取的参数
        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
        # 将加载的 Q-Former tokenizer 添加到参数列表末尾
        args.append(qformer_tokenizer)
        # 使用参数实例化当前类,并返回
        return cls(*args)

.\models\instructblip\__init__.py

# 版权声明和版权信息
# Copyright 2023 The HuggingFace Team. All rights reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义的异常类和模块加载延迟处理类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_instructblip": [
        "INSTRUCTBLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "InstructBlipConfig",
        "InstructBlipQFormerConfig",
        "InstructBlipVisionConfig",
    ],
    "processing_instructblip": ["InstructBlipProcessor"],
}

# 检查是否可以导入 torch,如果不可以,则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可以导入 torch,则添加以下模块到导入结构中
    _import_structure["modeling_instructblip"] = [
        "INSTRUCTBLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
        "InstructBlipQFormerModel",
        "InstructBlipPreTrainedModel",
        "InstructBlipForConditionalGeneration",
        "InstructBlipVisionModel",
    ]

# 如果是类型检查阶段,进行详细的导入
if TYPE_CHECKING:
    # 导入配置相关的类和变量
    from .configuration_instructblip import (
        INSTRUCTBLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
        InstructBlipConfig,
        InstructBlipQFormerConfig,
        InstructBlipVisionConfig,
    )
    # 导入处理相关的类
    from .processing_instructblip import InstructBlipProcessor

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入模型相关的类和变量
        from .modeling_instructblip import (
            INSTRUCTBLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
            InstructBlipForConditionalGeneration,
            InstructBlipPreTrainedModel,
            InstructBlipQFormerModel,
            InstructBlipVisionModel,
        )

# 如果不是类型检查阶段,将模块指定给自定义的 LazyModule 类
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\jukebox\configuration_jukebox.py

# coding=utf-8
# 版权 2022 年 OpenAI 团队和 HuggingFace Inc. 团队。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,软件按“原样”分发,
# 不附带任何明示或暗示的保证或条件。
# 有关详细信息,请参阅许可证。
""" Jukebox 配置 """

import os
from typing import List, Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# Jukebox 预训练配置文件映射
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "openai/jukebox-5b-lyrics": "https://huggingface.co/openai/jukebox-5b-lyrics/blob/main/config.json",
    "openai/jukebox-1b-lyrics": "https://huggingface.co/openai/jukebox-1b-lyrics/blob/main/config.json",
}

# 大型注意力列表
_LARGE_ATTENTION = [
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
]
# 定义三个注意力模式的名称列表
_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
# 定义全连接密集注意力模式的名称列表
_FullDenseAttention = ["dense_attention"]
# 定义Prime-Prime-Dense注意力模式的名称列表
_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]


# 定义函数,返回全连接密集注意力模式的名称
def full_dense_attention(layer):
    return _FullDenseAttention[0]


# 定义函数,根据层索引返回RawColumnPreviousRowAttention模式的名称
def raw_column_previous_row_attention(layer):
    return _RawColumnPreviousRowAttention[layer % 3]


# 定义函数,根据层索引返回large separated enc dec w lyrics模式的名称
def large_separated_enc_dec_w_lyrics(layer):
    return _LARGE_ATTENTION[layer % 79]  # _LARGE_ATTENTION未定义,可能存在错误


# 定义函数,根据层索引返回enc dec with lyrics模式的名称
def enc_dec_with_lyrics(layer):
    if layer % 16 == 15:
        return _PrimePrimeDenseAttention[layer % 3]
    return _RawColumnPreviousRowAttention[layer % 3]


# 定义全局变量,包含不同注意力模式的名称及其对应的函数引用
ATTENTION_PATTERNS = {
    "full_dense_attention": full_dense_attention,
    "raw_column_previous_row_attention": raw_column_previous_row_attention,  # 用于替换行、列和上一行注意力
    "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics,  # 用于带歌词的大型分离enc dec模型
    "enc_dec_with_lyrics": enc_dec_with_lyrics,  # 用于带歌词的编码器-解码器模型
}


# 定义配置类,存储JukeboxPrior模型的配置信息
class JukeboxPriorConfig(PretrainedConfig):
    """
        This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
        `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
        configuration with the defaults will yield a similar configuration to that of the top level prior from the
        [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
    -1b-lyrics) architecture.

        Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
        documentation from [`PretrainedConfig`] for more information.



    """

    # 模型类型
    model_type = "jukebox_prior"
    # 属性映射字典,映射配置项到模型的实际参数名
    attribute_map = {
        "max_position_embeddings": "n_positions",  # 最大位置嵌入数对应的模型参数名
        "num_attention_heads": "n_head",  # 注意力头数对应的模型参数名
    }
    # 定义初始化函数,用于创建一个对象实例
    def __init__(
        self,
        act_fn="quick_gelu",  # 激活函数名称,默认为 "quick_gelu"
        level=0,  # 模型的层级,默认为 0
        alignment_head=2,  # 对齐头部参数,默认为 2
        alignment_layer=68,  # 对齐层参数,默认为 68
        attention_multiplier=0.25,  # 注意力乘子,默认为 0.25
        attention_pattern="enc_dec_with_lyrics",  # 注意力模式,默认为 "enc_dec_with_lyrics"
        attn_dropout=0,  # 注意力部分的 dropout 概率,默认为 0
        attn_res_scale=False,  # 注意力残差比例,默认为 False
        blocks=64,  # 块数,默认为 64
        conv_res_scale=None,  # 卷积残差比例,默认为 None
        num_layers=72,  # 层数,默认为 72
        emb_dropout=0,  # 嵌入部分的 dropout 概率,默认为 0
        encoder_config=None,  # 编码器配置信息,默认为 None
        encoder_loss_fraction=0.4,  # 编码器损失分数,默认为 0.4
        hidden_size=2048,  # 隐藏层大小,默认为 2048
        init_scale=0.2,  # 初始化比例,默认为 0.2
        is_encoder_decoder=True,  # 是否是编码解码模型,默认为 True
        lyric_vocab_size=80,  # 歌词词汇量大小,默认为 80
        mask=False,  # 是否使用掩码,默认为 False
        max_duration=600,  # 最大持续时间,默认为 600
        max_nb_genres=1,  # 最大音乐类型数,默认为 1
        merged_decoder=True,  # 是否合并解码器,默认为 True
        metadata_conditioning=True,  # 是否使用元数据条件,默认为 True
        metadata_dims=[604, 7898],  # 元数据维度,默认为 [604, 7898]
        min_duration=0,  # 最小持续时间,默认为 0
        mlp_multiplier=1.0,  # 多层感知机乘数,默认为 1.0
        music_vocab_size=2048,  # 音乐词汇量大小,默认为 2048
        n_ctx=6144,  # 上下文大小,默认为 6144
        n_heads=2,  # 多头注意力的头数,默认为 2
        nb_relevant_lyric_tokens=384,  # 相关歌词标记数,默认为 384
        res_conv_depth=3,  # 残余卷积深度,默认为 3
        res_conv_width=128,  # 残余卷积宽度,默认为 128
        res_convolution_multiplier=1,  # 残余卷积乘数,默认为 1
        res_dilation_cycle=None,  # 残余扩张周期,默认为 None
        res_dilation_growth_rate=1,  # 残余扩张增长率,默认为 1
        res_downs_t=[3, 2, 2],  # 残余下采样时序,默认为 [3, 2, 2]
        res_strides_t=[2, 2, 2],  # 残余步长时序,默认为 [2, 2, 2]
        resid_dropout=0,  # 残余 dropout 概率,默认为 0
        sampling_rate=44100,  # 采样率,默认为 44100
        spread=None,  # 传播参数,默认为 None
        timing_dims=64,  # 时间维度,默认为 64
        zero_out=False,  # 是否清零,默认为 False
        **kwargs,  # 其他关键字参数,用于捕获未指定的关键字参数
    ):
        ):
        # 初始化函数,接受多个参数并将它们赋值给对象的属性
        self.act_fn = act_fn
        self.alignment_head = alignment_head
        self.alignment_layer = alignment_layer
        self.attention_multiplier = attention_multiplier
        self.attention_pattern = attention_pattern
        self.attn_dropout = attn_dropout
        self.attn_res_scale = attn_res_scale
        self.blocks = blocks
        self.conv_res_scale = conv_res_scale
        self.num_layers = num_layers
        self.emb_dropout = emb_dropout
        self.music_vocab_size = music_vocab_size
        # 如果提供了编码器配置,将其转换为 JukeboxPriorConfig 对象
        if encoder_config is not None:
            self.encoder_config = JukeboxPriorConfig(**encoder_config)
        else:
            self.encoder_config = None
        self.encoder_loss_fraction = encoder_loss_fraction
        self.init_scale = init_scale
        self.is_encoder_decoder = is_encoder_decoder
        self.lyric_vocab_size = lyric_vocab_size
        self.level = level
        self.mask = mask
        self.max_duration = max_duration
        self.max_nb_genres = max_nb_genres
        self.merged_decoder = merged_decoder
        self.metadata_conditioning = metadata_conditioning
        self.metadata_dims = metadata_dims
        self.min_duration = min_duration
        self.mlp_multiplier = mlp_multiplier
        self.n_ctx = n_ctx
        self.n_heads = n_heads
        self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
        self.res_conv_depth = res_conv_depth
        self.res_conv_width = res_conv_width
        self.res_convolution_multiplier = res_convolution_multiplier
        self.res_dilation_cycle = res_dilation_cycle
        self.res_dilation_growth_rate = res_dilation_growth_rate
        self.res_downs_t = res_downs_t
        self.res_strides_t = res_strides_t
        self.resid_dropout = resid_dropout
        self.sampling_rate = sampling_rate
        self.spread = spread
        self.timing_dims = timing_dims
        self.hidden_size = hidden_size
        self.zero_out = zero_out
        # 设置对象的初始化完成标志

    @classmethod
    def from_pretrained(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
    ) -> "PretrainedConfig":
        # 设置传递给类方法的特殊标记位
        cls._set_token_in_kwargs(kwargs)

        # 获取配置字典和更新后的 kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典指定模型类型为 "jukebox",则使用对应级别的先验配置
        if config_dict.get("model_type") == "jukebox":
            config_dict = config_dict[f"prior_{level}"]

        # 检查配置字典中的模型类型是否与类的模型类型匹配,如果不匹配则发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 从配置字典创建并返回一个新的 PretrainedConfig 对象实例
        return cls.from_dict(config_dict, **kwargs)
# 定义 JukeboxVQVAEConfig 类,继承自 PretrainedConfig 类,用于存储 JukeboxVQVAE 模型的配置信息
class JukeboxVQVAEConfig(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
    `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the VQVAE from
    [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 定义一个函数,用于构建 VQVAE 模型
    def build_model(
        act_fn: str = "relu",  # 激活函数,默认为 ReLU
        nb_discrete_codes: int = 2048,  # VQVAE 的离散码数量,默认为 2048
        commit: float = 0.02,  # Commit loss 的乘数,默认为 0.02
        conv_input_shape: int = 1,  # 音频通道数,默认为 1
        conv_res_scale: bool = False,  # 是否缩放 JukeboxResConv1DBlock 的残差,默认为 False
        embed_dim: int = 64,  # Codebook 向量的嵌入维度,默认为 64
        hop_fraction: List[int] = [0.125, 0.5, 0.5],  # 进行采样过程时使用的非交叠窗口的分数列表,默认为 [0.125, 0.5, 0.5]
        levels: int = 3,  # 在 VQVAE 中使用的层级数,默认为 3
        lmu: float = 0.99,  # 用于代码本更新的指数移动平均系数,默认为 0.99
        multipliers: List[int] = [2, 1, 1],  # 每个层级使用的深度和宽度乘数列表,默认为 [2, 1, 1]
        res_conv_depth: int = 4,  # 编码器和解码器块的深度,默认为 4
        res_conv_width: int = 32,  # 编码器和解码器块的宽度,默认为 32
        res_convolution_multiplier: int = 1,  # JukeboxResConv1DBlock 中隐藏维度的缩放因子,默认为 1
        res_dilation_cycle: int = None,  # JukeboxResnet 中使用的扩张周期值,默认为 None
        res_dilation_growth_rate: int = 3,  # VQVAE 中使用的 ResNet 扩张增长率,默认为 3
        res_downs_t: List[int] = [3, 2, 2],  # 分层 VQ-VAE 中每个层级的下采样率列表,默认为 [3, 2, 2]
        res_strides_t: List[int] = [2, 2, 2],  # 分层 VQ-VAE 中每个层级的步长列表,默认为 [2, 2, 2]
        sample_length: int = 1058304,  # VQVAE 的最大输入形状,默认为 1058304
        init_scale: float = 0.2,  # 初始化尺度,默认为 0.2
        zero_out: bool = False,  # 初始化时是否将卷积权重置零,默认为 False
    ):
        """
        构建 VQVAE 模型,根据给定的参数设置各种配置和参数。
        """
        # 函数体为空,用于声明函数的开始
        pass
    # 设定模型类型为 "jukebox_vqvae"
    model_type = "jukebox_vqvae"
    
    # 定义类的初始化方法,接受多个参数
    def __init__(
        self,
        act_fn="relu",  # 激活函数,默认为 relu
        nb_discrete_codes=2048,  # 离散代码数量,默认为 2048
        commit=0.02,  # commit 参数,默认为 0.02
        conv_input_shape=1,  # 卷积输入形状,默认为 1
        conv_res_scale=False,  # 是否使用卷积残差缩放,默认为 False
        embed_dim=64,  # 嵌入维度,默认为 64
        hop_fraction=[0.125, 0.5, 0.5],  # hop fraction 列表,默认值为 [0.125, 0.5, 0.5]
        levels=3,  # 级别数量,默认为 3
        lmu=0.99,  # lmu 参数,默认为 0.99
        multipliers=[2, 1, 1],  # 多重因子列表,默认为 [2, 1, 1]
        res_conv_depth=4,  # 卷积深度,默认为 4
        res_conv_width=32,  # 卷积宽度,默认为 32
        res_convolution_multiplier=1,  # 卷积乘数,默认为 1
        res_dilation_cycle=None,  # 膨胀周期,默认为 None
        res_dilation_growth_rate=3,  # 膨胀增长率,默认为 3
        res_downs_t=[3, 2, 2],  # 下采样 t 列表,默认为 [3, 2, 2]
        res_strides_t=[2, 2, 2],  # 步幅 t 列表,默认为 [2, 2, 2]
        sample_length=1058304,  # 样本长度,默认为 1058304
        init_scale=0.2,  # 初始化规模,默认为 0.2
        zero_out=False,  # 是否置零,默认为 False
        **kwargs,  # 其他关键字参数
    ):
        self.hop_fraction = hop_fraction  # 设置类属性 hop_fraction
        self.conv_input_shape = conv_input_shape  # 设置类属性 conv_input_shape
        self.sample_length = sample_length  # 设置类属性 sample_length
    
        # 设置 VQVAE 参数(全部使用)
        self.levels = levels
        self.embed_dim = embed_dim
        self.nb_discrete_codes = nb_discrete_codes
        self.res_conv_width = res_conv_width
        self.res_conv_depth = res_conv_depth
        self.res_convolution_multiplier = res_convolution_multiplier
        self.res_dilation_growth_rate = res_dilation_growth_rate
        self.res_dilation_cycle = res_dilation_cycle
        self.multipliers = multipliers
        self.res_downs_t = res_downs_t
        self.res_strides_t = res_strides_t
        self.lmu = lmu
        self.commit = commit
        self.conv_res_scale = conv_res_scale
        self.act_fn = act_fn
        self.init_scale = init_scale
        self.zero_out = zero_out
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        cls._set_token_in_kwargs(kwargs)  # 在关键字参数中设置令牌
    
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)  # 获取配置字典和更新后的关键字参数
    
        # 如果加载的是 CLIPConfig,获取文本配置字典
        if config_dict.get("model_type") == "jukebox":
            config_dict = config_dict["vqvae_config"]
    
        # 检查配置字典中的模型类型是否与类中定义的模型类型一致
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )
    
        # 从配置字典和关键字参数实例化类并返回
        return cls.from_dict(config_dict, **kwargs)
class JukeboxConfig(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a [`JukeboxModel`].

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
    yield a similar configuration to that of
    [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.


    The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
    (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
    to get the second level codes. This is mostly true for training the top level prior and the upsamplers.

    Args:
        vqvae_config (`JukeboxVQVAEConfig`, *optional*):
            Configuration for the `JukeboxVQVAE` model.
        prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
            List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
        nb_priors (`int`, *optional*, defaults to 3):
            Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
            (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
            trained using a top prior and 2 upsampler priors.
        sampling_rate (`int`, *optional*, defaults to 44100):
            Sampling rate of the raw audio.
        timing_dims (`int`, *optional*, defaults to 64):
            Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
            layer. The timing embedding layer converts the absolute and relative position in the currently sampled
            audio to a tensor of length `timing_dims` that will be added to the music tokens.
        min_duration (`int`, *optional*, defaults to 0):
            Minimum duration of the audios to generate
        max_duration (`float`, *optional*, defaults to 600.0):
            Maximum duration of the audios to generate
        max_nb_genres (`int`, *optional*, defaults to 5):
            Maximum number of genres that can be used to condition a single sample.
        metadata_conditioning (`bool`, *optional*, defaults to `True`):
            Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
            duration.

    Example:

    ```
    >>> from transformers import JukeboxModel, JukeboxConfig

    >>> # Initializing a Jukebox configuration
    >>> configuration = JukeboxConfig()

    >>> # Initializing a model from the configuration
    >>> model = JukeboxModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 类型标识符,用于标识该配置类是`jukebox`类型的配置
    model_type = "jukebox"
    # 初始化方法,用于实例化 JukeboxConfig 对象
    def __init__(
        self,
        vqvae_config=None,
        prior_config_list=None,
        nb_priors=3,
        sampling_rate=44100,
        timing_dims=64,
        min_duration=0,
        max_duration=600.0,
        max_nb_genres=5,
        metadata_conditioning=True,
        **kwargs,
    ):
        # 如果 vqvae_config 为 None,则用空字典初始化
        if vqvae_config is None:
            vqvae_config = {}
            # 记录日志,说明 vqvae_config 是 None,使用默认值初始化 JukeboxVQVAE
            logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")

        # 使用给定的 vqvae_config 字典创建 JukeboxVQVAEConfig 对象,并赋值给 self.vqvae_config
        self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)

        # 如果 prior_config_list 不为 None,则依次用 JukeboxPriorConfig 类实例化列表中的每个配置
        if prior_config_list is not None:
            self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
        else:
            # 否则初始化为空列表
            self.prior_configs = []
            # 对于每个 prior_idx 在 nb_priors 范围内,尝试从 kwargs 中获取配置信息,如果没有则使用空字典初始化
            for prior_idx in range(nb_priors):
                prior_config = kwargs.pop(f"prior_{prior_idx}", None)
                if prior_config is None:
                    prior_config = {}
                    # 记录日志,说明该 prior_idx 的配置是 None,使用默认值初始化 JukeboxPriorConfig 列表
                    logger.info(
                        f"prior_{prior_idx}'s  config is None. Initializing the JukeboxPriorConfig list with default"
                        " values."
                    )
                # 使用 prior_config 字典创建 JukeboxPriorConfig 对象,并添加到 prior_configs 列表中
                self.prior_configs.append(JukeboxPriorConfig(**prior_config))

        # 将 vqvae_config 中的 hop_fraction 属性赋值给当前对象的 hop_fraction 属性
        self.hop_fraction = self.vqvae_config.hop_fraction

        # 将传入的各种元数据配置参数赋值给对象的相应属性
        self.nb_priors = nb_priors
        self.max_nb_genres = max_nb_genres
        self.sampling_rate = sampling_rate
        self.timing_dims = timing_dims
        self.min_duration = min_duration
        self.max_duration = max_duration
        self.metadata_conditioning = metadata_conditioning

        # 调用父类的初始化方法,传入剩余的 kwargs 参数
        super().__init__(**kwargs)

    @classmethod
    def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
        r"""
        Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
        configuration.

        Returns:
            [`JukeboxConfig`]: An instance of a configuration object
        """
        # 将 prior_configs 列表中每个配置对象转换为字典形式,存入 prior_config_list
        prior_config_list = [config.to_dict() for config in prior_configs]
        # 调用当前类的初始化方法,传入 prior_config_list 和 vqvae_config 的字典形式,以及 kwargs 参数
        return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)

    def to_dict(self):
        # 重写父类的 to_dict 方法,将对象转换为字典形式
        result = super().to_dict()
        # 将 prior_configs 列表中每个配置对象转换为字典形式,存入 result 字典的 "prior_config_list" 键下
        result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
        return result

.\models\jukebox\convert_jukebox.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Jukebox checkpoints"""

import argparse  # 导入处理命令行参数的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入操作系统相关功能的模块
from pathlib import Path  # 导入处理路径相关操作的模块

import requests  # 导入发送 HTTP 请求的模块
import torch  # 导入 PyTorch 深度学习框架

from transformers import JukeboxConfig, JukeboxModel  # 导入 Jukebox 模型相关类
from transformers.utils import logging  # 导入日志记录模块


logging.set_verbosity_info()  # 设置日志记录级别为信息
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"  # Jukebox 模型下载地址的前缀
MODEL_MAPPING = {
    "jukebox-1b-lyrics": [
        "5b/vqvae.pth.tar",
        "5b/prior_level_0.pth.tar",
        "5b/prior_level_1.pth.tar",
        "1b_lyrics/prior_level_2.pth.tar",
    ],
    "jukebox-5b-lyrics": [
        "5b/vqvae.pth.tar",
        "5b/prior_level_0.pth.tar",
        "5b/prior_level_1.pth.tar",
        "5b_lyrics/prior_level_2.pth.tar",
    ],
}


def replace_key(key):
    if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
        key = key.replace(".model.1.bias", ".conv1d_1.bias")  # 替换模型参数键名中的 ".model.1.bias" 为 ".conv1d_1.bias"
    elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
        key = key.replace(".model.1.weight", ".conv1d_1.weight")  # 替换模型参数键名中的 ".model.1.weight" 为 ".conv1d_1.weight"
    elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
        key = key.replace(".model.3.bias", ".conv1d_2.bias")  # 替换模型参数键名中的 ".model.3.bias" 为 ".conv1d_2.bias"
    elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
        key = key.replace(".model.3.weight", ".conv1d_2.weight")  # 替换模型参数键名中的 ".model.3.weight" 为 ".conv1d_2.weight"

    if "conditioner_blocks.0." in key:
        key = key.replace("conditioner_blocks.0", "conditioner_blocks")  # 替换模型参数键名中的 "conditioner_blocks.0" 为 "conditioner_blocks"

    if "prime_prior" in key:
        key = key.replace("prime_prior", "encoder")  # 替换模型参数键名中的 "prime_prior" 为 "encoder"

    if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
        key = key.replace(".emb.", ".")  # 替换模型参数键名中的 ".emb." 为 "."

    if key.endswith("k"):  # 如果键名以 "k" 结尾,替换为以 "codebook" 结尾
        return key.replace(".k", ".codebook")
    if "y_emb." in key:
        return key.replace("y_emb.", "metadata_embedding.")  # 替换模型参数键名中的 "y_emb." 为 "metadata_embedding."

    if "x_emb.emb." in key:
        key = key.replace("0.x_emb.emb", "embed_tokens")  # 替换模型参数键名中的 "0.x_emb.emb" 为 "embed_tokens"

    if "prime_state_ln" in key:
        return key.replace("prime_state_ln", "encoder.final_layer_norm")  # 替换模型参数键名中的 "prime_state_ln" 为 "encoder.final_layer_norm"
    if ".ln" in key:
        return key.replace(".ln", ".layer_norm")  # 替换模型参数键名中的 ".ln" 为 ".layer_norm"
    if "_ln" in key:
        return key.replace("_ln", "_layer_norm")  # 替换模型参数键名中的 "_ln" 为 "_layer_norm"

    if "prime_state_proj" in key:
        return key.replace("prime_state_proj", "encoder.proj_in")  # 替换模型参数键名中的 "prime_state_proj" 为 "encoder.proj_in"
    if "prime_x_out" in key:
        return key.replace("prime_x_out", "encoder.lm_head")  # 替换模型参数键名中的 "prime_x_out" 为 "encoder.lm_head"
    # 如果字符串 "prior.x_out" 在 key 中,将 "x_out" 替换为 "fc_proj_out" 并返回替换后的结果
    if "prior.x_out" in key:
        return key.replace("x_out", "fc_proj_out")
    # 如果字符串 "x_emb" 在 key 中,将 "x_emb" 替换为 "embed_tokens" 并返回替换后的结果
    if "x_emb" in key:
        return key.replace("x_emb", "embed_tokens")

    # 如果以上条件都不满足,则返回 key 本身
    return key
def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
    # 初始化一个空字典,用于存储修复后的模型权重
    new_dict = {}
    # 导入正则表达式模块用于匹配模型权重的键名
    import re

    # 正则表达式用于匹配编码器块的卷积层输入
    re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配编码器块的ResNet结构
    re_encoder_block_resnet = re.compile(
        r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
    )
    # 正则表达式用于匹配编码器块的投影输出
    re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")

    # 正则表达式用于匹配解码器块的卷积层输出
    re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配解码器块的ResNet结构
    re_decoder_block_resnet = re.compile(
        r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
    )
    # 正则表达式用于匹配解码器块的投影输入
    re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")

    # 正则表达式用于匹配先验条件块的卷积层输出
    re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配先验条件块的ResNet结构
    re_prior_cond_resnet = re.compile(
        r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
    )
    # 正则表达式用于匹配先验条件块的投影输入
    re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")

    # 返回初始化的空字典
    return new_dict


@torch.no_grad()
def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
    """
    Copy/paste/tweak model's weights to our Jukebox structure.
    """
    # 遍历模型映射中的每个文件路径
    for file in MODEL_MAPPING[model_name]:
        # 如果文件不存在于指定路径中,则从URL下载文件并保存
        if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
            r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
            os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
            open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)

    # 根据模型名称加载预训练配置和模型
    model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
    config = JukeboxConfig.from_pretrained(model_name)
    model = JukeboxModel(config)

    # 初始化一个空列表用于存储模型的权重字典
    weight_dict = []
    # 初始化一个空字典用于存储映射关系
    mapping = {}

    # 遍历要转换的每个模型字典名称
    for i, dict_name in enumerate(model_to_convert):
        # 从PyTorch模型文件中加载旧的字典
        old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}")["model"]

        # 初始化一个空字典用于存储新的修复后的字典
        new_dic = {}
        # 遍历旧字典的每个键
        for k in old_dic.keys():
            # 根据键名的后缀进行不同的处理
            if k.endswith(".b"):
                new_dic[k.replace("b", "bias")] = old_dic[k]
            elif k.endswith(".w"):
                new_dic[k.replace("w", "weight")] = old_dic[k]
            elif "level_2" not in dict_name and "cond.model." in k:
                new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
            else:
                new_dic[k] = old_dic[k]

        # 根据特定前缀修复Jukebox模型的键名
        key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
        new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
        # 将修复后的字典添加到权重列表中
        weight_dict.append(new_dic)

    # 从权重列表中取出VQ-VAE部分的状态字典并加载到模型中
    vqvae_state_dict = weight_dict.pop(0)
    model.vqvae.load_state_dict(vqvae_state_dict)
    # 遍历权重列表中的每个元素,将其加载到模型的先验概率分布部分
    for i in range(len(weight_dict)):
        model.priors[i].load_state_dict(weight_dict[2 - i])

    # 确保指定路径存在,用于保存转换后的PyTorch模型
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    # 将映射数据保存为 JSON 文件到指定路径
    with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
        json.dump(mapping, txtfile)
    
    # 打印模型保存信息,包括模型名称和保存路径
    print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
    
    # 使用 PyTorch 模型对象的方法保存模型权重到指定路径
    model.save_pretrained(pytorch_dump_folder_path)
    
    # 返回保存的权重字典
    return weight_dict
if __name__ == "__main__":
    # 如果当前脚本作为主程序执行,则进入条件判断块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # Required parameters
    parser.add_argument(
        "--model_name",
        default="jukebox-5b-lyrics",
        type=str,
        help="Name of the model you'd like to convert.",
    )
    # 添加必选参数:模型名称,设置默认值为"jukebox-5b-lyrics",类型为字符串,用于指定要转换的模型名称

    parser.add_argument(
        "--pytorch_dump_folder_path",
        default="jukebox-5b-lyrics-converted",
        type=str,
        help="Path to the output PyTorch model directory.",
    )
    # 添加必选参数:PyTorch 模型输出文件夹路径,设置默认值为"jukebox-5b-lyrics-converted",类型为字符串,用于指定输出转换后的模型的存储路径

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 变量中

    convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
    # 调用函数 convert_openai_checkpoint,传入模型名称和输出文件夹路径作为参数,执行模型转换操作

.\models\jukebox\modeling_jukebox.py

# coding=utf-8
# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Jukebox model."""

import math  # 导入数学库
import os  # 导入操作系统相关功能
from typing import List, Optional, Tuple  # 导入类型提示相关的模块

import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库
import torch.nn.functional as F  # 导入PyTorch中的函数模块
from torch import nn  # 导入PyTorch的神经网络模块
from torch.nn import LayerNorm as FusedLayerNorm  # 导入PyTorch的归一化层模块

from ...activations import ACT2FN  # 导入自定义的激活函数
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...utils import add_start_docstrings, logging  # 导入工具函数和日志模块
from ...utils.logging import tqdm  # 导入进度条显示模块
from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig  # 导入配置文件

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "openai/jukebox-1b-lyrics",
    "openai/jukebox-5b-lyrics",
    # See all Jukebox models at https://huggingface.co/models?filter=jukebox
]


def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    """
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering

    Args:
        logits (`torch.Tensor`):
            logits distribution shape (vocabulary size)
        top_k (`int`, *optional*, defaults to 0):
            When `top_k >0` keep only top key tokens with highest probability (top-k filtering).
        top_p (`int`, *optional*, defaults to 0):
            When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering).
    """
    logits = logits.clone()  # 复制logits张量,确保不改变原始数据
    top_k = min(top_k, logits.size(-1))  # 安全检查,确保top_k不超过logits的最后一个维度大小

    if top_k > 0:
        # 移除概率小于top-k中的最后一个概率的所有标记
        indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:]
        logits[indices_to_remove] = filter_value  # 将这些标记的概率值设置为filter_value
    # 如果给定的 top_p 阈值大于 0,则执行以下操作
    if top_p > 0.0:
        # 对 logits 进行降序排序,并返回排序后的 logits 和对应的索引
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        
        # 计算排序后的 logits 的累积 softmax 概率
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        
        # 根据累积概率超过阈值的情况,标记需要移除的索引
        sorted_indices_to_remove = cumulative_probs > top_p
        
        # 将超过阈值的索引右移一位,以保留第一个超过阈值的 token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        # 根据排序后的索引,创建一个布尔张量表示需要移除的位置
        indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
            dim=-1, index=sorted_indices, src=sorted_indices_to_remove
        )
        
        # 将需要移除的位置对应的 logits 设置为指定的 filter_value
        logits[indices_to_remove] = filter_value
    
    # 返回经过处理后的 logits
    return logits
# 根据给定的参数,从完整的歌词标记中提取相关的标记。
# 返回的标记数为 `max_n_lyric_tokens`。如果提供的标记序列较小,将进行填充;否则,只返回从中点向左右偏移 `max_n_lyric_tokens//2` 的标记。
# 这个过程专注于时间上最相关的标记。

def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration):
    """
    Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be
    returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the
    midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on
    the most relevant tokens (in time) for the sequence.

    Args:
        full_tokens (`List[int]`):
            List containing the token ids of the entire lyrics.
        max_n_lyric_tokens (`int`):
            Maximum number of lyric tokens to return.
        total_length (`int`):
            Total expected length of the music (not all of it is generated, see duration), in samples.
        offset (`int`):
            Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into
            account
        duration (`int`):
            Expected duration of the generated music, in samples. The duration has to be smaller than the total length,
            which represent the overall length of the signal,
    """
    full_tokens = full_tokens[0]  # 取出列表中的第一个元素(预计是整数列表)
    if len(full_tokens) < max_n_lyric_tokens:
        # 如果标记序列长度小于 `max_n_lyric_tokens`,进行填充
        tokens = torch.cat(
            [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens]
        )
        indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens)))
    else:
        # 计算中点位置
        midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length)
        # 限制中点位置在有效范围内
        midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2)
        # 提取中心周围的标记
        tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2]
        indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2))
    return tokens.unsqueeze(dim=0), indices


# 将总长度 `total_length` 拆分为大小为 `n_ctx` 的窗口,每隔 `hop_length` 个样本分隔开
def get_starts(total_length, n_ctx, hop_length):
    starts = []
    for start in range(0, total_length - n_ctx + hop_length, hop_length):
        if start + n_ctx >= total_length:
            # 最后一个窗口可能会较小,我们设定为 `n_ctx` 以最大化上下文
            start = total_length - n_ctx
        starts.append(start)
    return starts


# 获取音乐标记、标签、先验值和配置信息,返回对齐信息
def get_alignment(music_tokens, labels, prior, config):
    level = prior.levels - 1  # 使用的顶层
    n_ctx = prior.n_ctx
    tokens = music_tokens[level]
    batch_size, total_length = tokens.shape[0], tokens.shape[1]
    if total_length < n_ctx:
        # 如果总长度小于 `n_ctx`,进行填充
        padding_length = n_ctx - total_length
        tokens = torch.cat(
            [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1
        )
        total_length = tokens.shape[1]
    else:
        padding_length = 0

    # 计算 `hop_length`,这是根据配置的音频片段长度的分数决定的
    hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx)
    # 从配置中获取对齐头部和对齐层信息,并选择第一个
    alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0]
    # 创建包含alignment_layer的集合
    attn_layers = {alignment_layer}
    # 创建空的对齐跳数字典
    alignment_hops = {}
    # 创建空的索引跳数字典
    indices_hops = {}
    # 对于每个从get_starts生成的起始位置进行迭代,显示"Computing lyric to music alignment"进度条
    for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "):
        end = start + n_ctx
        # 获取metadata和indices_hop,从prior获取标签,开始,采样长度,并获取indices
        metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0)
        # 将tokens分块为batch_size大小的张量块
        tokens_bs = torch.chunk(tokens, batch_size, dim=0)
        # 将metadata分块为batch_size大小的张量块
        metadata_bs = torch.chunk(metadata, batch_size, dim=0)
        # 创建空列表w_hops
        w_hops = []
        # 对于tokens_bs和metadata_bs中的每一对,执行以下操作
        for tokens_i, metadata_i in zip(tokens_bs, metadata_bs):
            # 调用prior的forward_tokens函数,传递tokens_i[:, start:end],空列表,metadata_i参数,获取attn_layers的注意力权重
            w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers)
            # 将第一个返回的注意力权重的alignment_head列添加到w_hops中
            w_hops.append(w_hop[0][:, alignment_head])
            # 删除w_hop变量以释放内存
            del w_hop
        # 将w_hops中的张量连接成一个张量weights
        weights = torch.cat(w_hops, dim=0)
        # 删除w_hops以释放内存
        del w_hops
        # 将weights转换为float类型,移动到CPU上,并转换为numpy数组,存储在alignment_hop中
        alignment_hop = weights.float().cpu().numpy()
        # 删除weights以释放内存
        del weights

        # alignment_hop的形状为(bs, n_ctx, nb_relevant_lyric_tokens)
        # indices_hop是长度为bs的列表,每个条目长度为hps.nb_relevant_lyric_tokens
        # 将indices_hop和alignment_hop存储在对应的start位置
        indices_hops[start] = indices_hop
        alignment_hops[start] = alignment_hop

    # 将每个跳的attn组合成全范围的attn
    # 使用indices将它们放置到相应源tokens的正确位置
    alignments = []
    for item in range(batch_size):
        # 注意每个item具有不同长度的歌词
        full_tokens = labels[0, 3:]
        # 创建全零数组alignment,形状为(total_length, len(full_tokens) + 1)
        alignment = np.zeros((total_length, len(full_tokens) + 1))
        # 对于反向遍历的每个start,执行以下操作
        for start in reversed(get_starts(total_length, n_ctx, hop_length)):
            end = start + n_ctx
            # 获取alignment_hops中的alignment_hop[item]和indices_hops中的indices[item]
            alignment_hop = alignment_hops[start][item]
            indices = indices_hops[start][item]
            # 将alignment_hop放置到对应的indices位置
            alignment[start:end, indices] = alignment_hop
        # 去除token填充和最后一个歌词索引,截取alignment数组
        alignment = alignment[: total_length - padding_length, :-1]
        # 将alignment添加到alignments列表中
        alignments.append(alignment)
    # 返回alignments列表作为函数结果
    return alignments
# 定义一个函数,用于保存临时音频数据
def save_temp_audio(fname, lvl, metas, aud):
    # 将音频数据限制在[-1, 1]范围内,并转换为numpy数组
    aud = torch.clamp(aud, -1, 1).cpu().numpy()
    # 遍历音频数据的每一个片段
    for i in list(range(aud.shape[0])):
        # 如果提供了元数据信息
        if metas is not None:
            # 获取当前片段的艺术家、流派和歌词信息
            artists, genres, lyrics = list(metas)[i].values()
            # 构建保存路径,包含文件夹名、级别、艺术家、流派、歌词前5个字符和索引信息
            path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}"
            # 保存numpy数组为.npy文件
            np.save(path, aud[i])
        else:
            # 如果未提供元数据信息,直接保存为.npy文件,文件名包含级别和索引信息
            np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i])


# 定义一个函数,根据不同的掩码类型生成掩码张量
def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t):
    # 如果掩码为None或者查询长度为1,则返回None,表示无需掩码
    if mask is None or query_length == 1:
        return None
    # 计算偏移量,用于掩码生成的起始位置
    offset = sample_t - query_length if sample else max(key_value_length - query_length, 0)
    # 根据不同的掩码类型生成相应的掩码张量
    if mask == "autoregressive":
        # 自回归掩码:下三角形式的矩阵,掩盖查询和键值之间的依赖关系
        mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
    elif mask == "summary":
        # 摘要掩码:用于对输入进行汇总处理时使用的掩码
        mask = torch.ones(query_length, query_length, device=device).tril()
        mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :]
        mask = (
            torch.nn.functional.pad(
                mask,
                (0, 0, 1, 0),
                value=1,
            )
            .contiguous()
            .view(query_length, key_value_length)
        )
    elif mask == "prime":
        # 主掩码:一种特定形式的下三角掩码
        mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
    return mask.view(1, 1, query_length, key_value_length)


# 定义一个神经网络模型类,实现基于卷积的Jukebox模型
class JukeboxConv1D(nn.Module):
    def __init__(self, input_width, output_width):
        super().__init__()
        self.input_width = input_width
        self.output_width = output_width
        # 初始化权重和偏置参数
        weight = torch.empty(input_width, output_width)
        bias = torch.zeros(output_width)
        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias)

    def forward(self, hidden_states):
        # 计算输出大小
        size_out = (*hidden_states.size()[:-1], self.output_width)
        # 执行卷积操作,并加上偏置项
        hidden_states = torch.addmm(
            self.bias.type_as(hidden_states),
            hidden_states.view(-1, hidden_states.size(-1)),
            self.weight.type_as(hidden_states),
        )
        # 重新调整输出形状并返回结果
        hidden_states = hidden_states.view(*size_out)
        return hidden_states


# 定义一个神经网络模型类,实现基于残差卷积的Jukebox模型块
class JukeboxResConv1DBlock(nn.Module):
    def __init__(self, config, conv_width, depth=1, res_scale=1.0):
        super().__init__()
        # 根据配置参数计算隐藏层维度、膨胀率和填充大小
        hidden_dim = config.res_convolution_multiplier * conv_width
        dilation = config.res_dilation_growth_rate**depth
        padding = dilation

        self.res_scale = res_scale
        self.activation = nn.ReLU()
        # 定义第一个卷积层和第二个卷积层
        self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation)
        self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0)
    # 定义一个前向传播方法,用于神经网络模型中
    def forward(self, hidden_states):
        # 将输入的隐藏状态保存为残差项
        residuals = hidden_states
        # 对隐藏状态应用激活函数
        hidden_states = self.activation(hidden_states)
        # 应用第一个一维卷积层
        hidden_states = self.conv1d_1(hidden_states)
        # 再次应用激活函数
        hidden_states = self.activation(hidden_states)
        # 应用第二个一维卷积层
        hidden_states = self.conv1d_2(hidden_states)
        # 返回残差项与带有残差缩放系数的隐藏状态的和
        return residuals + self.res_scale * hidden_states
# 定义 JukeboxResnet1D 类,继承自 nn.Module 类,实现了一维卷积神经网络的残差结构
class JukeboxResnet1D(nn.Module):
    # 初始化函数,接受配置 config、卷积宽度 conv_width、深度 n_depth、是否反向扩张 reverse_dilation 参数
    def __init__(self, config, conv_width, n_depth, reverse_dilation=False):
        # 调用父类的初始化方法
        super().__init__()
        # 根据配置设置残差扩张周期
        self.dilation_cycle = config.res_dilation_cycle
        # 如果配置了卷积残差缩放,则根据深度设置缩放系数
        res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth)

        # 创建空的块列表
        blocks = []
        # 根据深度循环创建残差卷积块
        for depth in range(n_depth):
            # 如果设置了扩张周期,则取当前深度对扩张周期取模得到块深度
            block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle
            # 创建并添加 JukeboxResConv1DBlock 到块列表中
            blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale))

        # 如果设置了反向扩张,则对块列表进行反向排序
        if reverse_dilation:
            blocks = blocks[::-1]
        # 将块列表转换为 nn.ModuleList 类型的模块列表并赋值给实例变量
        self.resnet_block = nn.ModuleList(blocks)

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 遍历每个残差卷积块,依次对输入进行处理
        for block in self.resnet_block:
            hidden_states = block(hidden_states)
        # 返回处理后的 hidden_states
        return hidden_states


# 定义 JukeboxEncoderConvBlock 类,继承自 nn.Module 类,实现了编码器的卷积块结构
class JukeboxEncoderConvBlock(nn.Module):
    # 初始化函数,接受配置 config、嵌入维度 embed_dim、隐藏维度 hidden_dim、深度 depth、down_t、stride_t 参数
    def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t):
        # 调用父类的初始化方法
        super().__init__()
        # 创建空的块列表
        blocks = []
        # 计算滤波器大小 filter_t 和填充大小 pad_t
        filter_t = stride_t * 2
        pad_t = stride_t // 2
        # 如果 down_t 大于 0,则循环添加卷积层和残差卷积块到块列表中
        if down_t > 0:
            for i in range(down_t):
                # 添加 1 维卷积层到块列表中
                blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t))
                # 添加 JukeboxResnet1D 模块到块列表中
                blocks.append(JukeboxResnet1D(config, hidden_dim, depth))

        # 创建输出投影层
        self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1)
        # 将块列表转换为 nn.ModuleList 类型的模块列表并赋值给实例变量
        self.downsample_block = nn.ModuleList(blocks)

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 遍历每个块,依次对输入进行处理
        for block in self.downsample_block:
            hidden_states = block(hidden_states)
        # 将处理后的 hidden_states 经过投影层处理后返回
        hidden_states = self.proj_out(hidden_states)
        return hidden_states


# 定义 JukeboxEncoder 类,继承自 nn.Module 类,实现了 Jukebox 编码器结构
class JukeboxEncoder(nn.Module):
    # 初始化函数,接受配置 config、宽度 width、深度 depth、层级 levels、downs_t、strides_t 参数
    def __init__(self, config, width, depth, levels, downs_t, strides_t):
        # 调用父类的初始化方法
        super().__init__()
        # 设置层级数
        self.levels = levels
        # 创建模块列表 level_blocks
        self.level_blocks = nn.ModuleList()

        # 使用 zip 函数迭代 levels、downs_t 和 strides_t,并根据迭代结果生成 JukeboxEncoderConvBlock 模块并添加到 level_blocks 中
        iterator = zip(list(range(self.levels)), downs_t, strides_t)
        for i, down_t, stride_t in iterator:
            self.level_blocks.append(
                JukeboxEncoderConvBlock(
                    config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t
                )
            )

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 创建空列表 all_hidden_states 用于存储所有层级的隐藏状态
        all_hidden_states = []

        # 遍历每个层级
        for level in range(self.levels):
            # 获取当前层级的 JukeboxEncoderConvBlock 模块
            level_block = self.level_blocks[level]
            # 对输入 hidden_states 应用当前层级的模块处理
            hidden_states = level_block(hidden_states)
            # 将处理后的隐藏状态添加到 all_hidden_states 列表中
            all_hidden_states.append(hidden_states)

        # 返回所有层级的隐藏状态列表
        return all_hidden_states


# 定义 JukeboxDecoderConvBock 类,继承自 nn.Module 类,未完成的类定义
class JukeboxDecoderConvBock(nn.Module):
    # 初始化函数,用于初始化类实例
    def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True):
        # 设置类的属性 embed_dim 和 hidden_dim
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        # 调用父类的初始化方法
        super().__init__()
        # 初始化空列表用于存储模块
        blocks = []
        # 如果 down_t 大于 0,执行以下操作
        if down_t > 0:
            # 计算滤波器长度和填充长度
            filter_t = stride_t * 2
            pad_t = stride_t // 2
            # 创建输入投影层,将 embed_dim 维度的输入转换为 hidden_dim 维度
            self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1)
            # 循环 down_t 次,添加 JukeboxResnet1D 模块和反卷积层到 blocks 列表中
            for i in range(down_t):
                blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation))
                blocks.append(
                    nn.ConvTranspose1d(
                        hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t
                    )
                )
        # 将 blocks 列表作为 ModuleList 赋给实例的 upsample_block 属性
        self.upsample_block = nn.ModuleList(blocks)

    # 前向传播函数,处理输入的隐藏状态
    def forward(self, hidden_states):
        # 将输入的隐藏状态通过投影层 proj_in 进行转换
        hidden_states = self.proj_in(hidden_states)
        # 对 upsample_block 中的每个模块进行前向传播
        for block in self.upsample_block:
            hidden_states = block(hidden_states)
        # 返回处理后的隐藏状态
        return hidden_states
class JukeboxDecoder(nn.Module):
    # 定义JukeboxDecoder类,继承自nn.Module
    def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t):
        super().__init__()
        self.levels = levels
        self.level_blocks = nn.ModuleList()
        for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t):
            self.level_blocks.append(
                JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t)
            )

        self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1)
        # 初始化各个网络层

    def forward(self, hidden_states, all_levels=True):
        hidden_state = hidden_states[-1]

        # 32, 64 ...
        for level in reversed(range(self.levels)):
            level_block = self.level_blocks[level]
            hidden_state = level_block(hidden_state)

            if level != 0 and all_levels:
                hidden_state = hidden_state + hidden_states[level - 1]
        # 在不同的层级进行前向传播,并根据需要进行级联

        hidden_state = self.out(hidden_state)
        return hidden_state
        # 返回隐藏状态


class JukeboxBottleneckBlock(nn.Module):
    # 定义JukeboxBottleneckBlock类,继承自nn.Module
    def __init__(self, config: JukeboxVQVAEConfig):
        super().__init__()
        self.nb_discrete_codes = config.nb_discrete_codes
        self.codebook_width = config.embed_dim
        self.mu = config.lmu
        self.threshold = 1.0
        self.init = False
        self.codebook_sum = None
        self.codebook_elem = None
        self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width))
        # 初始化相关变量,并注册缓冲区

    def _tile(self, hidden_states):
        dim, embed_width = hidden_states.shape
        if dim < self.nb_discrete_codes:
            n_repeats = (self.nb_discrete_codes + dim - 1) // dim
            std = 0.01 / np.sqrt(embed_width)
            hidden_states = hidden_states.repeat(n_repeats, 1)
            hidden_states = hidden_states + torch.randn_like(hidden_states) * std
        return hidden_states
        # 定义辅助函数_tile,用于重复和扩展隐藏状态

    def init_codebook(self, hidden_states):
        nb_discrete_codes = self.nb_discrete_codes
        self.init = True
        codes = self._tile(hidden_states)
        self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
        self.codebook_sum = self.codebook
        self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device)
        # 初始化码书信息
    # 更新代码本函数,更新代码簿中的中心点
    def update_codebook(self, hidden_states, latent_states):
        # 从对象属性中获取参数
        mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes
        # 禁止梯度计算
        with torch.no_grad():
            # 计算新的中心点
            # 将离散状态转换为独热编码
            latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device)
            latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1)

            # 计算每个簇的加权和
            _codebook_sum = torch.matmul(latent_states_onehot, hidden_states)
            # 计算每个簇的元素数量
            _codebook_elem = latent_states_onehot.sum(dim=-1)  # nb_discrete_codes
            # 复制隐藏状态以扩展簇的数量
            codes = self._tile(hidden_states)
            # 随机选取一些代码本的样本
            _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]

            # 更新中心点
            old_codebook = self.codebook
            # 更新加权和
            self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum
            # 更新簇的元素数量
            self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem  # nb_discrete_codes
            # 计算每个簇的使用情况
            usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float()

            # 归一化簇的中心点
            norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view(
                nb_discrete_codes, 1
            )
            # 更新代码本
            self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook
            # 计算每个簇的概率
            _codebook_prob = _codebook_elem / torch.sum(_codebook_elem)  # prob of each bin
            # 计算熵,用于衡量多样性
            entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8))  # entropy ie how diverse
            # 计算当前使用的簇的数量
            used_curr = (_codebook_elem >= self.threshold).sum()
            # 计算簇的使用情况
            usage = torch.sum(usage)
            # 计算 K-L 散度
            dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
        # 返回更新结果
        return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}

    # 预处理函数,用于规范化隐藏状态
    def preprocess(self, hidden_states):
        # 调整张量形状以便后续处理
        hidden_states = hidden_states.permute(0, 2, 1).contiguous()
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

        # 如果隐藏状态的维度等于代码本的宽度
        if hidden_states.shape[-1] == self.codebook_width:
            # 计算预规范化值
            prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape))
        # 如果隐藏状态的维度是代码本宽度的两倍
        elif hidden_states.shape[-1] == 2 * self.codebook_width:
            # 分离隐藏状态的两部分
            x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
            # 分别计算两部分的预规范化值,并相加
            prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
                torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
            )
            # 合并隐藏状态的两部分
            hidden_states = x1 + x2

        # 返回预处理后的隐藏状态及其规范化值
        return hidden_states, prenorm
    def postprocess(self, latent_states, dequantised_states, x_shape):
        # 获取输入数据的批次大小和时间步数
        batch_size, time = x_shape
        # 重新组织 dequantised_states 的形状,使其变为 (batch_size, -1, time)
        dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous()
        # 重新组织 latent_states 的形状,使其变为 (batch_size, time)
        latent_states = latent_states.view(batch_size, time)
        return latent_states, dequantised_states

    def quantise(self, latent_states):
        # 计算 latent_states 与 codebook 中的距离
        codebook_weights = self.codebook.t()
        distance = (
            torch.sum(latent_states**2, dim=-1, keepdim=True)
            - 2 * torch.matmul(latent_states, codebook_weights)
            + torch.sum(codebook_weights**2, dim=0, keepdim=True)
        )  # 形状为 (batch_size * latent_states , codebook_weights)
        # 找到每个 latent_state 最接近的 codebook 中的索引 music_tokens
        min_distance, music_tokens = torch.min(distance, dim=-1)
        # 计算平均最小距离
        fit = torch.mean(min_distance)
        return music_tokens, fit

    def dequantise(self, music_tokens):
        # 使用 music_tokens 从 codebook 中获取对应的 dequantised_states
        dequantised_states = F.embedding(music_tokens, self.codebook)
        return dequantised_states

    def encode(self, latent_states):
        samples, _, seq_len = latent_states.shape

        # 数据预处理
        latent_states, _ = self.preprocess(latent_states)

        # 量化过程
        music_tokens, _ = self.quantise(latent_states)

        # 后处理
        music_tokens = music_tokens.view(samples, seq_len)
        return music_tokens

    def decode(self, music_tokens):
        samples, seq_len = music_tokens.shape

        # 反量化过程
        dequantised_states = self.dequantise(music_tokens)

        # 后处理
        dequantised_states = (
            dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous()
        )
        return dequantised_states

    def forward(self, hidden_states, update_codebook=True):
        samples, _, seq_len = hidden_states.shape

        # 数据预处理
        hidden_states, prenorm = self.preprocess(hidden_states)

        # 如果需要更新 codebook 并且未初始化,则进行初始化
        if update_codebook and not self.init:
            self.init_codebook(hidden_states)

        # 通过编码和解码过程量化和反量化
        music_tokens, fit = self.quantise(hidden_states)
        dequantised_states = self.dequantise(music_tokens)

        # 如果需要更新 codebook,则更新相关指标
        if update_codebook:
            update_metrics = self.update_codebook(hidden_states, music_tokens)
        else:
            update_metrics = {}

        # 计算损失
        commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape)

        # 通过传递增强数据流
        dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()

        # 后处理
        music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len))
        return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
# 导入 PyTorch 的 nn 模块
import torch.nn as nn

# 定义一个名为 JukeboxBottleneck 的类,继承自 nn.Module
class JukeboxBottleneck(nn.Module):
    
    # 初始化方法,接受 config 和 levels 参数
    def __init__(self, config, levels):
        super().__init__()
        self.levels = levels  # 设置 levels 属性
        self.level_blocks = nn.ModuleList()  # 初始化一个 nn.ModuleList 用于存储每个 level 的 block
        
        # 遍历 levels 创建 JukeboxBottleneckBlock,并添加到 level_blocks 中
        for level in range(self.levels):
            self.level_blocks.append(JukeboxBottleneckBlock(config))

    # 编码方法,接受 raw_audio 参数
    def encode(self, raw_audio):
        # 使用列表推导式对每个 level_block 和对应的 hidden_states 进行编码
        music_tokens = [
            level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio)
        ]
        return music_tokens  # 返回编码后的音乐 tokens

    # 解码方法,接受 music_tokens、start_level 和 end_level 参数
    def decode(self, music_tokens, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels  # 如果未指定 end_level,默认为 levels
        
        # 使用列表推导式对每个 level_block 和对应的 music_tokens 进行解码
        quantised_audio = [
            level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens)
        ]
        return quantised_audio  # 返回量化后的音频数据

    # 前向传播方法,接受 input_audio 参数
    def forward(self, input_audio):
        music_tokens, quantised_states, commit_losses, metrics = [], [], [], []
        
        # 遍历每个 level
        for level in range(self.levels):
            level_block = self.level_blocks[-level - 1]  # 获取当前 level 的 block
            hidden_states = input_audio[level]  # 获取对应的输入音频的隐藏状态
            
            # 调用 level_block 进行处理,获取返回值
            sampled_tokens, quantised_state, commit_loss, metric = level_block(
                hidden_states, update_codebook=self.training
            )
            
            music_tokens.append(sampled_tokens)  # 将 sampled_tokens 添加到 music_tokens 列表中
            
            if not self.training:
                # 在非训练模式下,确保编码器权重不会从直通估计中更改
                quantised_state = quantised_state.detach()
            
            quantised_states.append(quantised_state)  # 将 quantised_state 添加到 quantised_states 列表中
            commit_losses.append(commit_loss)  # 将 commit_loss 添加到 commit_losses 列表中
            
            if self.training:
                metrics.append(metric)  # 在训练模式下,将 metric 添加到 metrics 列表中
        
        # 返回 music_tokens、quantised_states、commit_losses 和 metrics
        return music_tokens, quantised_states, commit_losses, metrics

# 设置 JUKEBOX_START_DOCSTRING 变量,包含模型的一些基本文档信息
JUKEBOX_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config (`JukeboxConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

# 使用 @add_start_docstrings 装饰器添加额外的文档信息到 JukeboxVQVAE 类
@add_start_docstrings(
    """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam
Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111).

    """,
    JUKEBOX_START_DOCSTRING,
)
# 定义 JukeboxVQVAE 类,继承自 PreTrainedModel
class JukeboxVQVAE(PreTrainedModel):
    config_class = JukeboxVQVAEConfig  # 设置 config_class 属性
    # 设置基础模型前缀为 "vqvae"
    base_model_prefix = "vqvae"

    # 初始化权重的函数,用于初始化模块的权重
    def _init_weights(self, module):
        # 如果模块是 nn.Embedding 类型,例如 embed_tokens
        if isinstance(module, nn.Embedding):
            # 初始化权重为正态分布,均值为 0,标准差为 0.02 乘以配置参数中的初始化比例
            module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
        # 如果模块是 JukeboxConv1D 类型
        elif isinstance(module, JukeboxConv1D):
            # 根据配置参数决定是否将权重初始化为零,否则初始化为正态分布
            if self.config.zero_out:
                module.weight.data.zero_()
            else:
                module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
        # 如果模块是 JukeboxResConv1DBlock 类型,并且配置参数中指定了 zero_out 为 True
        elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
            # 将第二个卷积层的权重和偏置初始化为零
            module.conv1d_2.weight.data.zero_()
            module.conv1d_2.bias.data.zero_()
        # 如果模块是 nn.LayerNorm 类型
        if isinstance(module, nn.LayerNorm):
            # 将偏置初始化为零
            module.bias.data.zero_()
            # 将权重初始化为全 1
            module.weight.data.fill_(1.0)
        # 如果模块是 nn.Linear 类型,并且有偏置项
        if isinstance(module, nn.Linear) and module.bias is not None:
            # 将偏置项初始化为零
            module.bias.data.zero_()

    # 初始化函数,接受一个 JukeboxVQVAEConfig 类型的配置参数
    def __init__(self, config: JukeboxVQVAEConfig):
        # 调用父类的初始化方法,传入配置参数
        super().__init__(config)
        # 获取配置参数中的 res_downs_t 和 res_strides_t
        downs_t = config.res_downs_t
        strides_t = config.res_strides_t
        # 如果配置参数中没有指定 sample_length
        if not config.sample_length:
            # 计算 downsamples 数组,每个元素为 stride**down 的结果
            downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
            # 计算 top_raw_to_tokens,即 downsamples 的乘积
            top_raw_to_tokens = np.prod(downsamples)
            # 根据采样率和 top_raw_to_tokens 计算 sample_length
            config.sample_length = (
                config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens
            ) * top_raw_to_tokens
            # 将 sample_length 转换为整数类型
            config.sample_length = config.sample_length.astype(int)

        # 设置一些模型参数,从配置中获取
        self.nb_discrete_codes = config.nb_discrete_codes
        self.commit = config.commit
        self.sample_length = config.sample_length

        # 计算 downsamples 数组和 hop_lengths 数组
        self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
        self.hop_lengths = np.cumprod(self.downsamples)
        self.levels = levels = config.levels
        # 计算 music_tokens_shapes 数组
        self.music_tokens_shapes = [
            (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels)
        ]

        # 设置 multipliers 数组,如果配置中没有指定,则全部设置为 1
        self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels

        # 初始化 encoders 和 decoders,都是 nn.ModuleList 类型
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for level in range(levels):
            # 计算当前层的宽度和深度
            width = config.res_conv_width * self.multipliers[level]
            depth = config.res_conv_depth * self.multipliers[level]
            # 分别创建 JukeboxEncoder 和 JukeboxDecoder 并加入到 encoders 和 decoders 中
            self.encoders.append(
                JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
            )
            self.decoders.append(
                JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
            )

        # 初始化 bottleneck 层
        self.bottleneck = JukeboxBottleneck(config, levels)
    # 解码函数,将音乐编码 `music_tokens` 解码为原始音频表示
    def _decode(self, music_tokens, start_level=0, end_level=None):
        # 如果未指定结束级别,则使用最大级别
        if end_level is None:
            end_level = self.levels
        # 使用瓶颈网络进行解码,获取潜在状态
        latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level)
        # 只使用最低级别的解码器
        decoder, dequantised_state = self.decoders[start_level], latent_states[0:1]
        # 使用解码器对去量化状态进行解码
        dequantised_state = decoder(dequantised_state, all_levels=False)
        # 调整维度顺序,将时间轴移至第二个维度
        dequantised_state = dequantised_state.permute(0, 2, 1)
        return dequantised_state

    # 解码函数,将音乐编码 `music_tokens` 解码为原始音频表示,支持批处理
    def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor:
        """
        将输入的 `music_tokens` 解码为它们的 `raw_audio` 表示。

        Args:
            music_tokens (`torch.LongTensor`):
                音乐编码的张量,通过使用码本将其解码为原始音频。每个音乐编码应该是码本中相应 `code` 向量的索引。
            start_level (`int`, *optional*):
                解码过程开始的级别。默认为 0。
            end_level (`int`, *optional*):
                解码过程结束的级别。默认为 None。
            bs_chunks (int, *optional*):
                同时处理的块数。

        Returns:
            `torch.Tensor`: 解码后的原始音频张量。
        """
        # 将音乐编码分块,以便并行处理
        token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens]
        dequantised_states = []
        for i in range(bs_chunks):
            music_tokens_i = [chunks[i] for chunks in token_chunks]
            # 调用 `_decode` 函数进行解码
            dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level)
            dequantised_states.append(dequantised_state)
        # 拼接所有解码后的状态张量
        return torch.cat(dequantised_states, dim=0)

    # 编码函数,将原始音频 `raw_audio` 编码为音乐编码 `music_tokens`
    def _encode(self, raw_audio, start_level=0, end_level=None):
        # 编码
        if end_level is None:
            end_level = self.levels
        # 调整输入音频的维度顺序,确保正确的输入格式
        input_audio = raw_audio.permute(0, 2, 1).float()
        latent_states = []
        # 遍历所有级别的编码器,获取潜在状态
        for level in range(self.levels):
            encoder = self.encoders[level]
            latent_state = encoder(input_audio)
            latent_states.append(latent_state[-1])  # 仅保留每级别最后一个潜在状态
        # 使用瓶颈网络对潜在状态进行编码,得到音乐编码 `music_tokens`
        music_tokens = self.bottleneck.encode(latent_states)
        return music_tokens[start_level:end_level]
    # 将输入音频分割成若干块,每块作为一个处理单元
    audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0)
    # 初始化一个空列表,用于存储每个音频块的离散表示
    music_tokens_list = []
    # 遍历每个音频块
    for chunk_i in audio_chunks:
        # 调用内部方法 `_encode` 对当前音频块进行编码,生成其离散表示
        music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level)
        # 将编码后的离散表示添加到列表中
        music_tokens_list.append(music_tokens_i)
    # 将每个音频块的离散表示进行合并,按照维度0连接在一起,形成最终的音乐表示
    music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)]
    # 返回整个音乐表示
    return music_tokens

# 生成指定数量的音乐样本的离散表示
def sample(self, n_samples):
    # 为每个离散表示形状生成随机整数,表示从0到nb_discrete_codes之间的离散码
    music_tokens = [
        torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu")
        for music_tokens_shape in self.music_tokens_shapes
    ]
    # 调用解码方法,将生成的离散表示解码为音频样本
    return self.decode(music_tokens)
    # Encode/Decode
    input_audio = raw_audio.permute(0, 2, 1).float()
    # 将输入音频数据重新排列维度,使其符合模型要求的格式,并转换为浮点数类型

    latent_states = []
    # 创建空列表,用于存储每个级别的潜在状态

    for level in range(self.levels):
        # 遍历所有编码器级别
        encoder = self.encoders[level]
        # 获取当前级别的编码器
        latent_state = encoder(input_audio)
        # 对输入音频进行编码,得到潜在状态
        latent_states.append(latent_state[-1])
        # 将编码后的潜在状态加入列表中,取最后一个状态

    _, music_tokens, commit_losses, _ = self.bottleneck(latent_states)
    # 使用瓶颈模型处理潜在状态,得到音乐编码、commit loss 等结果

    dequantised_states = []
    # 创建空列表,用于存储每个级别的反量化状态

    for level in range(self.levels):
        # 遍历所有解码器级别
        decoder = self.decoders[level]
        # 获取当前级别的解码器
        dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False)
        # 使用解码器解码音乐编码得到反量化状态
        dequantised_states.append(dequantised_state.permute(0, 2, 1))
        # 将反量化状态重新排列维度并加入列表中

    commit_loss = sum(commit_losses)
    # 计算总的 commit loss
    loss = self.commit * commit_loss
    # 根据 commit 系数计算最终损失值

    return dequantised_states, loss
    # 返回解码后的状态列表及计算得到的损失值
class JukeboxMLP(nn.Module):
    def __init__(self, config):
        # 初始化函数,定义一个多层感知机(MLP)模型
        super().__init__()
        # 从配置中获取隐藏层大小作为嵌入维度
        embed_dim = config.hidden_size
        # 计算隐藏层大小的倍数作为MLP的隐藏层大小
        hidden_dim = int(config.mlp_multiplier * embed_dim)

        # 创建第一个卷积层,输入维度为embed_dim,输出维度为hidden_dim
        self.c_fc = JukeboxConv1D(embed_dim, hidden_dim)
        # 创建第二个卷积层,输入维度为hidden_dim,输出维度为embed_dim
        self.c_proj = JukeboxConv1D(hidden_dim, embed_dim)
        # 选择激活函数,从预定义的函数字典ACT2FN中获取对应配置的激活函数
        self.act = ACT2FN[config.act_fn]
        # 定义Dropout层,使用配置中的残差丢弃率
        self.dropout = nn.Dropout(config.resid_dropout)

    def forward(self, hidden_states):
        # MLP模型的前向传播函数
        # 应用第一个卷积层
        hidden_states = self.c_fc(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 应用第二个卷积层
        hidden_states = self.c_proj(hidden_states)
        # 应用Dropout层
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class JukeboxLayerNorm(FusedLayerNorm):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        # 初始化函数,定义Jukebox模型的LayerNorm层
        super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
        # 计算输入张量的总元素个数(维度的乘积)
        self.width = np.prod(normalized_shape)
        # 计算能够处理的最大元素个数,限制为65535 * self.width
        self.max_numel = 65535 * self.width

    def forward(self, input):
        # Jukebox模型LayerNorm层的前向传播函数
        if input.numel() > self.max_numel:
            # 如果输入张量的元素个数超过self.max_numel,使用PyTorch的layer_norm函数处理
            return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
        else:
            # 否则调用父类FusedLayerNorm的forward方法处理
            return super().forward(input).type_as(input)


class JukeboxAttention(nn.Module):
    # Jukebox模型的Attention模块定义,未完整展示,不做进一步注释
    # 初始化函数,用于初始化一个模型对象
    def __init__(self, config, n_ctx, attn_func="dense_attn"):
        # 调用父类的初始化函数
        super().__init__()
        # 设置嵌入维度为配置文件中的隐藏大小
        self.embed_dim = config.hidden_size
        # 设置注意力头数为配置文件中的头数
        self.n_heads = config.n_heads
        # 设置注意力的dropout概率为配置文件中的注意力dropout
        self.dropout = config.attn_dropout
        # 计算隐藏层维度,根据注意力乘子乘以嵌入维度
        hidden_dim = int(config.attention_multiplier * self.embed_dim)

        # 设置每个头的维度
        self.head_dim = hidden_dim // config.n_heads
        # 设置上下文长度
        self.n_ctx = n_ctx
        # 设置隐藏层维度
        self.hidden_dim = hidden_dim
        # 设置缩放因子,用于注意力机制中的缩放
        self.scale = self.head_dim**-0.25
        # 设置是否使用掩码
        self.mask = config.mask

        # 根据注意力函数类型选择不同的处理方式
        if attn_func == "cross_attention":
            # 如果是交叉注意力,设置交叉注意力部分的卷积模块
            self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim)
            # 设置交叉注意力中的编码键值的卷积模块
            self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2)
        else:
            # 对于其他类型的注意力,设置通用的卷积模块
            self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3)

        # 设置投影层的卷积模块,用于最终的投影
        self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim)
        # 设置注意力的dropout层
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        # 设置残差连接的dropout层
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        # 根据序列长度seq_len将其分解为[块数, seq_len // 块数]的形式
        self.attn_func = attn_func
        # 根据注意力函数类型选择对应的QKV处理函数
        if attn_func == "cross_attention":
            self.qkv = self.decode_qkv
        elif attn_func == "prime_attn":
            self.qkv = self.prime_qkv
        else:
            self.qkv = self.factored_qkv

        # 定义不同注意力类型的映射关系
        ATTENTION_MAP = {
            "dense_attn": (self.dense_attn, "autoregressive"),
            "block_attn": (self.block_attn, "autoregressive"),
            "transpose_block_attn": (self.transpose_block_attn, "autoregressive"),
            "prev_block_attn": (self.prev_block_attn, None),
            "summary_attn": (self.summary_attn, "summary"),
            "summary_spread_attn": (self.summary_spread_attn, "summary"),
            "cross_attention": (self.dense_attn, None),
            "prime_attn": (self.prime_attn, "prime"),
        }
        # 根据传入的注意力函数名称选择对应的注意力函数及其掩码
        self.attn, self.attn_mask = ATTENTION_MAP[attn_func]

        # 设置块数和扩展数
        self.blocks = config.blocks
        self.spread = config.spread
        # 如果定义了块数,则设置块上下文长度
        if self.blocks is not None:
            self.block_ctx = self.n_ctx // self.blocks

        # 设置采样时间为0
        self.sample_t = 0
        # 初始化缓存字典
        self.cache = {}
        # 设置编码器长度,即编码器输入标识符的长度
        self.encoder_len = config.nb_relevant_lyric_tokens  # length of the encoder input ids
        # 记录是否记录注意力权重
        self.record_attn = False
    # 定义注意力机制函数,接受查询、键和值状态以及采样参数
    def _attn(self, query_states, key_states, value_states, sample):
        scale = self.scale
        # 如果处于训练阶段,应用缩放因子后计算注意力权重
        if self.training:
            attention_weight = torch.matmul(query_states * scale, key_states * scale)
        else:
            # 否则直接计算注意力权重,并乘以缩放因子的平方
            attention_weight = torch.matmul(query_states, key_states)
            attention_weight.mul_(scale * scale)
        attn_weight_type = attention_weight.dtype
        # 将注意力权重转换为 float 类型
        attention_weight = attention_weight.float()
        
        # 如果有掩码需求
        if self.mask:
            # 生成适当的掩码以遮蔽当前位置之前的所有位置
            # 对于稠密运算可能占用大量内存,因此可以缓存
            mask = get_mask(
                self.attn_mask,
                query_states.size(-2),
                key_states.size(-1),
                self.blocks,
                self.spread,
                attention_weight.device,
                sample,
                self.sample_t,
            )
            # 如果掩码存在,则应用掩码;否则令未被掩码的位置的注意力权重为一个极小的值
            if mask is not None:
                attention_weight = attention_weight * mask + -1e9 * (1 - mask)
        
        # 对注意力权重进行 softmax 归一化,并根据原始类型重新转换
        attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type)
        
        # 如果记录注意力权重
        if self.record_attn:
            self.attention_prob = attention_prob
            # 如果使用的是特定的注意力函数,只保留音乐查询和歌词键/值对应的注意力权重
            if self.attn_func == "prime_attn":
                self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len]
        
        # 对注意力权重应用 dropout
        attention_prob = self.attn_dropout(attention_prob)
        
        # 计算上下文状态,通过注意力权重加权求和值状态
        context_states = torch.matmul(attention_prob, value_states)
        return context_states

    # 合并多头注意力机制的结果
    def merge_heads(self, hidden_states):
        # 对隐藏状态进行维度置换,以便后续合并多头注意力的结果
        hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
        new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1))
        # 将维度变换后的隐藏状态返回,与 TensorFlow 实现中的 merge_states 函数相对应
        return hidden_states.view(*new_hidden_states_shape)

    # 将隐藏状态拆分为多头注意力机制所需的形状
    def split_heads(self, hidden_states, is_key=False):
        # 计算新的隐藏状态形状,以便进行多头注意力机制的拆分
        new_hidden_states_shape = (
            *hidden_states.size()[:-1],
            self.n_heads,
            hidden_states.size(-1) // self.n_heads,
        )
        # 根据新形状对隐藏状态进行视图变换,与 TensorFlow 实现中的 split_states 函数对应
        hidden_states = hidden_states.view(*new_hidden_states_shape)
        
        # 如果是键,进一步置换维度以满足多头注意力机制的要求
        if is_key:
            return hidden_states.permute(0, 2, 3, 1)
        else:
            return hidden_states.permute(0, 2, 1, 3)

    # 密集注意力机制的实现,接受查询、键、值和采样参数
    def dense_attn(self, query, key, value, sample):
        # 对查询、键和值分别进行多头拆分
        query = self.split_heads(query)
        key = self.split_heads(key, is_key=True)
        value = self.split_heads(value)
        # 应用注意力机制计算上下文状态
        context_states = self._attn(query, key, value, sample)
        # 合并多头注意力机制的结果
        context_states = self.merge_heads(context_states)
        return context_states
    # 定义一个方法用于处理分块注意力机制,接受查询(query)、键(key)、值(value)和一个是否抽样的标志(sample)
    def block_attn(self, query, key, value, sample):
        # 将当前对象的块上下文(block_ctx)存储到局部变量block_ctx中
        block_ctx = self.block_ctx
        # 获取值(value)的形状,其中包括批量大小(batch_size)、序列长度(seq_len)和嵌入维度(embed_dim)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        
        # 如果抽样标志为True,调用dense_attn方法处理注意力计算,并将结果调整为(batch_size, 1, embed_dim)的形状
        if sample:
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
        else:
            # 否则,根据查询(query)的长度重新组织查询(query)张量
            query_length = query.shape[1]
            query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
            
            # 如果查询长度小于序列长度(seq_len),更新序列长度为查询长度,同时截取键(key)和值(value)的最后一部分
            if query_length < seq_len:
                seq_len = query_length
                key = key[:, -seq_len:].contiguous()
                value = value[:, -seq_len:].contiguous()
            
            # 将键(key)和值(value)重新组织为适合分块上下文的形状
            key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
            value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
            
            # 调用dense_attn方法计算分块注意力,并将结果调整为(batch_size, seq_len, embed_dim)的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 定义一个方法用于处理转置的分块注意力机制,接受查询(query)、键(key)、值(value)和一个是否抽样的标志(sample)
    def transpose_block_attn(self, query, key, value, sample):
        # 将当前对象的块上下文(block_ctx)存储到局部变量block_ctx中
        block_ctx = self.block_ctx
        # 获取值(value)的形状,其中包括批量大小(batch_size)、序列长度(seq_len)和嵌入维度(embed_dim)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        
        # 如果抽样标志为True,计算最后一个分块长度,截取键(key)和值(value)的特定分块,并调用dense_attn方法计算注意力
        if sample:
            block_len = (seq_len - 1) % block_ctx
            key = key[:, block_len::block_ctx, :]
            value = value[:, block_len::block_ctx, :]
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
        else:
            # 否则,重新组织查询(query)、键(key)和值(value),以便进行分块转置操作
            query_length = query.shape[1]
            query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim)
            query = query.transpose(1, 2).contiguous()
            query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim)

            key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
            key = key.transpose(1, 2).contiguous()
            key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)

            value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
            value = value.transpose(1, 2).contiguous()
            value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)

            # 调用dense_attn方法计算分块注意力,并进行转置以匹配原始序列的形状
            block_attn = self.dense_attn(query, key, value, sample)
            block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim)
            block_attn = block_attn.transpose(1, 2).contiguous()
            block_attn = block_attn.view(batch_size, query_length, embed_dim)

            return block_attn
    # 定义一个方法,用于处理前一个块的注意力计算
    def prev_block_attn(self, query, key, value, sample):
        # 获取块的上下文大小
        block_ctx = self.block_ctx
        # 获取 value 的形状信息:batch_size(批大小)、seq_len(序列长度)、embed_dim(嵌入维度)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        
        # 如果需要采样(sample=True),则处理前一个块的注意力
        if sample:
            # 计算当前块的数量
            block = (seq_len - 1) // block_ctx
            # 计算前一个块的长度
            prev_l = (block - 1) * block_ctx
            
            # 如果存在前一个块
            if block > 0:
                # 截取前一个块的 key 和 value
                key = key[:, prev_l : prev_l + block_ctx, :]
                value = value[:, prev_l : prev_l + block_ctx, :]
            else:
                # 如果不存在前一个块,则创建零张量填充
                key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
                value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
            
            # 调用 self.dense_attn 方法进行注意力计算,并将结果 reshape 成 (batch_size, 1, embed_dim) 的形式
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
        
        # 如果不需要采样
        else:
            # 获取 query 的长度
            query_length = query.shape[1]
            # 将 query reshape 成适合块大小的形状
            query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)

            # 将 key 和 value 根据块大小进行 reshape
            key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
            key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0))
            key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)

            value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
            value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0))
            value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)

            # 如果 query 的长度小于 seq_len,则对 key 和 value 进行进一步处理以匹配 query 的长度
            if query_length < seq_len:
                nb_query_blocks = query_length // block_ctx
                nb_key_blocks = seq_len // block_ctx
                seq_len = query_length
                key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
                key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)

                value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
                value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)

            # 调用 self.dense_attn 方法进行注意力计算,并将结果 reshape 成 (batch_size, seq_len, embed_dim) 的形式
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
    # 计算自注意力摘要
    def summary_attn(self, query, key, value, sample):
        # 获取模型的块数和块上下文大小
        blocks = self.blocks
        block_ctx = self.block_ctx
        batch_size, seq_len, embed_dim = value.shape  # 获取值的形状,其中值的形状为(batch_size, seq_len, embed_dim),用于sample情况下,query_len= 1, key_len = value_len = sample_t
        if sample:
            # 对样本进行处理,目前未实现该分支的处理方式
            raise NotImplementedError
        else:
            # 对非样本进行处理
            # 调整key的形状以匹配块结构,并进行零填充以适应模型要求
            key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
            key = torch.nn.functional.pad(key, (0, 0, 1, 0))  # 在最后一维上进行零填充,确保形状为(batch_size, blocks, embed_dim)

            # 调整value的形状以匹配块结构,并进行零填充以适应模型要求
            value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
            value = torch.nn.functional.pad(value, (0, 0, 1, 0))  # 在最后一维上进行零填充,确保形状为(batch_size, blocks, embed_dim)

            # 使用自定义的注意力函数dense_attn进行注意力计算,并重新调整输出的形状以匹配输入value的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 计算分散注意力摘要
    def summary_spread_attn(self, query, key, value, sample):
        # 获取模型的块数和分散度大小
        blocks = self.blocks
        spread = self.spread

        batch_size, seq_len, embed_dim = value.shape  # 获取值的形状,其中值的形状为(batch_size, seq_len, embed_dim),用于sample情况下,query_len= 1, key_len = value_len = sample_t
        if sample:
            # 对样本进行处理,目前未实现该分支的处理方式
            raise NotImplementedError
        else:
            # 对非样本进行处理
            # 调整key的形状以匹配块结构并减少尾部的spread,然后进行零填充和连续化处理以适应模型要求
            key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
            key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous()  # 在维度1和2上进行零填充,确保形状为(batch_size, blocks * spread, embed_dim)

            # 调整value的形状以匹配块结构并减少尾部的spread,然后进行零填充和连续化处理以适应模型要求
            value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
            value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous()  # 在维度1和2上进行零填充,确保形状为(batch_size, blocks * spread, embed_dim)

            # 使用自定义的注意力函数dense_attn进行注意力计算,并重新调整输出的形状以匹配输入value的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 计算主要注意力摘要
    def prime_attn(self, query, key, value, sample):
        # 获取编码器长度
        encoder_len = self._encoder_len

        # 调整key和value的形状以匹配编码器长度,并返回dense_attn函数计算的结果
        key = key[:, :encoder_len]
        value = value[:, :encoder_len]
        return self.dense_attn(query, key, value, sample)
    # 根据给定的隐藏状态张量计算查询、键、值
    def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 如果存在上一个编码器的隐藏状态,则抛出类型错误
        if last_encoder_hidden_states is not None:
            raise TypeError("last_encoder_hidden_states should be None")

        # 将隐藏状态张量按照最后一个维度分成查询、键、值三部分
        query, key, value = hidden_states.chunk(3, dim=2)
        
        # 如果需要进行采样
        if sample:
            # 增加采样计数器
            self.sample_t += curr_ctx
            # 将键和值追加到缓存中
            key, value = self._append_cache(key, value)
            # 计算当前缓存长度
            l_cache = self._suff_cache_len()
            # 如果整体缓存长度超过阈值,进行缓存切片
            if self._cache_len() > l_cache:
                self._slice_cache(-l_cache)
            # 如果当前上下文大于1
            if curr_ctx > 1:
                # 如果注意力函数不是 "dense_attn",对查询、键、值进行块填充
                if self.attn_func != "dense_attn":
                    query = self._pad_to_block_ctx(query, query=True)
                    key = self._pad_to_block_ctx(key)
                    value = self._pad_to_block_ctx(value)
                # 禁用采样标志
                sample = False
            else:
                # 如果当前上下文为1,则从缓存中获取键和值
                key = self.cache["key"]
                value = self.cache["value"]

        # 返回查询、键、值以及采样标志
        return query, key, value, sample

    # 根据给定的隐藏状态张量计算查询、键、值
    def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 如果存在上一个编码器的隐藏状态,则抛出类型错误
        if last_encoder_hidden_states is not None:
            raise TypeError("last_encoder_hidden_states should be None")
        
        # 将隐藏状态张量按照最后一个维度分成查询、键、值三部分
        query, key, value = hidden_states.chunk(3, dim=2)
        
        # 如果需要进行采样
        if sample:
            # 如果缓存长度小于编码器长度,则将键和值追加到缓存中
            if self._cache_len() < self._encoder_len:
                self._append_cache(key, value)
            # 如果缓存长度大于编码器长度,则对缓存进行切片操作
            if self._cache_len() > self._encoder_len:
                self._slice_cache(0, self._encoder_len)
            # 从缓存中获取键和值
            key, value = self.cache["key"], self.cache["value"]
            # 增加采样计数器
            self.sample_t += curr_ctx
        
        # 返回查询、键、值以及采样标志
        return query, key, value, sample

    # 根据给定的隐藏状态张量计算查询、键、值
    def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 将隐藏状态作为查询
        query = hidden_states
        
        # 如果需要进行采样
        if sample:
            # 如果采样计数器为0,则从编码器的隐藏状态生成键和值,并存入缓存
            if self.sample_t == 0:
                self.cache["key"], self.cache["value"] = self.c_enc_kv(
                    last_encoder_hidden_states.type_as(hidden_states)
                ).chunk(2, dim=2)
            # 从缓存中获取键和值
            key, value = self.cache["key"], self.cache["value"]
            # 增加采样计数器
            self.sample_t += curr_ctx
        else:
            # 否则,根据给定的隐藏状态生成键和值
            key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2)
        
        # 返回查询、键、值以及采样标志
        return query, key, value, sample
    # 定义一个方法,用于进行前向传播计算
    def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文的长度
        curr_ctx = hidden_states.shape[1]
        # 对输入的隐藏状态应用注意力机制
        hidden_states = self.c_attn(hidden_states)
        # 使用查询、键、值进行注意力机制的计算
        query, key, value, sample = self.qkv(
            hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
        )
        # 计算注意力分数
        attention_scores = self.attn(query, key, value, sample)
        # 如果注意力分数的长度与当前上下文长度不一致,则进行偏移操作
        if attention_scores.shape[1] != curr_ctx:
            offset = self._offset(curr_ctx)
            attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous()
        # 应用变换投影到输出空间
        attention_scores = self.c_proj(attention_scores)
        # 应用残差连接的dropout操作并返回结果
        return self.resid_dropout(attention_scores)

    # 定义一个属性方法,用于获取编码器的长度
    @property
    def _encoder_len(self):
        # 获取编码器长度属性
        encoder_len = self.encoder_len
        # 计算编码器块的数量
        encoder_blocks = (encoder_len // self.blocks) + 1
        # 返回调整后的编码器长度
        return encoder_blocks * self.blocks

    # 定义一个方法,用于计算偏移量
    def _offset(self, curr_ctx):
        # 如果使用密集注意力机制,则返回0
        if self.attn_func == "dense_attn":
            return 0
        # 否则,计算偏移量并返回
        return (self.sample_t - curr_ctx) % self.block_ctx

    # 定义一个方法,用于将隐藏状态填充到块上下文的长度
    def _pad_to_block_ctx(self, hidden_states, query=False):
        # 获取序列长度
        seq_len = hidden_states.shape[1]
        # 如果是查询,则计算偏移量
        offset = self._offset(seq_len) if query else 0
        # 计算块的数量
        n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx
        # 计算填充的长度
        pad = n_blocks * self.block_ctx - seq_len - offset
        # 如果无需填充,则直接返回隐藏状态
        if pad == 0 and offset == 0:
            return hidden_states
        else:
            # 否则,对隐藏状态进行填充并返回
            return F.pad(hidden_states, (0, 0, offset, pad))

    # 定义一个方法,用于获取缓存的长度
    def _cache_len(self):
        # 如果缓存中没有键值对,则返回0;否则返回键的长度
        return 0 if "key" not in self.cache else self.cache["key"].shape[1]

    # 定义一个方法,用于获取必要的缓存长度
    def _suff_cache_len(self):
        """
        前提条件:
            键和值已经附加了当前上下文,并且self.sample_t反映了上下文中的1索引样本位置。
        """
        # 计算前一个块的长度
        previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx
        # 定义必要的缓存长度字典
        REQUIRED_CACHE_LEN = {
            "dense_attn": self.sample_t,
            "block_attn": (self.sample_t - 1) % self.block_ctx + 1,
            "transpose_block_attn": self.sample_t,
            "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length,
            "cross_attn": self.encoder_len,
            "prime_attn": min(self.sample_t, self._encoder_len),
        }
        # 返回根据注意力机制类型选择的必要缓存长度
        return REQUIRED_CACHE_LEN[self.attn_func]

    # 定义一个方法,用于对缓存进行切片
    def _slice_cache(self, start, end=None):
        # 对键和值缓存进行切片操作
        self.cache["key"] = self.cache["key"][:, start:end]
        self.cache["value"] = self.cache["value"][:, start:end]
    # 将键值对添加到缓存中,如果键不存在则创建新的缓存项,否则更新现有缓存项
    def _append_cache(self, key, value):
        # 检查缓存中是否已存在键
        if "key" not in self.cache:
            # 如果不存在,则将提供的键和值存入缓存
            self.cache["key"] = key
            self.cache["value"] = value
        else:
            # 如果存在,则合并现有键值和新的键值对,并更新缓存
            old_key, old_value = key, value
            key = torch.cat([self.cache["key"], old_key], dim=1)
            value = torch.cat([self.cache["value"], old_value], dim=1)
            # 删除旧的键和值以释放内存
            del self.cache["key"]
            del self.cache["value"]
            del old_key
            del old_value
            # 更新缓存的键和值
            self.cache["key"] = key
            self.cache["value"] = value
        # 返回更新后的键和值
        return self.cache["key"], self.cache["value"]
    
    # 清空缓存中的所有项,并重置样本计数器
    def del_cache(self):
        self.sample_t = 0  # 重置样本计数器为0
        if "key" in self.cache:
            del self.cache["key"]  # 删除缓存中的键
        if "value" in self.cache:
            del self.cache["value"]  # 删除缓存中的值
        self.cache = {}  # 清空整个缓存字典
class JukeboxBlock(nn.Module):
    # JukeboxBlock 类,用于实现一个模块
    def __init__(self, config, n_ctx, attn_func="dense_attn"):
        super().__init__()
        # 设置模块的宽度为隐藏层大小
        self.width = config.hidden_size
        # 创建 JukeboxAttention 对象,并存储在 self.attn 中
        self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func)

        # 创建第一个 Layer Normalization 层,并存储在 self.layer_norm_0 中
        self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size)
        # 创建 JukeboxMLP 对象,并存储在 self.mlp 中
        self.mlp = JukeboxMLP(config)
        # 创建第二个 Layer Normalization 层,并存储在 self.layer_norm_1 中
        self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size)
        # 设置残差比例,如果启用注意力残差缩放,为 1/层数,否则为 1.0
        self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0
        # 存储注意力函数名称
        self.attn_func = attn_func

    def forward(self, hidden_states, last_encoder_hidden_states, sample=False):
        # 复制输入的隐藏状态作为残差
        residuals = hidden_states
        # 应用第一个 Layer Normalization 层
        hidden_states = self.layer_norm_0(hidden_states)
        # 应用注意力机制,并更新隐藏状态
        hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample)

        # 计算输出状态,结合残差和更新后的隐藏状态
        output_states = self.layer_norm_1(residuals + hidden_states)
        # 应用 MLP 层
        output_states = self.mlp(output_states)
        # 计算最终输出,结合残差、更新后的隐藏状态和 MLP 输出
        if self.res_scale == 1.0:
            output = residuals + hidden_states + output_states
        else:
            output = residuals + self.res_scale * (hidden_states + output_states)
        return output


class JukeboxLayerStack(nn.Module):
    # JukeboxLayerStack 类,用于堆叠多个 JukeboxBlock 模块
    def __init__(self, config, n_ctx):
        super().__init__()
        # 初始化上下文长度和宽度为隐藏层大小
        self.n_ctx = n_ctx
        self.width = config.hidden_size
        # 设置层数和块数
        self.num_layers = config.num_layers
        self.blocks = config.blocks
        # 设置注意力模式
        self.attention_pattern = config.attention_pattern
        # 如果定义了块数,则计算每个块的上下文长度
        if self.blocks is not None:
            self.block_ctx = n_ctx // self.blocks
        # 设置编码器长度
        self.encoder_len = config.nb_relevant_lyric_tokens
        # 设置头数
        self.n_heads = config.n_heads

        # 根据注意力模式创建注意力模块列表
        attention_pattern = ATTENTION_PATTERNS[self.attention_pattern]
        self._attn_mods = nn.ModuleList()
        for depth in range(self.num_layers):
            # 向 _attn_mods 列表添加 JukeboxBlock 模块
            self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth)))

        # 用于存储注意力权重
        self.saved_attn_weights = []

    def set_record_attn(self, record_attn):
        """
        设置是否记录注意力 softmax 到 self.saved_attn_weights 中。

        Args:
            record_attn (`Union[bool,set]`):
                若为 set 类型,表示要记录哪些层的注意力 softmax;若为 bool 类型,表示是否全部记录。
        """
        # 判断是否记录每一层的注意力 softmax
        def _should_record_attn(layer_idx):
            if isinstance(record_attn, bool):
                return record_attn
            return layer_idx in record_attn

        # 设置每个层的注意力记录属性
        for i, layer in enumerate(self._attn_mods):
            layer.attn.record_attn = _should_record_attn(i)

        # 若不记录任何注意力 softmax,则清空 self.saved_attn_weights
        if not record_attn:
            self.saved_attn_weights = []
    # 前向传播函数,用于处理隐藏状态和可能的编码器最后隐藏状态,支持采样
    def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 遍历注意力层模块列表
        for i, attn_layer in enumerate(self._attn_mods):
            # 如果当前注意力层为跨注意力机制,即跨编码器-解码器注意力
            if attn_layer.attn_func == "cross_attention":  # attend to the lyrics
                # 执行跨注意力机制,将当前隐藏状态和最后编码器隐藏状态作为参数传入
                hidden_states = attn_layer(
                    hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
                )
            else:
                # 否则,执行普通的注意力机制,不使用编码器的隐藏状态
                hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample)
            # 如果当前注意力层记录了注意力权重
            if attn_layer.attn.record_attn:
                # 将当前注意力层的注意力权重保存到列表中
                self.saved_attn_weights.append(attn_layer.attn.c_attn.weight)
        # 返回处理后的隐藏状态
        return hidden_states

    # 删除缓存函数,用于清空所有注意力层的缓存
    def del_cache(self):
        # 遍历所有注意力层模块
        for attn_layer in self._attn_mods:
            # 调用注意力层对象的删除缓存方法
            attn_layer.attn.del_cache()
class JukeboxPositionalEmbedding(nn.Module):
    # JukeboxPositionalEmbedding 类定义,继承自 nn.Module
    def __init__(self, embed_dim, width):
        # 初始化方法
        super().__init__()
        # 创建一个可学习的参数 pos_emb,其形状为 (embed_dim, width)
        self.pos_emb = nn.Parameter(torch.empty((embed_dim, width)))

    def forward(self):
        # 前向传播方法
        pos_emb = self.pos_emb
        # 返回位置嵌入参数 pos_emb
        return pos_emb


class JukeboxConditionalAutoregressive(nn.Module):
    # JukeboxConditionalAutoregressive 类定义,继承自 nn.Module
    def __init__(
        self,
        config,
        n_ctx=None,
        embed_dim=None,
        audio_conditioning=False,
        metadata_conditioning=False,
        is_encoder=False,
    ):
        # 初始化方法,接受多个参数,包括模型配置、上下文长度、嵌入维度等
        super().__init__()
        # 此处缺少进一步的代码,可能涉及模型的具体定义和初始化,需查看完整代码以添加详细注释
        """
        Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly
        set fro each configuration.

        Args:
            config (`JukeboxPriorConfig`):
                Model configuration class with all the parameters of the model. Initializing with a config file does
                not load the weights associated with the model, only the configuration. Check out the
                [`~PreTrainedModel.from_pretrained`] method to load the model weights.
            n_ctx (`int`, *optional*):
                Number of tokens or lyrics tokens provided in a single pass.
            embed_dim (`int`, *optional*):
                Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension,
                if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder
            audio_conditioning (`bool`, *optional`, defaults to `False`):
                Whether or not the prior supports conditioning on audio.
            metadata_conditioning (`bool`, *optional`, defaults to `False`):
                Whether or not the prior supports conditioning on artist, genres, lyrics, and timing.
            is_encoder (`bool`, *optional`, defaults to `False`):
                Whether the model is an encoder only model.
        """

        # Initialize the class inheriting from nn.Module
        super().__init__()

        # Set the width of the model from the configuration
        self.width = config.hidden_size
        # Set the number of layers from the configuration
        self.num_layers = config.num_layers
        # Set the context length from the provided argument or the configuration
        self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx
        # Set the embedding dimension based on the argument or configuration's music vocabulary size
        self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size

        # Initialize embedding tokens using nn.Embedding with embed_dim and hidden_size from configuration
        self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size)
        # Apply dropout to embed_tokens based on config's embedding dropout rate
        self.embed_tokens_dropout = nn.Dropout(config.emb_dropout)

        # Set metadata and audio conditioning flags
        self.metadata_conditioning = metadata_conditioning
        self.audio_conditioning = audio_conditioning

        # If metadata_conditioning is False, initialize start_token as a learnable parameter
        if not metadata_conditioning:
            self.start_token = nn.Parameter(torch.empty((1, config.hidden_size)))

        # Initialize positional embedding using JukeboxPositionalEmbedding with n_ctx and hidden_size
        self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size)
        # Apply dropout to positional embedding based on config's embedding dropout rate
        self.pos_emb_dropout = nn.Dropout(config.emb_dropout)

        # Initialize transformer layer stack using JukeboxLayerStack with config and n_ctx
        self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx)
        # Set whether the model is an encoder based on is_encoder flag
        self.is_encoder = is_encoder
        # Set encoder length from configuration's relevant lyric tokens count
        self.encoder_len = config.nb_relevant_lyric_tokens

        # Conditional setups based on config's merged_decoder flag
        if config.merged_decoder:
            self.add_cond_after_transformer = False
            self.share_embed_tokens_fc_proj_out = False
        else:
            self.add_cond_after_transformer = True
            self.share_embed_tokens_fc_proj_out = True

        # If not an encoder, initialize output projection layer and loss function
        if not is_encoder:
            # Linear projection layer from hidden_size to embed_dim
            self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False)
            # If sharing embed tokens and fc_proj_out weights, synchronize them
            if self.share_embed_tokens_fc_proj_out:
                self.fc_proj_out.weight = self.embed_tokens.weight
            # Cross-entropy loss function initialization
            self.loss = torch.nn.CrossEntropyLoss()
    def forward(
        self,
        tokens,
        audio_conditioning=None,
        metadata_conditioning=None,
        last_encoder_hidden_states=None,
        get_preds=False,
        get_acts=False,
        get_sep_loss=False,
    ):
        """
        Args:
            tokens (`torch.tensor`):
                Can represent music tokens, lyrics tokens or both, depending on the configuration.
        """
        # Preprocess.
        batch_size = tokens.shape[0]  # 获取批处理大小
        with torch.no_grad():
            tokens = tokens.view(batch_size, -1).long()  # 转换 tokens 的形状

        if not self.audio_conditioning:
            # 如果没有音频条件,则创建全零的音频条件张量
            audio_conditioning = torch.zeros(
                (batch_size, 1, self.width),
                device=tokens.device,
                dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype,
            )

        target = tokens  # 目标 tokens
        hidden_states = self.embed_tokens(tokens)  # 嵌入 tokens
        # Shift by 1, and fill in start token
        hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1)  # 将 tokens 向右移动一个位置,并填充起始 token
        if self.metadata_conditioning:
            hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width)  # 如果有元数据条件,则使用元数据条件
        else:
            hidden_states[:, 0] = self.start_token  # 否则使用预定义的起始 token

        hidden_states = (
            self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning
        )  # 添加嵌入 tokens 的 dropout、位置编码的 dropout 和音频条件

        hidden_states = self.transformer(
            hidden_states, last_encoder_hidden_states=last_encoder_hidden_states
        )  # 应用 transformer 模型进行编码

        if self.add_cond_after_transformer:  # 如果在 transformer 后添加条件
            hidden_states = hidden_states + audio_conditioning  # 添加音频条件

        activations = hidden_states  # 激活值等于隐藏状态
        if self.is_encoder:
            return hidden_states  # 如果是编码器,直接返回隐藏状态

        hidden_states = self.fc_proj_out(hidden_states)  # 使用全连接层进行预测
        loss_fn = nn.CrossEntropyLoss()  # 使用交叉熵损失函数

        if get_sep_loss:
            # 如果需要单独计算损失
            lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim)
            token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim)

            lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0)  # 计算歌词部分的损失
            music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0)  # 计算音乐 token 部分的损失

            loss = (lyric_loss, music_token_loss)  # 返回歌词损失和音乐 token 损失
        else:
            loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0)  # 计算整体损失

        if get_preds:
            return loss, hidden_states  # 如果需要预测,返回损失和隐藏状态
        elif get_acts:
            return loss, activations  # 如果需要激活值,返回损失和激活值
        else:
            return loss, None  # 否则只返回损失
    # 定义一个方法,用于获取嵌入表示
    def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning):
        # 如果是第一个样本
        if sample_t == 0:
            # 创建一个空的张量用于存储隐藏状态,形状为 (n_samples, 1, self.width),数据类型与权重张量相同,并移到相同的设备上
            hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to(
                self.embed_tokens.weight.device
            )
            # 如果有元数据条件
            if self.metadata_conditioning:
                # 将元数据条件视图重塑为 (n_samples, self.width),并赋值给隐藏状态的第一个位置
                hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width)
            else:
                # 否则将起始标记赋值给隐藏状态的第一个位置
                hidden_states[:, 0] = self.start_token
        else:
            # 对于非第一个样本,使用嵌入的 token 表示 tokens
            hidden_states = self.embed_tokens(tokens)
        
        # 如果音频条件的形状与期望的形状相同
        if audio_conditioning.shape == (n_samples, self.n_ctx, self.width):
            # 则将对应的音频条件切片赋给 cond,形状为 (n_samples, 1, self.width)
            cond = audio_conditioning[:, sample_t : sample_t + 1, :]
        else:
            # 否则直接使用原始的音频条件
            cond = audio_conditioning
        
        # 添加位置嵌入和音频条件到隐藏状态中,位置嵌入在评估时的 dropout 是恒等映射
        hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond
        
        # 返回更新后的隐藏状态和条件
        return hidden_states, cond
        ):
        # 如果未指定采样的 tokens 数量,则使用默认值 self.n_ctx
        if sample_tokens is None:
            sample_tokens = self.n_ctx

        # 如果不需要音频调节,则创建一个全零张量作为音频调节
        if not self.audio_conditioning:
            audio_conditioning = torch.zeros(
                (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
            ).to(self.fc_proj_out.device)

        # 禁止梯度更新
        with torch.no_grad():
            sampled_tokens = []
            tokens = None
            if get_preds:
                preds = []

            # 使用 tqdm 创建进度条迭代器
            iter = tqdm(range(0, sample_tokens), leave=False)
            for sample_t in iter:
                iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True)
                # 获取嵌入向量和条件
                hidden_states, cond = self.get_emb(
                    sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning
                )

                # 使用 transformer 进行前向传播
                hidden_states = self.transformer(
                    hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
                )
                # 如果设置了在 transformer 后添加条件
                if self.add_cond_after_transformer:
                    hidden_states = hidden_states + cond
                # 使用全连接层进行预测
                hidden_states = self.fc_proj_out(hidden_states)  # Predictions
                # 如果需要获取预测值,则保存预测结果
                if get_preds:
                    preds.append(hidden_states.clone())
                # 调整 logits 的值
                hidden_states = hidden_states / temp
                hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
                # 从 logits 中采样生成 tokens
                tokens = torch.distributions.Categorical(logits=hidden_states).sample()
                sampled_tokens.append(tokens.clone())

            del tokens
            # 清除 transformer 的缓存
            self.transformer.del_cache()

            # 拼接所有采样的 tokens
            tokens = torch.cat(sampled_tokens, dim=1)
            if get_preds:
                preds = torch.cat(preds, dim=1)
        # 如果需要获取预测值,则返回 tokens 和 preds
        if get_preds:
            return tokens, preds
        # 否则,只返回 tokens
        else:
            return tokens

    def split_chunks(self, length, chunk_size):
        # 计算分块的数量
        n_passes = (length + chunk_size - 1) // chunk_size
        # 计算每个分块的大小列表
        chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
        return chunk_sizes

    def primed_sample(
        self,
        n_samples,
        lyric_and_music_tokens,
        audio_conditioning=None,
        metadata_conditioning=None,
        last_encoder_hidden_states=None,
        temp=1.0,
        top_k=0,
        top_p=0.0,
        get_preds=False,
        chunk_size=None,
        sample_tokens=None,
class JukeboxMusicTokenConditioner(nn.Module):
    """
    The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's
    codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE).
    """

    def __init__(self, config, level):
        super().__init__()
        # Initialize an embedding layer for music tokens based on vocabulary size and hidden size
        self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size)
        # Set the embed_dim attribute in config to music_vocab_size for compatibility with JukeboxDecoder
        config.embed_dim = config.music_vocab_size  # setting correct argument for the `JukeboxDecoder`

        # Initialize the upsampler using a custom convolutional block
        self.upsampler = JukeboxDecoderConvBock(
            config,
            config.hidden_size,
            config.res_conv_width,
            config.res_conv_depth,
            config.res_downs_t[level],
            config.res_strides_t[level],
            reverse_dilation=False,
        )
        # Initialize layer normalization for the hidden states
        self.layer_norm = JukeboxLayerNorm(config.hidden_size)

    def forward(self, music_tokens, raw_audio_conditionning=None):
        """
        Args:
            music_tokens (`torch.LongTensor`):
                Music tokens form the uper level in range(nb_discrete_codes)
            raw_audio_conditionning (`torch.LongTensor`, *optional*):
                Audio used when primed sampling, raw audio information that conditions the generation
        """
        # Set default value for raw_audio_conditioning if not provided
        if raw_audio_conditionning is None:
            raw_audio_conditionning = 0.0
        # Convert music_tokens to long type
        music_tokens = music_tokens.long()
        # Embed music_tokens using the previously initialized embedding layer
        hidden_states = self.embed_tokens(music_tokens)
        # Add raw_audio_conditioning to the embedded music tokens
        hidden_states = hidden_states + raw_audio_conditionning

        # Permute dimensions for upsampling
        hidden_states = hidden_states.permute(0, 2, 1)
        # Apply the upsampler to the permuted hidden states
        hidden_states = self.upsampler(hidden_states)
        # Permute dimensions back to original shape
        hidden_states = hidden_states.permute(0, 2, 1)
        # Apply layer normalization to the processed hidden states
        hidden_states = self.layer_norm(hidden_states)
        # Return the normalized hidden states
        return hidden_states


class JukeboxRangeEmbedding(nn.Module):
    """
    The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional
    embedding of length `n_ctx`.

    Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end)
    -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <=
    end
    """

    def __init__(self, n_time, embed_dim, range, out_width, clamp=False):
        super().__init__()
        # Initialize an embedding layer with size embed_dim and output width out_width
        self.emb = nn.Embedding(embed_dim, out_width)
        self.n_time = n_time
        self.embed_dim = embed_dim
        # Define positional range [pos_min, pos_max]
        self.pos_min, self.pos_max = range
        self.clamp = clamp
    # 定义一个方法用于将位置起始点和结束点进行前向传播
    def forward(self, pos_start, pos_end=None):
        # 检查 pos_start 的形状是否为二维
        if not len(pos_start.shape) == 2:
            raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}")
        # 检查 pos_start 是否在指定范围 [pos_min, pos_max) 内
        if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all():
            raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}")

        # 将 pos_start 转换为 float 类型
        pos_start = pos_start.float()
        # 如果 pos_end 不为 None
        if pos_end is not None:
            # 如果设置了 clamp 标志,将 pos_end 限制在 pos_min 和 pos_max 范围内
            if self.clamp:
                pos_end = pos_end.clamp(self.pos_min, self.pos_max)

            # 将 pos_end 转换为 float 类型
            pos_end = pos_end.float()

        # 计算插值以使得 [pos_start, ..., pos_end] <-> 长度为 n_ctx 的位置张量
        n_time = self.n_time
        if n_time != 1:
            # 生成插值张量,用于在 pos_start 到 pos_end 之间进行线性插值
            interpolation = (
                torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time
            )
            position = pos_start + (pos_end - pos_start) * interpolation
        else:
            position = pos_start

        # 将位置归一化到 [0, 1] 范围内
        normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min)
        # 将归一化后的位置映射到 bins_,用于离散化表示
        bins_ = (self.embed_dim * normalised_position).floor().long().detach()
        # 返回根据 bins_ 索引得到的嵌入向量
        return self.emb(bins_)
class JukeboxLabelConditioner(nn.Module):
    def __init__(self, config, include_time_signal):
        super().__init__()

        embed_dim = config.hidden_size  # 从配置中获取隐藏单元的维度
        timing_dims = config.timing_dims  # 从配置中获取时间维度
        sampling_rate = config.sampling_rate  # 从配置中获取采样率
        nb_genres, nb_artists = config.metadata_dims  # 从配置中获取流派和艺术家的维度
        music_tokens_shape = config.n_ctx  # 从配置中获取音乐令牌的形状

        self.max_nb_genres = config.max_nb_genres  # 设置最大流派数量
        self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim)  # 创建流派嵌入层
        self.artist_emb = nn.Embedding(nb_artists, embed_dim)  # 创建艺术家嵌入层
        self.include_time_signal = include_time_signal  # 设置是否包含时间信号的标志
        if self.include_time_signal:
            # 如果包含时间信号,设置总长度范围、绝对位置范围和相对位置范围
            total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate)
            absolute_pos_range = (0.0, config.max_duration * sampling_rate)
            relative_pos_range = (0.0, 1.0)
            # 创建总长度、绝对位置和相对位置的嵌入层
            self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim)
            self.absolute_pos_emb = JukeboxRangeEmbedding(
                music_tokens_shape, timing_dims, absolute_pos_range, embed_dim
            )
            self.relative_pos_emb = JukeboxRangeEmbedding(
                music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True
            )

    def forward(self, metadata):
        total_length = metadata[:, 0:1]  # 提取元数据中的总长度
        offset = metadata[:, 1:2]  # 提取元数据中的偏移量
        length = metadata[:, 2:3]  # 提取元数据中的长度
        artist = metadata[:, 3:4]  # 提取元数据中的艺术家
        genre = metadata[:, 4:]  # 提取元数据中的流派

        # 起始嵌入,长度为1
        artist_emb = self.artist_emb(artist)  # 计算艺术家的嵌入表示
        # 空的流派插槽用-1表示,对其进行屏蔽处理
        mask = (genre >= 0).float().unsqueeze(2)  # 创建流派屏蔽掩码
        genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)  # 计算流派的嵌入表示
        start_emb = genre_emb + artist_emb  # 合并艺术家和流派的嵌入表示作为起始嵌入

        # 位置嵌入,长度为n_ctx
        if self.include_time_signal:
            start, end = offset, offset + length  # 计算起始和结束位置
            total_length = total_length.float()  # 将总长度转换为浮点数
            start = start.float()  # 将起始位置转换为浮点数
            end = end.float()  # 将结束位置转换为浮点数
            # 计算总长度、绝对位置和相对位置的嵌入表示
            pos_emb = (
                self.total_length_emb(total_length)
                + self.absolute_pos_emb(start, end)
                + self.relative_pos_emb(start / total_length, end / total_length)
            )
        else:
            pos_emb = None  # 如果不包含时间信号,则位置嵌入为None
        return start_emb, pos_emb  # 返回起始嵌入和位置嵌入
    # 定义一个类变量,指定配置类为 JukeboxPriorConfig
    config_class = JukeboxPriorConfig

    # 初始化模型权重的方法,接受一个模块作为参数
    def _init_weights(self, module):
        # 从配置中获取初始化比例
        init_scale = self.config.init_scale

        # 如果模块是 nn.Embedding 类型
        if isinstance(module, nn.Embedding):
            # 对权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
            module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
        
        # 如果模块是 JukeboxConv1D 类型
        elif isinstance(module, JukeboxConv1D):
            # 如果配置中指定需要将权重置零
            if self.config.zero_out:
                # 将权重数据置零
                module.weight.data.zero_()
            else:
                # 否则对权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
                module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
        
        # 如果模块是 JukeboxPositionalEmbedding 类型
        elif isinstance(module, JukeboxPositionalEmbedding):
            # 对位置嵌入数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
            module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale)
        
        # 如果模块是 JukeboxRangeEmbedding 类型
        elif isinstance(module, JukeboxRangeEmbedding):
            # 对范围嵌入的权重数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
            module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale)
        
        # 如果模块是 JukeboxConditionalAutoregressive 类型,并且具有 lm_head 属性
        elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"):
            # 对 lm_head 的权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
            module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
        
        # 如果模块是 JukeboxConditionalAutoregressive 类型,并且具有 start_token 属性
        elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"):
            # 对 start_token 的数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
            module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale)
        
        # 如果模块是 JukeboxResConv1DBlock 类型,并且配置中指定需要置零
        elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
            # 将 conv1d_2 的权重和偏置数据置零
            module.conv1d_2.weigth.data.zero_()
            module.conv1d_2.bias.data.zero_()
        
        # 如果模块是 nn.LayerNorm 类型
        if isinstance(module, nn.LayerNorm):
            # 将偏置数据置零
            module.bias.data.zero_()
            # 将权重数据填充为 1.0
            module.weight.data.fill_(1.0)
        
        # 如果模块是 nn.Linear 类型,并且具有偏置
        if isinstance(module, nn.Linear) and module.bias is not None:
            # 将偏置数据置零
            module.bias.data.zero_()
    ````
        def get_metadata(self, labels, start, total_length, offset, get_indices=False):
            # 克隆 labels 张量以避免修改原始数据,创建 metadata 张量
            metadata = labels.clone()
            # 设置 metadata 的第一列为总长度
            metadata[:, 0] = total_length
            # 设置 sample_length 列以匹配当前层级的样本长度
            metadata[:, 2] = int(self.sample_length)
    
            # 设置偏移量,计算偏移量在 token 的索引
            metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens)
            # 由于 metadata 包含完整的 token_list,只需选择相关的部分
    
            # 设置歌词 token,调用 set_metadata_lyric_tokens 方法
            metadata, indices = self.set_metadata_lyric_tokens(metadata)
            # 根据 get_indices 参数返回 metadata 和 indices
            if get_indices:
                return metadata, indices
            else:
                return metadata
    
        def set_metadata_lyric_tokens(self, labels):
            """
            处理完整的标签,只提取相关的歌词 token,并保持元数据的条件 token。
            """
            # 如果有相关的歌词 token
            if self.nb_relevant_lyric_tokens > 0:
                # 初始化 tokens_list 张量,尺寸为 (labels 行数, nb_relevant_lyric_tokens),数据类型为 long,设备为 labels 的设备
                tokens_list = torch.zeros(
                    (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device
                )
                indices_list = []  # 存储原始数组中每个字符的索引
                # 遍历每一行标签数据
                for idx in range(labels.shape[0]):
                    # 克隆 labels 的所有行,但不包括前四列和元数据嵌入的最大生成数量
                    full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :]
                    total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2]
                    # 获取相关的歌词 token 和其索引
                    tokens, indices = get_relevant_lyric_tokens(
                        full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration
                    )
                    # 将获取的 tokens 存入 tokens_list
                    tokens_list[idx, :] = tokens
                    indices_list.append(indices)
    
                # 返回更新后的 labels 和索引列表,合并原标签的前几列与 tokens_list
                return (
                    torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1),
                    indices_list,
                )
            else:
                # 如果没有相关的歌词 token,直接返回原 labels 和 None
                return labels, None
    
        def get_music_tokens_conds(self, music_tokens, start, end):
            """
            提取当前层级的条件音乐 token。
            """
            # 如果不是第一层级
            if self.level != 0:
                # 获取上一层级的音乐 token 条件
                music_tokens_cond = music_tokens[self.level - 1]
                # 根据 start 和 end 索引提取音乐 token
                music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample]
                # 计算缺失的条件长度
                missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1]
                # 如果有缺失的条件长度,填充零
                if missing_cond_len > 0:
                    init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device)
                    music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long()
                # 返回处理后的音乐 token 条件列表
                music_tokens_conds = [music_tokens_cond]
            else:
                music_tokens_conds = None
            return music_tokens_conds
    def prior_preprocess(self, tokens, conds):
        """
        Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music
        tokens should be shifted by. It is equal to `lyric_vocab_size`.
        """
        # 获取批次大小
        batch_size = tokens[0].shape[0]
        # 对每个输入的 token 进行偏移处理
        for i in range(len(tokens)):
            tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1)

        # 对每个条件进行处理,如果条件为 None,则用零填充
        for i in range(len(conds)):
            if conds[i] is None:
                conds[i] = torch.zeros(
                    (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device
                )

        # 将处理后的 tokens 和 conds 拼接起来返回
        return torch.cat(tokens, dim=1), torch.cat(conds, dim=1)

    def prior_postprocess(self, tokens):
        """
        Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is
        shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music
        tokens.
        """
        # 获取批次大小
        batch_size = tokens.shape[0]
        # 划分 tokens 为列表,按照指定维度进行切分
        dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0])
        tokens = list(torch.split(tokens, dims, dim=1))

        # 对每个切分后的 token 进行逆向处理,将其偏移值减去
        for i in range(len(tokens)):
            bins_shift = int(self.embed_dim_shift[i])
            tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1)
            tokens[i] = torch.clamp(tokens[i], min=0)
            # 如果不屏蔽损失,模型可能生成的歌词/音符 token 可能会被 bin_shift 偏移小于0
        # 返回处理后的最后一个 token
        return tokens[-1]

    def embed_tokens(self, music_tokens_conds):
        """
        Embeds the upper level music tokens and upsamples them to provide as audio conditioning.
        """
        # 仅处理 music_tokens_conds 中指定条件级别以下的内容
        music_tokens_conds = music_tokens_conds[: self.cond_level + 1]
        audio_conditioning = None
        # 对 music_tokens_conds 和条件块进行逆向处理
        for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))):
            audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning)
        # 返回音频条件化结果
        return audio_conditioning

    def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1):
        """
        Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states.
        """
        # 如果未指定起始级别,则使用默认级别
        if start_level is None:
            start_level = self.level
        # 如果未指定结束级别,则使用默认级别
        if end_level is None:
            end_level = self.levels
        # 使用 VQVAE 编码器获取潜在状态
        with torch.no_grad():
            latent_states = self.vqvae_encoder(
                hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
            )
        # 返回潜在状态
        return latent_states
    # 使用给定的音乐令牌解码成原始音频序列
    def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1):
        """
        Usamples the sequence of codebook vectors to a raw audio.
        """
        # 如果未指定起始级别,默认使用对象的级别
        if start_level is None:
            start_level = self.level
        # 如果未指定结束级别,默认使用对象的级别数
        if end_level is None:
            end_level = self.levels
        # 使用禁用梯度环境运行以下代码
        with torch.no_grad():
            # 调用 VQ-VAE 解码器来生成输出
            output = self.vqvae_decoder(
                music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
            )
        return output

    # 获取条件信息,将音乐令牌转换为输入嵌入。将歌词与其余元数据分开,歌词令牌可以为 None。
    def get_cond(self, music_tokens_conds, metadata):
        """
        Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens
        can be None.
        """
        # 如果存在元数据,则从中提取标签和歌词令牌
        if metadata is not None:
            n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens
            metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:]
        else:
            # 否则设置元数据和歌词令牌为 None
            metadata, lyric_tokens = None, None
        # 根据是否有元数据条件,生成相应的元数据嵌入和位置编码
        metadata_conditioning, metadata_pos = (
            self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None)
        )
        # 根据音频条件设置,生成音频条件输入嵌入或者使用元数据位置编码
        audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos
        return audio_conditioning, metadata_conditioning, lyric_tokens

    # 对模型进行采样生成
    def sample(
        self,
        n_samples,
        music_tokens=None,
        music_tokens_conds=None,
        metadata=None,
        temp=1.0,
        top_k=0,
        top_p=0.0,
        chunk_size=None,
        sample_tokens=None,
    ):
        # 该函数在这里不完整,需要在此处添加代码以完成功能

    # 获取编码器状态,提取将由解码器关注的歌词编码器的最后隐藏状态。通过歌词编码器向前传播。
    def get_encoder_states(self, lyric_tokens, sample=False):
        """
        Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through
        the lyric encoder.
        """
        # 如果存在相关歌词令牌且歌词条件为真
        if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
            # 如果需要进行采样,则将编码器转移到设备上
            if sample:
                self.encoder = self.encoder.to(lyric_tokens.device)
            # 通过编码器生成歌词编码活动
            lyric_acts = self.encoder(lyric_tokens, None, None, None)
            # 将歌词编码活动投影到输入空间
            lyric_acts = self.encoder.proj_in(lyric_acts)
            # 对编码后的结果进行最终的层归一化处理
            last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts)
        else:
            # 否则将最终隐藏状态设置为 None
            last_encoder_hidden_states = None
        return last_encoder_hidden_states

    # 获取编码器损失,计算歌词编码器的损失:下一个歌词令牌的预测。
    def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics):
        """
        Computes the loss for the lyric encoder: next lyric token prediction.
        """
        # 如果启用了歌词条件
        if self.lyric_conditioning:
            # 对最终隐藏状态进行语言模型头部处理
            last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states)
            # 计算交叉熵损失
            encoder_loss = nn.functional.cross_entropy(
                last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1)
            ) / np.log(2.0)
        else:
            # 否则将编码器损失设置为 0
            encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device)
        return encoder_loss
    def forward_tokens(
        self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False
    ):
        """
        Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the
        vqvae's encoding layers.
        """
        # 如果需要记录注意力权重,则设置记录
        if get_attn_weights:
            self.prior.transformer.set_record_attn(get_attn_weights)
        
        # 获取音频、元数据条件和歌词 token
        audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)

        # 如果模型是编码-解码器结构
        if self.is_encoder_decoder:
            # 预处理歌词和音乐 token,返回的 tokens 包含歌词和音乐 token,audio_conditioning 也被修改
            tokens, audio_conditioning = self.prior_preprocess(
                [lyric_tokens, music_tokens], [None, audio_conditioning]
            )
            # 使用 prior 模型进行前向传播,包括获取预测值和分离损失
            (encoder_loss, next_token_prediction_loss), preds = self.prior(
                tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds
            )
        else:
            # 获取最后一个编码器隐藏状态
            last_encoder_hidden_states = self.get_encoder_states(lyric_tokens)
            # 计算编码器损失
            encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens)
            # 使用 prior 模型进行前向传播,获取下一个 token 预测损失和预测值
            next_token_prediction_loss, preds = self.prior(
                music_tokens,
                audio_conditioning,
                metadata_conditioning,
                last_encoder_hidden_states,
                get_preds=get_preds,
            )
        
        # 计算总损失,包括编码器损失和下一个 token 预测损失
        loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims
        loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims

        # 定义需要返回的指标
        metrics = {
            "bpd": next_token_prediction_loss.clone().detach(),
            "encoder_loss": encoder_loss.clone().detach(),
            "next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
        }
        
        # 如果需要返回预测值,则加入指标中
        if get_preds:
            metrics["preds"] = preds.clone().detach()
        
        # 如果需要记录注意力权重,将保存的注意力权重返回并关闭记录
        if get_attn_weights:
            saved_attn_weights = self.prior.transformer.saved_attn_weights
            self.prior.transformer.set_record_attn(False)
            return saved_attn_weights
        else:
            # 否则返回计算得到的损失和指标
            return loss, metrics
    ) -> List[torch.Tensor]:
        """
        Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens`
        function. The loss is the sum of the `encoder` loss and the `decoder` loss.

        Args:
            hidden_states (`torch.Tensor`):
                Hidden states which should be raw audio
            metadata (`List[torch.LongTensor]`, *optional*):
                List containing the metadata conditioning tensor with the lyric and the metadata tokens.
            decode (`bool`, *optional*, defaults to `False`):
                Whether or not to decode the encoded to tokens.
            get_preds (`bool`, *optional*, defaults to `False`):
                Whether or not to return the actual predictions of the model.
        """
        # 获取批处理的大小
        batch_size = hidden_states.shape[0]
        # 使用 `vqvae` 编码器对隐藏状态(原始音频)进行编码,得到音乐 tokens 和可能的条件
        music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size)
        # 调用 forward_tokens 函数计算损失和指标
        loss, metrics = self.forward_tokens(
            music_tokens=music_tokens,
            music_tokens_conds=music_tokens_conds,
            metadata=metadata,
            get_preds=get_preds,
        )
        # 如果需要解码,则使用 `vqvae` 解码器进行解码
        if decode:
            dequantised_states = self.decode([music_tokens, *music_tokens_conds])
        else:
            dequantised_states = None
        # 返回解码后的状态、损失和指标
        return dequantised_states, loss, metrics
class JukeboxPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 使用 JukeboxConfig 作为配置类
    config_class = JukeboxConfig
    # 基础模型前缀为 "jukebox"
    base_model_prefix = "jukebox"
    # 不支持梯度检查点
    supports_gradient_checkpointing = False

    def _init_weights(self, module):
        # 如果模块是 JukeboxPrior 或者 JukeboxVQVAE 类型,调用其 _init_weights 方法进行初始化
        if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE):
            module.apply(module._init_weights)

    def __init__(self, *inputs, **kwargs):
        # 调用父类的构造函数
        super().__init__(*inputs, **kwargs)


# JUKEBOX_SAMPLING_INPUT_DOCSTRING 是用于描述采样输入的文档字符串常量
JUKEBOX_SAMPLING_INPUT_DOCSTRING = r"""
            labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` :
                List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to
                condition the generation.
            sampling_kwargs (`Dict[Any]`):
                Various additional sampling arguments that are used by the `_sample` function. A detail list of the
                arguments can bee seen in the [`_sample`] function documentation.
"""


@add_start_docstrings(
    """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`,
    `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If
    you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior
    individually.
    """,
    JUKEBOX_START_DOCSTRING,
)
class JukeboxModel(JukeboxPreTrainedModel):
    _no_split_modules = ["JukeboxBlock"]

    def __init__(self, config):
        # 调用父类构造函数,并传入配置对象
        super().__init__(config)
        # 使用给定的 vqvae_config 初始化 JukeboxVQVAE 对象
        vqvae_config = config.vqvae_config
        self.vqvae = JukeboxVQVAE(vqvae_config)
        # 设置共享参数
        self.set_shared_params(config)
        # 初始化 priors 列表,每个元素为 JukeboxPrior 类的实例
        self.priors = nn.ModuleList(
            [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)]
        )

    def set_shared_params(self, model_config):
        """
        Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig`
        is nest, and is thus unreachable in the `from_dict` function
        """
        # 遍历 model_config.prior_configs 列表,并为每个配置对象设置共享参数
        for config in model_config.prior_configs:
            config.sampling_rate = model_config.sampling_rate
            config.timing_dims = model_config.timing_dims
            config.min_duration = model_config.min_duration
            config.max_duration = model_config.max_duration
            config.max_nb_genres = model_config.max_nb_genres
            config.metadata_conditioning = model_config.metadata_conditioning

    def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1):
        # 调用 vqvae 对象的 decode 方法进行音乐解码
        return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks)
    # 调用 VQ-VAE 模型的 encode 方法对输入音频进行编码
    def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
        return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks)

    # 将对象 obj 拆分成大小为 split_size 的 batch,总共 n_samples 个样本
    def split_batch(self, obj, n_samples, split_size):
        # 计算需要多少个 passes 才能处理完所有样本
        n_passes = (n_samples + split_size - 1) // split_size
        if isinstance(obj, torch.Tensor):  # 如果 obj 是 torch.Tensor 类型
            return torch.split(obj, split_size, dim=0)  # 在 dim=0 上拆分 Tensor
        elif isinstance(obj, list):  # 如果 obj 是 list 类型
            # 对 list 中的每个元素分别在 dim=0 上进行拆分,并将结果打包成 list 返回
            return list(zip(*[torch.split(item, split_size, dim=0) for item in obj]))
        elif obj is None:  # 如果 obj 为 None
            # 返回包含 n_passes 个 None 的列表
            return [None] * n_passes
        else:
            # 抛出类型错误异常
            raise TypeError("Unknown input type")

    # 在 level 层级上,从 music_tokens 中采样一个长度小于 n_ctx 的部分窗口,新增 tokens_to_sample 个新标记
    def sample_partial_window(
        self, music_tokens, labels, offset, sampling_kwargs, level, tokens_to_sample, max_batch_size
    ):
        prior = self.priors[level]  # 获取指定层级的 prior 模型
        sampled_tokens = music_tokens[level]  # 获取在指定层级的音乐 tokens
        n_ctx = prior.n_ctx  # 获取 prior 模型的上下文长度
        nb_sampled_tokens = sampled_tokens.shape[1]  # 获取已采样 tokens 的数量

        if nb_sampled_tokens < n_ctx - tokens_to_sample:
            # 如果已采样 tokens 的数量小于 n_ctx - tokens_to_sample
            sampling_kwargs["sample_tokens"] = nb_sampled_tokens + tokens_to_sample
            start = 0
        else:
            # 否则设置采样的 tokens 数量为 n_ctx
            sampling_kwargs["sample_tokens"] = n_ctx
            start = nb_sampled_tokens - n_ctx + tokens_to_sample

        # 调用 sample_single_window 方法进行单个窗口的采样
        return self.sample_single_window(music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size)

    # 在 level 层级上,从 start 位置开始采样一个长度为 n_ctx 的单个窗口
    # 从先验分布中采样一个单窗口的音乐序列片段
    def sample_single_window(self, music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size):
        # 获取当前层级的先验分布
        prior = self.priors[level]
        # 获取音乐片段的总数
        n_samples = music_tokens[0].shape[0]
        # 获取先验分布中的上下文长度
        n_ctx = prior.n_ctx
        # 计算当前片段的结束位置
        end = start + n_ctx
        # 获取已经在当前层级采样的音乐片段
        previous_sampled_tokens = music_tokens[level][:, start:end]

        # 从采样参数中获取要采样的令牌数
        sample_tokens = sampling_kwargs.get("sample_tokens", None)
        if "sample_tokens" in sampling_kwargs:
            sample_tokens = end - start

        # 计算当前条件下的令牌数量
        conditioning_tokens = previous_sampled_tokens.shape[1]
        # 计算新采样的令牌数量
        new_tokens = sample_tokens - previous_sampled_tokens.shape[1]

        # 记录采样信息日志
        logger.info(
            f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on"
            f" {conditioning_tokens} tokens"
        )

        # 如果没有新的令牌需要采样,则直接返回原始音乐令牌
        if new_tokens <= 0:
            return music_tokens

        # 获取上一层级的音乐令牌条件
        music_tokens_conds = prior.get_music_tokens_conds(music_tokens, start, end)
        # 如果没有上一层级,应该返回None!

        # 设置元数据的偏移量、采样长度和歌词令牌
        metadata = prior.get_metadata(labels, start, self.total_length, offset)

        # 将音乐令牌、音乐令牌条件和元数据拆分成批次
        music_tokens_list = self.split_batch(previous_sampled_tokens, n_samples, max_batch_size)
        music_tokens_conds_list = self.split_batch(music_tokens_conds, n_samples, max_batch_size)
        metadata_list = self.split_batch(metadata, n_samples, max_batch_size)
        tokens = []
        # 迭代处理每个批次的音乐令牌和条件
        iterator = tqdm(zip(music_tokens_list, music_tokens_conds_list, metadata_list), leave=False)
        for music_tokens_i, music_tokens_conds_i, metadata_i in iterator:
            # 确定当前使用的名称("祖先"或"主导"),基于是否有音乐令牌条件
            name = ["Ancestral", "Primed"][music_tokens_i.shape[1] == 0]
            iterator.set_description(
                f"[prior level {level}] {name} Sampling {sample_tokens} tokens out of"
                f" {self.total_length//prior.raw_to_tokens}",
                refresh=True,
            )
            # 从先验分布中采样音乐令牌
            tokens_i = prior.sample(
                n_samples=music_tokens_i.shape[0],
                music_tokens=music_tokens_i,
                music_tokens_conds=music_tokens_conds_i,
                metadata=metadata_i,
                **sampling_kwargs,
            )
            tokens.append(tokens_i)
        # 将所有采样的音乐令牌连接起来
        sampled_tokens = torch.cat(tokens, dim=0)

        # 更新音乐令牌序列,加入新的采样片段
        music_tokens_new = sampled_tokens[:, -new_tokens:]
        music_tokens[level] = torch.cat([music_tokens[level], music_tokens_new], dim=1)
        return music_tokens

    # 以指定级别、总长度和跳跃长度进行采样
    def sample_level(
        self, music_tokens, labels, offset, sampling_kwargs, level, total_length, hop_length, max_batch_size
    ):
        # 如果总长度超过当前先验模型的上下文长度
        if total_length >= self.priors[level].n_ctx:
            # 获取起始位置迭代器,根据指定的步长和先验模型的上下文长度
            iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length)
            # 对于迭代器中的每个起始位置
            for start in iterator:
                # 对单个窗口进行采样
                music_tokens = self.sample_single_window(
                    music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size
                )

        else:
            # 对部分窗口进行采样,因为总长度小于当前先验模型的上下文长度
            music_tokens = self.sample_partial_window(
                music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size
            )
        # 返回采样后的音乐 tokens
        return music_tokens

    @torch.no_grad()
    def _sample(
        self,
        music_tokens,
        labels,
        sample_levels,
        metas=None,
        chunk_size=32,
        sampling_temperature=0.98,
        lower_batch_size=16,
        max_batch_size=16,
        sample_length_in_seconds=24,
        compute_alignments=False,
        sample_tokens=None,
        offset=0,
        save_results=True,
        sample_length=None,
    ):
        # 添加文档字符串作为函数注释,描述生成音乐 tokens 的过程
        @add_start_docstrings(
            """
            Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically
            upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use
            the VQ-VAE decoder to convert the music tokens to raw audio.

            Args:
                labels (`List[torch.LongTensor]`) :
                    List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
                    lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
                    which are used to condition the generation.
                n_samples (`int`, *optional*, default to 1) :
                    Number of samples to be generated in parallel.
            """,
        )
    def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]:
        """
        Example:

        ```
        >>> from transformers import AutoTokenizer, JukeboxModel, set_seed

        >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")

        >>> lyrics = "Hey, are you awake? Can you talk to me?"
        >>> artist = "Zac Brown Band"
        >>> genre = "Country"
        >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics)
        >>> set_seed(0)
        >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400)

        >>> with torch.no_grad():
        ...     model.decode(music_tokens)[:, :10].squeeze(-1)
        tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405,
            -0.0818, -0.0697]])
        ```
        """

        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
        # 初始化一个空的音乐 tokens 列表,用于存储采样后的结果
        music_tokens = [
            torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors))
        ]
        # 使用 _sample 方法进行采样生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens

    @add_start_docstrings(
        """Generates a continuation of the previously generated tokens.

        Args:
            music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
                A sequence of music tokens which will be used as context to continue the sampling process. Should have
                `self.levels` tensors, each corresponding to the generation at a certain level.
        """,
        JUKEBOX_SAMPLING_INPUT_DOCSTRING,
    )
    def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
        # 使用 _sample 方法继续生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens

    @add_start_docstrings(
        """Upsamples a sequence of music tokens using the prior at level `level`.

        Args:
            music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
                A sequence of music tokens which will be used as context to continue the sampling process. Should have
                `self.levels` tensors, each corresponding to the generation at a certain level.
        """,
        JUKEBOX_SAMPLING_INPUT_DOCSTRING,
    )
    def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度减一)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1)))
        # 使用 _sample 方法上采样生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens
    @add_start_docstrings(
        """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the
        generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are
        used: as conditioning for each level, which means that no ancestral sampling is required.

        Args:
            raw_audio (`List[torch.Tensor]` of length `n_samples` ) :
                A list of raw audio that will be used as conditioning information for each samples that will be
                generated.
        """,
        JUKEBOX_SAMPLING_INPUT_DOCSTRING,
    )

这是一个装饰器函数,用于给 `primed_sample` 方法添加文档字符串。文档字符串描述了函数的作用、参数和返回值。


    def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]:

定义了一个名为 `primed_sample` 的方法,用于生成基于提供的 `raw_audio` 条件的原始音频。返回一个列表,其中每个元素是包含音乐 token 的 torch LongTensor。


        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))

从 `sampling_kwargs` 中获取 `sample_levels` 参数,如果不存在则默认为 `self.priors` 的长度范围内的列表。


        self.vqvae.to(raw_audio.device).float()

将 `self.vqvae` 移动到 `raw_audio` 的设备上,并将其数据类型转换为 float。


        with torch.no_grad():

进入一个禁用梯度跟踪的上下文管理器,以确保在该部分代码中不会进行梯度计算。


            music_tokens = self.vqvae.encode(
                raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0]
            )

使用 `self.vqvae` 对 `raw_audio` 进行编码,生成音乐 token。使用从 0 到 `len(self.priors)` 的级别作为起始和结束级别,并根据 `raw_audio` 的形状分块处理。


        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)

调用 `_sample` 方法,使用 `music_tokens`、`labels` 和 `sample_levels` 进行采样,传递额外的 `sampling_kwargs`。


        return music_tokens

返回生成的音乐 token 列表。

.\models\jukebox\tokenization_jukebox.py

# 引入所需的库和模块
import json  # 导入处理 JSON 格式的模块
import os    # 导入操作系统相关功能的模块
import re    # 导入正则表达式模块
import unicodedata  # 导入 Unicode 数据库模块
from json.encoder import INFINITY  # 从 JSON 库中导入 INFINITY 常量
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示相关的功能

import numpy as np  # 导入 NumPy 库,用于数值计算
import regex       # 导入 regex 库,支持更强大的正则表达式功能

# 从 tokenization_utils 模块导入 AddedToken 和 PreTrainedTokenizer 类
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 从 tokenization_utils_base 模块导入 BatchEncoding 类
from ...tokenization_utils_base import BatchEncoding
# 从 utils 模块导入 TensorType, is_flax_available, is_tf_available, is_torch_available, logging 等功能
from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
# 从 utils.generic 模块导入 _is_jax 和 _is_numpy 函数
from ...utils.generic import _is_jax, _is_numpy

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)

# 定义各种文件名与其对应的词汇表文件名
VOCAB_FILES_NAMES = {
    "artists_file": "artists.json",   # 艺术家信息的 JSON 文件名
    "lyrics_file": "lyrics.json",     # 歌词信息的 JSON 文件名
    "genres_file": "genres.json",     # 音乐流派信息的 JSON 文件名
}

# 预训练词汇文件映射表
PRETRAINED_VOCAB_FILES_MAP = {
    "artists_file": {
        "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/artists.json",  # 艺术家信息的预训练 URL
    },
    "genres_file": {
        "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/genres.json",   # 音乐流派信息的预训练 URL
    },
    "lyrics_file": {
        "jukebox": "https://huggingface.co/ArthurZ/jukebox/blob/main/lyrics.json",   # 歌词信息的预训练 URL
    },
}

# 预训练歌词 token 大小
PRETRAINED_LYRIC_TOKENS_SIZES = {
    "jukebox": 512,   # Jukebox 模型的歌词 token 大小为 512
}

# JukeboxTokenizer 类,继承自 PreTrainedTokenizer
class JukeboxTokenizer(PreTrainedTokenizer):
    """
    构造 Jukebox 分词器。Jukebox 可以根据三种不同的输入条件进行条件化:
        - 艺术家:每个艺术家关联的唯一 ID 存储在提供的字典中。
        - 音乐流派:每种流派关联的唯一 ID 存储在提供的字典中。
        - 歌词:基于字符的分词。必须初始化使用词汇表中包含的字符列表。

    该分词器不需要训练。它应该能够处理不同数量的输入:
    因为模型的条件化可以在三种不同的查询上完成。如果未提供任何值,则将使用默认值。

    根据应该条件化模型的流派数量(`n_genres`)而定。

    参数:
        - PreTrainedTokenizer:继承自父类 PreTrainedTokenizer 的构造函数。

    示例用法:
    ```
    >>> from transformers import JukeboxTokenizer

    >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
    >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
    [tensor([[   0,    0,    0, 6785,  546,   41,   38,   30,   76,   46,   41,   49,
               40,   76,   44,   41,   27,   30]]), tensor([[  0,   0,   0, 145,   0]]), tensor([[  0,   0,   0, 145,   0]])]
    ```
    ```
    # 你可以通过在实例化这个分词器时或在调用它处理文本时传递 `add_prefix_space=True` 来避免这种行为,但由于模型不是以这种方式预训练的,可能会导致性能下降。
    
    # 提示信息

    # 如果未提供任何内容,流派和艺术家将随机选择或设置为 None。

    # 这个分词器继承自 [`PreTrainedTokenizer`],其中包含大多数主要方法。用户应参考该超类以获取有关这些方法的更多信息。

    # 然而,代码不允许这样做,只支持从各种流派组成。

    # 参数说明:
    # artists_file (`str`):
    #     包含艺术家与其ID映射的词汇文件的路径。默认文件支持 "v2" 和 "v3"。
    # genres_file (`str`):
    #     包含流派与其ID映射的词汇文件的路径。
    # lyrics_file (`str`):
    #     包含歌词分词接受字符的词汇文件的路径。
    # version (`List[str]`, 可选, 默认为 `["v3", "v2", "v2"]`) :
    #     分词器版本列表。`5b-lyrics` 的顶级优先模型使用 `v3` 而不是 `v2` 进行训练。
    # n_genres (`int`, 可选, 默认为 5):
    #     用于组合的最大流派数。
    # max_n_lyric_tokens (`int`, 可选, 默认为 512):
    #     保留的最大歌词分词数量。
    # unk_token (`str`, 可选, 默认为 `"<|endoftext|>"`):
    #     未知标记。词汇表中没有的标记将无法转换为ID,并被设置为此标记。
    """

    # 定义类级别的属性

    # 词汇文件名列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 预训练词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 预训练歌词分词器的最大输入尺寸
    max_lyric_input_size = PRETRAINED_LYRIC_TOKENS_SIZES
    # 模型的输入名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        artists_file,
        genres_file,
        lyrics_file,
        version=["v3", "v2", "v2"],
        max_n_lyric_tokens=512,
        n_genres=5,
        unk_token="<|endoftext|>",
        **kwargs,
    ):
    ):
        # 如果 unk_token 是字符串,则创建一个 AddedToken 对象,保留字符串两侧空白字符
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        # 设置模型版本号
        self.version = version
        # 设置歌词最大 token 数量
        self.max_n_lyric_tokens = max_n_lyric_tokens
        # 设置流派数量
        self.n_genres = n_genres
        # 初始化未知 token 的解码器
        self._added_tokens_decoder = {0: unk_token}

        # 读取并加载艺术家编码器(JSON 格式)
        with open(artists_file, encoding="utf-8") as vocab_handle:
            self.artists_encoder = json.load(vocab_handle)

        # 读取并加载流派编码器(JSON 格式)
        with open(genres_file, encoding="utf-8") as vocab_handle:
            self.genres_encoder = json.load(vocab_handle)

        # 读取并加载歌词编码器(JSON 格式)
        with open(lyrics_file, encoding="utf-8") as vocab_handle:
            self.lyrics_encoder = json.load(vocab_handle)

        # 正则表达式模式,用于识别词汇表中的未知字符
        oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
        # 在 v2 版本中,我们的 n_vocab=80,但在 v3 中我们遗漏了 +,所以现在 n_vocab=79 个字符。
        # 如果歌词编码器长度为 79,则更新正则表达式以包括额外的字符 '-'
        if len(self.lyrics_encoder) == 79:
            oov = oov.replace(r"\-'", r"\-+'")

        # 编译正则表达式模式,用于匹配词汇表中的未知字符
        self.out_of_vocab = regex.compile(oov)
        # 创建艺术家的解码器,将编码器的键值对反转
        self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
        # 创建流派的解码器,将编码器的键值对反转
        self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
        # 创建歌词的解码器,将编码器的键值对反转
        self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
        # 调用父类的初始化方法,传递参数给父类
        super().__init__(
            unk_token=unk_token,
            n_genres=n_genres,
            version=version,
            max_n_lyric_tokens=max_n_lyric_tokens,
            **kwargs,
        )

    @property
    def vocab_size(self):
        # 返回总的词汇量大小,包括艺术家、流派和歌词的编码器的长度之和
        return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)

    def get_vocab(self):
        # 返回包含艺术家、流派和歌词编码器的字典
        return {
            "artists_encoder": self.artists_encoder,
            "genres_encoder": self.genres_encoder,
            "lyrics_encoder": self.lyrics_encoder,
        }

    def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
        """Converts the artist, genre and lyrics tokens to their index using the vocabulary.
        The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
        the lyrics token sequence.
        """
        # 将艺术家标签转换为它们在编码器中的索引
        artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
        # 将流派标签转换为它们在编码器中的索引,并在需要时添加填充标记
        for genres in range(len(list_genres)):
            list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
            list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))

        # 将歌词字符转换为它们在编码器中的索引,每个歌词位置(如 total_length、offset、duration)提供相应的歌词
        lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
        return artists_id, list_genres, lyric_ids
    # 将字符串 lyrics 转换为标记序列(字符串),使用指定的标记器。
    # 如果是基于词汇的,按单词拆分;如果是基于子词(如BPE/SentencePieces/WordPieces),则按子词拆分。
    # 对于基于字符的词汇表,仅将歌词拆分成字符。
    def _tokenize(self, lyrics):
        """
        Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).

        Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
        """
        # 仅对歌词进行拆分,如果是基于字符的词汇表,这很容易处理
        return list(lyrics)

    # 使用标记器将艺术家、流派和歌词转换为标记序列的三元组
    def tokenize(self, artist, genre, lyrics, **kwargs):
        """
        Converts three strings in a 3 sequence of tokens using the tokenizer
        """
        # 准备艺术家、流派和歌词以进行标记化
        artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
        # 将歌词转换为标记序列
        lyrics = self._tokenize(lyrics)
        return artist, genre, lyrics

    # 准备艺术家、流派和歌词以进行标记化
    def prepare_for_tokenization(
        self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
    ):
    ) -> Tuple[str, str, str, Dict[str, Any]]:
        """
        Performs any necessary transformations before tokenization.

        Args:
            artist (`str`):
                The artist name to prepare. This will mostly lower the string
            genres (`str`):
                The genre name to prepare. This will mostly lower the string.
            lyrics (`str`):
                The lyrics to prepare.
            is_split_into_words (`bool`, *optional*, defaults to `False`):
                Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
                tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
                which it will tokenize. This is useful for NER or token classification.
        """
        # 循环遍历版本列表,进行必要的转换操作
        for idx in range(len(self.version)):
            # 如果版本为 "v3",将艺术家和流派名称转换为小写
            if self.version[idx] == "v3":
                artists[idx] = artists[idx].lower()
                genres[idx] = [genres[idx].lower()]
            else:
                # 如果版本不为 "v3",对艺术家名称进行标准化处理并添加后缀 ".v2",对流派名称进行拆分并添加后缀 ".v2"
                artists[idx] = self._normalize(artists[idx]) + ".v2"
                genres[idx] = [
                    self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
                ]  # split is for the full dictionary with combined genres

        # 如果版本为 "v2",设置处理非词汇表外字符的正则表达式和词汇表
        if self.version[0] == "v2":
            self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
            vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
            # 创建词汇表和词汇表的索引
            self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
            self.vocab["<unk>"] = 0
            self.n_vocab = len(vocab) + 1
            self.lyrics_encoder = self.vocab
            self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
            self.lyrics_decoder[0] = ""
        else:
            # 如果版本不为 "v2",设置处理非词汇表外字符的正则表达式
            self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")

        # 运行去除文本中重音符号的函数
        lyrics = self._run_strip_accents(lyrics)
        # 替换文本中的 "\\" 为换行符 "\n"
        lyrics = lyrics.replace("\\", "\n")
        # 使用正则表达式去除文本中的非词汇表外字符,并初始化两个空列表
        lyrics = self.out_of_vocab.sub("", lyrics), [], []
        # 返回处理后的艺术家名称、流派名称和歌词
        return artists, genres, lyrics

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 使用 unicodedata 库规范化文本中的 Unicode 字符,去除重音符号
        text = unicodedata.normalize("NFD", text)
        output = []
        # 遍历文本中的每个字符
        for char in text:
            # 获取字符的 Unicode 分类
            cat = unicodedata.category(char)
            # 如果字符的分类为 "Mn"(非重音符号),跳过该字符
            if cat == "Mn":
                continue
            # 将符合条件的字符添加到输出列表中
            output.append(char)
        # 将输出列表中的字符连接成字符串并返回
        return "".join(output)
    # 定义一个方法,用于规范化输入的文本。这个过程适用于音乐流派和艺术家名称。

    def _normalize(self, text: str) -> str:
        """
        Normalizes the input text. This process is for the genres and the artist

        Args:
            text (`str`):
                Artist or Genre string to normalize
        """
        # 定义可接受的字符集,包括小写字母、大写字母、数字和点号
        accepted = (
            [chr(i) for i in range(ord("a"), ord("z") + 1)]
            + [chr(i) for i in range(ord("A"), ord("Z") + 1)]
            + [chr(i) for i in range(ord("0"), ord("9") + 1)]
            + ["."]
        )
        accepted = frozenset(accepted)  # 将字符集转换为不可变集合以提高性能
        pattern = re.compile(r"_+")  # 编译用于匹配多个下划线的正则表达式模式
        # 将文本转换为小写,并替换不在接受字符集中的字符为下划线
        text = "".join([c if c in accepted else "_" for c in text.lower()])
        text = pattern.sub("_", text).strip("_")  # 将多个连续的下划线替换为单个下划线,并去除首尾的下划线
        return text  # 返回规范化后的文本字符串

    # 定义一个方法,将歌词令牌列表转换为一个字符串
    def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
        return " ".join(lyrics)

    # 定义一个方法,用于将输入转换为张量(Tensor),可以选择添加批次轴
        """
        Convert the inner content to tensors.

        Args:
            tensor_type (`str` or [`~utils.TensorType`], *optional*):
                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
                unset, no modification is done.
            prepend_batch_axis (`int`, *optional*, defaults to `False`):
                Whether or not to add the batch dimension during the conversion.
        """
        # Convert to TensorType
        if not isinstance(tensor_type, TensorType):
            # 如果 `tensor_type` 不是 `TensorType` 类型的实例,则转换为 `TensorType`
            tensor_type = TensorType(tensor_type)

        # Get a function reference for the correct framework
        if tensor_type == TensorType.TENSORFLOW:
            # 如果 `tensor_type` 是 `TensorType.TENSORFLOW`
            if not is_tf_available():
                # 检查 TensorFlow 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError(
                    "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
                )
            import tensorflow as tf

            # 使用 TensorFlow 的 constant 函数
            as_tensor = tf.constant
            # 使用 TensorFlow 的 is_tensor 函数
            is_tensor = tf.is_tensor
        elif tensor_type == TensorType.PYTORCH:
            # 如果 `tensor_type` 是 `TensorType.PYTORCH`
            if not is_torch_available():
                # 检查 PyTorch 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
            import torch

            # 使用 PyTorch 的 tensor 函数
            as_tensor = torch.tensor
            # 使用 PyTorch 的 is_tensor 函数
            is_tensor = torch.is_tensor
        elif tensor_type == TensorType.JAX:
            # 如果 `tensor_type` 是 `TensorType.JAX`
            if not is_flax_available():
                # 检查 JAX 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
            import jax.numpy as jnp  # noqa: F811

            # 使用 JAX 的 array 函数
            as_tensor = jnp.array
            # 使用自定义的 `_is_jax` 函数
            is_tensor = _is_jax
        else:
            # 默认情况下使用 NumPy 的 asarray 函数
            as_tensor = np.asarray
            # 使用自定义的 `_is_numpy` 函数
            is_tensor = _is_numpy

        # Do the tensor conversion in batch
        # 在批处理中进行张量转换

        try:
            if prepend_batch_axis:
                # 如果 `prepend_batch_axis` 为真,则在 `inputs` 前面添加一个批次维度
                inputs = [inputs]

            # 如果 `inputs` 不是张量,则使用 `as_tensor` 将其转换为张量
            if not is_tensor(inputs):
                inputs = as_tensor(inputs)
        except:  # noqa E722
            # 捕获所有异常,通常用于处理可能的数值或类型转换问题
            raise ValueError(
                "Unable to create tensor, you should probably activate truncation and/or padding "
                "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
            )

        return inputs
    def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
        """Convert the raw string to a list of token ids

        Args:
            artist (`str`):
                Name of the artist.
            genres (`str`):
                List of genres that will be mixed to condition the audio
            lyrics (`str`, *optional*, defaults to `""`):
                Lyrics used to condition the generation
        """
        # 初始化输入的 token ids
        input_ids = [0, 0, 0]
        # 将 artist 复制多份,以匹配 self.version 的长度
        artist = [artist] * len(self.version)
        # 将 genres 复制多份,以匹配 self.version 的长度
        genres = [genres] * len(self.version)

        # 使用 tokenize 方法将 artist、genres 和 lyrics 转换为 tokens
        artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
        # 将 tokens 转换为对应的 ids
        artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)

        # 初始化 attention_masks 为负无穷大
        attention_masks = [-INFINITY] * len(full_tokens[-1])
        # 根据每个版本的要求,将各个 ids 组合成 input_ids,并转换为 tensors
        input_ids = [
            self.convert_to_tensors(
                [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
            )
            for i in range(len(self.version))
        ]
        # 返回 BatchEncoding 对象,包含 input_ids 和 attention_masks
        return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        Saves the tokenizer's vocabulary dictionary to the provided save_directory.

        Args:
            save_directory (`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.

            filename_prefix (`Optional[str]`, *optional*):
                A prefix to add to the names of the files saved by the tokenizer.

        """
        # 检查 save_directory 是否存在,若不存在则记录错误信息并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 将 artists_encoder 转换为 JSON 格式并保存到指定路径
        artists_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
        )
        with open(artists_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.artists_encoder, ensure_ascii=False))

        # 将 genres_encoder 转换为 JSON 格式并保存到指定路径
        genres_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
        )
        with open(genres_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.genres_encoder, ensure_ascii=False))

        # 将 lyrics_encoder 转换为 JSON 格式并保存到指定路径
        lyrics_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
        )
        with open(lyrics_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))

        # 返回保存的文件路径元组
        return (artists_file, genres_file, lyrics_file)
    def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
        """
        Converts an index (integer) in a token (str) using the vocab.

        Args:
            artists_index (`int`):
                Index of the artist in its corresponding dictionary.
            genres_index (`Union[List[int], int]`):
               Index of the genre in its corresponding dictionary. Can be a single index or a list of indices.
            lyric_index (`List[int]`):
                List of character indices, each corresponding to a character.

        Returns:
            artist (`Optional[str]`):
                Decoded artist name corresponding to artists_index.
            genres (`List[Optional[str]]`):
                List of decoded genre names corresponding to genres_index.
            lyrics (`List[Optional[str]]`):
                List of decoded characters corresponding to lyric_index.
        """
        # Retrieve artist name from artists_decoder using artists_index
        artist = self.artists_decoder.get(artists_index)
        
        # Retrieve genre names from genres_decoder for each genre index in genres_index
        genres = [self.genres_decoder.get(genre) for genre in genres_index]
        
        # Retrieve character representations from lyrics_decoder for each character index in lyric_index
        lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
        
        # Return the decoded artist name, list of decoded genres, and list of decoded characters
        return artist, genres, lyrics

.\models\jukebox\__init__.py

# 版权声明和许可信息
# 该模块受 Apache License, Version 2.0 许可,详情请访问 http://www.apache.org/licenses/LICENSE-2.0

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义异常和模块惰性加载工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_jukebox": [
        "JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP",
        "JukeboxConfig",
        "JukeboxPriorConfig",
        "JukeboxVQVAEConfig",
    ],
    "tokenization_jukebox": ["JukeboxTokenizer"],
}

# 检查是否 Torch 可用,若不可用则抛出自定义的依赖不可用异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,则添加额外的模块导入结构
    _import_structure["modeling_jukebox"] = [
        "JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST",
        "JukeboxModel",
        "JukeboxPreTrainedModel",
        "JukeboxVQVAE",
        "JukeboxPrior",
    ]

# 如果是类型检查阶段
if TYPE_CHECKING:
    # 从相应模块导入特定的类或变量
    from .configuration_jukebox import (
        JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP,
        JukeboxConfig,
        JukeboxPriorConfig,
        JukeboxVQVAEConfig,
    )
    from .tokenization_jukebox import JukeboxTokenizer

    # 再次检查 Torch 是否可用,若不可用则忽略异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 若 Torch 可用,则从 modeling_jukebox 模块导入特定类或变量
        from .modeling_jukebox import (
            JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST,
            JukeboxModel,
            JukeboxPreTrainedModel,
            JukeboxPrior,
            JukeboxVQVAE,
        )

# 如果不是类型检查阶段,则执行以下操作
else:
    # 导入 sys 模块
    import sys

    # 将当前模块定义为一个惰性加载模块
    # 使用 _LazyModule 类,传入当前模块的名称、文件路径、导入结构以及模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\kosmos2\configuration_kosmos2.py

# coding=utf-8
# 设置文件编码为UTF-8,确保支持多语言字符集

# 版权声明及许可证信息
# 版权所有 2023 Microsoft Research 和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 没有任何明示或暗示的保证或条件。
# 有关详细信息,请参阅许可证。

""" KOSMOS-2 模型配置"""

import os
from typing import Union

# 导入预训练配置类
from ...configuration_utils import PretrainedConfig
# 导入日志记录工具
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# 预训练配置存档映射字典,映射模型名称到其配置文件的下载链接
KOSMOS2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "microsoft/kosmos-2-patch14-224": (
        "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/config.json"
    ),
    # 查看所有 KOSMOS-2 模型的列表:https://huggingface.co/models?filter=kosmos-2
}


class Kosmos2TextConfig(PretrainedConfig):
    r"""
    这是一个配置类,用于存储 [`Kosmos2TextModel`] 的配置信息。根据指定的参数实例化 KOSMOS-2 文本解码器,
    定义模型架构。使用默认参数实例化配置对象将产生类似于 KOSMOS-2 文本解码器
    [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
    ```
    # 定义 Kosmos2 模型的参数和默认值

    # 模型的类型,用于标识 Kosmos2 文本模型
    model_type = "kosmos_2_text_model"

    # 推断阶段需要忽略的键列表,这些键不会在推断时使用
    keys_to_ignore_at_inference = ["past_key_values"]

    # 属性映射字典,将模型参数的名称映射到 Kosmos2 模型期望的名称
    attribute_map = {
        "num_attention_heads": "attention_heads",  # 注意力头的数量
        "hidden_size": "embed_dim",  # 隐藏层的维度
        "num_hidden_layers": "layers",  # Transformer 编码器中的隐藏层数量
    }
    # 初始化函数,用于创建一个新的配置对象
    def __init__(
        self,
        vocab_size=65037,                    # 词汇表大小,默认为65037
        max_position_embeddings=2048,        # 最大位置嵌入数量,默认为2048
        embed_dim=2048,                      # 嵌入维度,默认为2048
        layers=24,                           # 层数,默认为24
        ffn_dim=8192,                        # 前馈神经网络维度,默认为8192
        attention_heads=32,                  # 注意力头数,默认为32
        activation_function="gelu",          # 激活函数,默认为"gelu"
        dropout=0.1,                         # 普通层级dropout概率,默认为0.1
        attention_dropout=0.1,               # 注意力模块dropout概率,默认为0.1
        activation_dropout=0.0,              # 激活函数dropout概率,默认为0.0
        layerdrop=0.0,                       # 层级dropout概率,默认为0.0
        layer_norm_eps=1e-5,                 # 层归一化的epsilon,默认为1e-5
        init_std=0.02,                       # 初始化标准差,默认为0.02
        scale_embedding=True,                # 是否缩放嵌入,默认为True
        use_cache=True,                      # 是否使用缓存,默认为True
        pad_token_id=1,                      # 填充标记ID,默认为1
        bos_token_id=0,                      # 开始序列标记ID,默认为0
        eos_token_id=2,                      # 结束序列标记ID,默认为2
        **kwargs,                            # 其他关键字参数
    ):
        # 调用父类的初始化方法,设置填充、开始、结束标记ID等参数
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )

        # 初始化配置对象的各个属性
        self.vocab_size = vocab_size                         # 设置词汇表大小属性
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置嵌入数量属性
        self.embed_dim = embed_dim                           # 设置嵌入维度属性
        self.layers = layers                                 # 设置层数属性
        self.ffn_dim = ffn_dim                               # 设置前馈神经网络维度属性
        self.attention_heads = attention_heads               # 设置注意力头数属性
        self.activation_function = activation_function       # 设置激活函数属性
        self.dropout = dropout                               # 设置普通层级dropout概率属性
        self.attention_dropout = attention_dropout           # 设置注意力模块dropout概率属性
        self.activation_dropout = activation_dropout         # 设置激活函数dropout概率属性
        self.layerdrop = layerdrop                           # 设置层级dropout概率属性
        self.layer_norm_eps = layer_norm_eps                 # 设置层归一化的epsilon属性
        self.init_std = init_std                             # 设置初始化标准差属性
        self.scale_embedding = scale_embedding               # 设置是否缩放嵌入属性
        self.use_cache = use_cache                           # 设置是否使用缓存属性

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 将token相关参数添加到kwargs中
        cls._set_token_in_kwargs(kwargs)

        # 获取配置字典和更新后的kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果加载自Kosmos2Config,则获取文本配置字典
        if config_dict.get("model_type") == "kosmos-2":
            config_dict = config_dict["text_config"]

        # 如果配置字典中存在model_type,并且与类的model_type不同,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 从配置字典和kwargs创建类的实例
        return cls.from_dict(config_dict, **kwargs)
# 定义 `Kosmos2VisionConfig` 类,用于存储 `Kosmos2VisionModel` 的配置信息。
# 继承自 `PretrainedConfig`,用于控制模型的输出。详细信息请参考 `PretrainedConfig` 的文档。

class Kosmos2VisionConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Kosmos2VisionModel`]. It is used to instantiate a
    KOSMOS-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2
    [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        hidden_size (`int`, *optional*, defaults to 1024):
            Dimensionality of the encoder layers and the pooler layer.
        intermediate_size (`int`, *optional*, defaults to 4096):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        num_hidden_layers (`int`, *optional*, defaults to 24):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        image_size (`int`, *optional*, defaults to 224):
            The size (resolution) of each image.
        patch_size (`int`, *optional*, defaults to 14):
            The size (resolution) of each patch.
        hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
        layer_norm_eps (`float`, *optional*, defaults to 1e-5):
            The epsilon used by the layer normalization layers.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        initializer_factor (`float`, *optional*, defaults to 1):
            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
            testing).
    """
    ):
        super().__init__(**kwargs)
        # 调用父类的初始化方法,传入关键字参数

        self.hidden_size = hidden_size
        # 设置隐藏层大小

        self.intermediate_size = intermediate_size
        # 设置中间层大小

        self.num_hidden_layers = num_hidden_layers
        # 设置隐藏层数量

        self.num_attention_heads = num_attention_heads
        # 设置注意力头数量

        self.num_channels = num_channels
        # 设置通道数量

        self.patch_size = patch_size
        # 设置图像块大小

        self.image_size = image_size
        # 设置图像大小

        self.initializer_range = initializer_range
        # 设置初始化范围

        self.initializer_factor = initializer_factor
        # 设置初始化因子

        self.attention_dropout = attention_dropout
        # 设置注意力丢弃率

        self.layer_norm_eps = layer_norm_eps
        # 设置层归一化的 epsilon 值

        self.hidden_act = hidden_act
        # 设置隐藏层激活函数

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        cls._set_token_in_kwargs(kwargs)
        # 调用类方法 _set_token_in_kwargs,设置关键字参数中的 token

        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
        # 调用类方法 get_config_dict,获取预训练模型的配置字典和更新后的关键字参数

        # 如果从 Kosmos2Config 加载,则获取视觉配置字典
        if config_dict.get("model_type") == "kosmos-2":
            config_dict = config_dict["vision_config"]

        # 如果配置字典中存在 "model_type" 并且类具有 "model_type" 属性,并且它们不相同,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
            logger.warning(
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
            )

        # 从配置字典创建类的实例,并返回
        return cls.from_dict(config_dict, **kwargs)
class Kosmos2Config(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`Kosmos2Model`]. It is used to instantiate a
    KOSMOS-2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the KOSMOS-2
    [microsoft/kosmos-2-patch14-224](https://huggingface.co/microsoft/kosmos-2-patch14-224) architecture.

    Args:
        text_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Kosmos2TextConfig`].
        vision_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Kosmos2VisionConfig`].
        latent_query_num (`int`, *optional*, defaults to 64):
            The number of latent query tokens that represent the image features used in the text decoder component.
        kwargs (*optional*):
            Dictionary of keyword arguments.

    Example:

    ```
    >>> from transformers import Kosmos2Config, Kosmos2Model

    >>> # Initializing a Kosmos-2 kosmos-2-patch14-224 style configuration
    >>> configuration = Kosmos2Config()

    >>> # Initializing a model (with random weights) from the kosmos-2-patch14-224 style configuration
    >>> model = Kosmos2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    # 设定模型类型为 "kosmos-2"
    model_type = "kosmos-2"
    # 标志这个配置类是由多个部分组成
    is_composition = True

    def __init__(
        self,
        text_config=None,
        vision_config=None,
        latent_query_num=64,
        **kwargs,
    ):
        # 调用父类构造函数,传入所有额外的关键字参数
        super().__init__(**kwargs)

        # 如果文本配置为空,使用默认空字典并记录日志
        if text_config is None:
            text_config = {}
            logger.info("`text_config` is `None`. Initializing the `Kosmos2TextConfig` with default values.")

        # 如果视觉配置为空,使用默认空字典并记录日志
        if vision_config is None:
            vision_config = {}
            logger.info("`vision_config` is `None`. Initializing the `Kosmos2VisionConfig` with default values.")

        # 根据传入的文本配置初始化 `Kosmos2TextConfig` 对象
        self.text_config = Kosmos2TextConfig(**text_config)
        # 根据传入的视觉配置初始化 `Kosmos2VisionConfig` 对象
        self.vision_config = Kosmos2VisionConfig(**vision_config)

        # 设置 latent_query_num 属性,表示在文本解码器组件中用于表示图像特征的潜在查询标记数目
        self.latent_query_num = latent_query_num

.\models\kosmos2\convert_kosmos2_original_pytorch_checkpoint_to_pytorch.py

import argparse  # 导入命令行参数解析模块

from fairseq.checkpoint_utils import load_checkpoint_to_cpu  # 从fairseq库中导入加载checkpoint到CPU的函数

from transformers import Kosmos2Config, Kosmos2ForConditionalGeneration  # 从transformers库中导入Kosmos2Config和Kosmos2ForConditionalGeneration类


KEYS_TO_MODIFY_MAPPING = {
    "gpt_model.decoder.output_projection": "text_model.lm_head",  # 将"gpt_model.decoder.output_projection"映射为"text_model.lm_head"
    "gpt_model.decoder": "text_model.model",  # 将"gpt_model.decoder"映射为"text_model.model"
    "img_connector": "image_to_text_projection",  # 将"img_connector"映射为"image_to_text_projection"
    "img_model.visual.class_embedding": "vision_model.model.embeddings.class_embedding",  # 将"img_model.visual.class_embedding"映射为"vision_model.model.embeddings.class_embedding"
    "img_model.visual.positional_embedding": "vision_model.model.embeddings.position_embedding.weight",  # 将"img_model.visual.positional_embedding"映射为"vision_model.model.embeddings.position_embedding.weight"
    "img_model.visual.conv1": "vision_model.model.embeddings.patch_embedding",  # 将"img_model.visual.conv1"映射为"vision_model.model.embeddings.patch_embedding"
    "img_model.visual": "vision_model.model",  # 将"img_model.visual"映射为"vision_model.model"
    "ln_pre": "pre_layrnorm",  # 将"ln_pre"映射为"pre_layrnorm"
    "ln_post": "post_layernorm",  # 将"ln_post"映射为"post_layernorm"
    "transformer.resblocks": "encoder.layers",  # 将"transformer.resblocks"映射为"encoder.layers"
    "ts_attn": "self_attn",  # 将"ts_attn"映射为"self_attn"
    "ln_1": "layer_norm1",  # 将"ln_1"映射为"layer_norm1"
    "ln_2": "layer_norm2",  # 将"ln_2"映射为"layer_norm2"
    "c_fc": "fc1",  # 将"c_fc"映射为"fc1"
    "c_proj": "fc2",  # 将"c_proj"映射为"fc2"
}


KEYS_TO_IGNORE = [
    # 在原始代码中仅用于将权重发送到所需设备的缓冲区
    "gpt_model.decoder.embed_positions._float_tensor",
    # 在原始的KOSMOS-2代码中前向传播中从未使用过的权重
    "gpt_model.decoder.self_attn_sope.scale",
]


def rename_key(key):
    for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
        if key_to_modify in key:
            key = key.replace(key_to_modify, new_key)  # 根据映射表修改键名

    return key


def convert_kosmos2_checkpoint_to_pytorch(checkpoint_path, pytorch_dump_folder_path):
    state = load_checkpoint_to_cpu(checkpoint_path)  # 加载checkpoint到CPU
    state_dict = state["model"]  # 获取模型的state_dict
    state_dict_keys = list(state_dict.keys())  # 获取state_dict中的所有键列表

    config = Kosmos2Config()  # 创建Kosmos2Config实例
    # 为了匹配原始演示给出的结果,设置必要的配置项
    config.text_config.no_repeat_ngram_size = 3
    model = Kosmos2ForConditionalGeneration(config)  # 创建Kosmos2ForConditionalGeneration模型实例

    # 转换(通过重命名键名)
    converted_state_dict = {}
    for key in state_dict_keys:
        if key in KEYS_TO_IGNORE:
            continue  # 跳过需要忽略的键名
        renamed_key = rename_key(key)  # 根据映射重命名键名
        converted_state_dict[renamed_key] = state_dict[key]  # 更新转换后的state_dict

    # 检查权重加载
    model.load_state_dict(converted_state_dict, strict=True)  # 加载转换后的state_dict到模型
    # 保存结果
    model.save_pretrained(pytorch_dump_folder_path)  # 将模型保存为PyTorch格式的文件


if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建参数解析器
    # 必需参数
    parser.add_argument(
        "--kosmos2_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
    )
    parser.add_argument(
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()  # 解析命令行参数
    convert_kosmos2_checkpoint_to_pytorch(args.kosmos2_checkpoint_path, args.pytorch_dump_folder_path)  # 执行转换函数

.\models\kosmos2\modeling_kosmos2.py

# 设置文件编码为 UTF-8
# 版权声明,声明版权归 Microsoft Research 和 HuggingFace Inc. 团队所有
#
# 根据 Apache License, Version 2.0 许可,除非符合许可要求,否则不得使用此文件
# 您可以在以下网址获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按"原样"分发,不附带任何明示或暗示的担保或条件
# 请参阅许可证以了解特定语言的权限和限制

""" PyTorch KOSMOS-2 model."""

# 导入必要的库和模块
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

# 导入与激活函数相关的模块
from ...activations import ACT2FN
# 导入不同类型的模型输出
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPooling,
    CausalLMOutputWithCrossAttentions,
)
# 导入预训练模型的基类
from ...modeling_utils import PreTrainedModel
# 导入实用工具函数
from ...utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
# 导入 KOSMOS-2 的配置类
from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 文档用的配置对象
_CONFIG_FOR_DOC = Kosmos2Config

# 预训练模型存档列表
KOSMOS2_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "microsoft/kosmos-2-patch14-224",
    # 可以在 https://huggingface.co/models?filter=kosmos-2 查看所有 KOSMOS-2 模型
]

# 定义函数:将注意力掩码从 `[bsz, seq_len]` 扩展到 `[bsz, 1, tgt_seq_len, src_seq_len]`
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    # 扩展注意力掩码
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    # 创建反向的掩码
    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


# 定义函数:创建用于双向自注意力的因果掩码
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

# 从 transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids 复制过来的部分
# 定义一个函数,根据输入的 token IDs 创建位置 ID,用于替换非填充符号为它们的位置编号。位置编号从 padding_idx+1 开始计数,忽略填充符号。
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
    """
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        input_ids (torch.Tensor): 输入的 token IDs
        padding_idx (int): 填充符号的索引
        past_key_values_length (int, optional): 过去键值长度,用于增量索引计算

    Returns:
        torch.Tensor: 替换后的位置 ID
    """
    # 在这里进行一系列的类型转换和转换,以确保同时支持 ONNX 导出和 XLA。
    mask = input_ids.ne(padding_idx).int()  # 创建一个掩码,标记非填充符号的位置
    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask  # 计算增量索引
    return incremental_indices.long() + padding_idx  # 返回最终的位置 ID,加上 padding_idx 得到真实的位置编号


KOSMOS2_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Kosmos2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

KOSMOS2_VISION_INPUTS_DOCSTRING = r"""
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

KOSMOS2_TEXT_INPUTS_DOCSTRING = r"""
    Args:
"""

KOSMOS2_INPUTS_DOCSTRING = r"""
    Args:
"""


@dataclass
class Kosmos2ModelOutput(ModelOutput):
    """
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    """
    # 最后一层模型的隐藏状态,形状为(batch_size, sequence_length, hidden_size)
    last_hidden_state: torch.FloatTensor = None
    
    # 过去的键-值对,可选参数,形状为(config.n_layers, 2, batch_size, num_heads, sequence_length, embed_size_per_head),用于加速顺序解码
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    
    # 模型每一层的隐藏状态的元组,如果模型有嵌入层,则包括嵌入层输出,形状为(batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    
    # 自注意力机制每一层的注意力权重的元组,形状为(batch_size, num_heads, sequence_length, sequence_length)
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    
    # 图像嵌入的隐藏状态,形状为(batch_size, latent_query_num, hidden_size),可选参数
    image_embeds: Optional[torch.FloatTensor] = None
    # 定义一个可选类型的变量 projection_attentions,可能是一个包含 torch.FloatTensor 的元组,初始值为 None
    projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
    
    # 定义一个变量 vision_model_output,类型为 BaseModelOutputWithPooling,初始值为 None
    vision_model_output: BaseModelOutputWithPooling = None
    
    # 定义一个方法 to_tuple,返回一个元组,包含对象所有键对应的值,但对于键为"text_model_output"和"vision_model_output"的情况,返回它们的 to_tuple() 方法的结果
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
        )
@dataclass
class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
    """
    Model output class for `Kosmos2ForConditionalGeneration`.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
        projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
            the weighted average in the self-attention heads.
        vision_model_output(`BaseModelOutputWithPooling`, *optional*):
            The output of the [`Kosmos2VisionModel`].
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
    """
    # 定义可选的损失张量
    loss: Optional[torch.FloatTensor] = None
    # 定义空的 logits 张量
    logits: torch.FloatTensor = None
    # 定义可选的过去键值元组,包含 FloatTensor
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 定义可选的隐藏状态元组,包含 FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的注意力元组,包含 FloatTensor
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的图像嵌入张量
    image_embeds: Optional[torch.FloatTensor] = None
    # 定义可选的投影注意力元组,包含 FloatTensor
    projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 定义空的视觉模型输出,类型为 BaseModelOutputWithPooling
    vision_model_output: BaseModelOutputWithPooling = None

    # 转换为元组的方法,返回包含所有非 "text_model_output" 和 "vision_model_output" 的属性的元组
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
        )
# 从transformers.models.clip.modeling_clip.CLIPVisionEmbeddings复制而来,修改为Kosmos2
class Kosmos2VisionEmbeddings(nn.Module):
    def __init__(self, config: Kosmos2VisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size  # 设置嵌入维度为配置文件中的隐藏大小
        self.image_size = config.image_size  # 设置图像大小为配置文件中的图像大小
        self.patch_size = config.patch_size  # 设置补丁大小为配置文件中的补丁大小

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))  # 定义类别嵌入作为可学习参数

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )  # 定义补丁嵌入为二维卷积层,用于从图像像素值生成嵌入向量

        self.num_patches = (self.image_size // self.patch_size) ** 2  # 计算图像中的补丁数量
        self.num_positions = self.num_patches + 1  # 计算位置嵌入的数量
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)  # 定义位置嵌入为一个嵌入层
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)  # 注册位置 ID,用于序列位置编码

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]  # 获取批次大小
        target_dtype = self.patch_embedding.weight.dtype  # 获取目标数据类型
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # 使用补丁嵌入层处理像素值,生成补丁嵌入向量

        # 展开补丁嵌入向量并进行维度转换
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)  # 扩展类别嵌入以适应批次大小
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)  # 连接类别嵌入和补丁嵌入,形成最终嵌入向量
        embeddings = embeddings + self.position_embedding(self.position_ids)  # 添加位置嵌入到最终嵌入向量中
        return embeddings


# 从transformers.models.clip.modeling_clip.CLIPAttention复制而来,修改为Kosmos2Vision
class Kosmos2VisionAttention(nn.Module):
    """来自 'Attention Is All You Need' 论文的多头注意力机制"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size  # 设置嵌入维度为配置文件中的隐藏大小
        self.num_heads = config.num_attention_heads  # 设置注意力头数为配置文件中的注意力头数
        self.head_dim = self.embed_dim // self.num_heads  # 计算每个注意力头的维度
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: "
                f"{self.num_heads})."
            )
        self.scale = self.head_dim**-0.5  # 缩放因子,用于缩放注意力分数
        self.dropout = config.attention_dropout  # 设置注意力层的dropout比例

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化键的投影层
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化值的投影层
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化查询的投影层
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化输出的投影层

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    # 定义一个方法 `forward`,用于模型前向传播操作
    def forward(
        self,
        # 输入参数:表示模型当前隐藏状态的张量
        hidden_states: torch.Tensor,
        # 输入参数:可选的注意力掩码张量,用于指示哪些位置需要注意
        attention_mask: Optional[torch.Tensor] = None,
        # 输入参数:可选的因果注意力掩码张量,用于自回归任务的自注意力
        causal_attention_mask: Optional[torch.Tensor] = None,
        # 输入参数:是否输出注意力权重,默认为 False
        output_attentions: Optional[bool] = False,
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
class Kosmos2VisionMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]  # 使用配置中的激活函数
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)  # 创建全连接层 fc1
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)  # 创建全连接层 fc2

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)  # 应用 fc1
        hidden_states = self.activation_fn(hidden_states)  # 应用激活函数
        hidden_states = self.fc2(hidden_states)  # 应用 fc2
        return hidden_states


# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Kosmos2Vision
class Kosmos2VisionEncoderLayer(nn.Module):
    def __init__(self, config: Kosmos2VisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size  # 设置嵌入维度为隐藏尺寸
        self.self_attn = Kosmos2VisionAttention(config)  # 创建 Kosmos2VisionAttention 自注意力层
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)  # 创建 LayerNorm 层1
        self.mlp = Kosmos2VisionMLP(config)  # 创建 Kosmos2VisionMLP 多层感知器
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)  # 创建 LayerNorm 层2

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        causal_attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            causal_attention_mask (`torch.FloatTensor`): mask indicating the causal nature of attention
            output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers.
        """
        residual = hidden_states  # 记录残差连接

        hidden_states = self.layer_norm1(hidden_states)  # 应用 LayerNorm 层1
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )  # 应用自注意力机制层,并返回注意力权重
        hidden_states = residual + hidden_states  # 残差连接

        residual = hidden_states  # 记录残差连接
        hidden_states = self.layer_norm2(hidden_states)  # 应用 LayerNorm 层2
        hidden_states = self.mlp(hidden_states)  # 应用 MLP 层
        hidden_states = residual + hidden_states  # 残差连接

        outputs = (hidden_states,)  # 输出结果为 hidden_states

        if output_attentions:
            outputs += (attn_weights,)  # 如果需要输出注意力权重,加入输出结果

        return outputs


# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Kosmos2Vision
class Kosmos2VisionEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    """
    # 定义 Kosmos2VisionEncoderLayer 类,用于处理 Kosmos2Vision 模型的编码器层
    class Kosmos2VisionEncoderLayer(nn.Module):

        # 初始化方法,接收一个 Kosmos2VisionConfig 类型的配置对象作为参数
        def __init__(self, config: Kosmos2VisionConfig):
            # 调用父类的初始化方法
            super().__init__()
            # 将传入的配置对象保存到实例变量中
            self.config = config
            # 创建一个包含多个 Kosmos2VisionEncoderLayer 实例的列表,列表长度为 config.num_hidden_layers
            self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
            # 设置梯度检查点标志为 False
            self.gradient_checkpointing = False

        # 前向传播方法,接收多个参数
        def forward(
            self,
            inputs_embeds,
            attention_mask: Optional[torch.Tensor] = None,
            causal_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
# 类定义,实现了一个类似于 `transformers.models.clip.modeling_clip.CLIPVisionTransformer` 的模型,但没有为 `forward` 方法添加文档字符串
class Kosmos2VisionTransformer(nn.Module):
    # 构造函数,接受一个 `Kosmos2VisionConfig` 类型的参数 `config`
    # 初始化父类 `nn.Module`
    def __init__(self, config: Kosmos2VisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        # 实例化 `Kosmos2VisionEmbeddings` 类,用于嵌入层处理
        self.embeddings = Kosmos2VisionEmbeddings(config)
        # LayerNorm 层,对嵌入向量进行归一化
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        # `Kosmos2VisionEncoder` 类,用于编码器的处理
        self.encoder = Kosmos2VisionEncoder(config)
        # 再次应用 LayerNorm 层,对输出进行归一化
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    # 前向传播函数,接受多个参数,返回一个元组或者 `BaseModelOutputWithPooling` 类型
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        # 如果 `output_attentions` 未指定,则使用配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 `output_hidden_states` 未指定,则使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 `return_dict` 未指定,则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 `pixel_values` 为空,则抛出值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 通过嵌入层处理 `pixel_values`,得到隐藏状态
        hidden_states = self.embeddings(pixel_values)
        # 对隐藏状态应用预 LayerNorm 层进行归一化
        hidden_states = self.pre_layrnorm(hidden_states)

        # 将归一化后的隐藏状态传递给编码器 `self.encoder` 进行编码
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取编码器输出的最后一个隐藏状态
        last_hidden_state = encoder_outputs[0]
        # 从最后一个隐藏状态中提取池化输出,通常是第一个位置的输出
        pooled_output = last_hidden_state[:, 0, :]
        # 对池化输出应用后 LayerNorm 层进行归一化
        pooled_output = self.post_layernorm(pooled_output)

        # 如果不需要返回字典,则返回一个元组,包含最后一个隐藏状态、池化输出以及额外的编码器输出
        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        # 否则,返回一个 `BaseModelOutputWithPooling` 对象,包含最后一个隐藏状态、池化输出、所有隐藏状态以及注意力权重
        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )


# 类定义,实现了一个类似于 `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` 的模块,但允许传递 `position_ids`
class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length."""

    # 构造函数,无参数,继承自 `nn.Module`
    # 此处省略了具体的初始化过程
    # 初始化函数,用于设置位置编码的参数
    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
        super().__init__()
        # 设置位置编码的偏移量为2
        self.offset = 2
        # 设定位置编码的维度
        self.embedding_dim = embedding_dim
        # 可选参数:填充索引
        self.padding_idx = padding_idx
        # 调用make_weights方法生成位置编码权重
        self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)

    # 静态方法:从transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding类中复制得到
    # 生成位置编码权重的方法
    def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
        # 调用get_embedding方法获取嵌入向量权重
        emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
        # 如果self中已经有了weights属性,则在forward方法中将权重转换成正确的数据类型和设备
        if hasattr(self, "weights"):
            emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)

        # 将生成的权重注册为缓冲区,不持久化保存
        self.register_buffer("weights", emb_weights, persistent=False)

    # 静态方法:从transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding类中复制得到
    # 生成嵌入向量的方法
    @staticmethod
    def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
        """
        构建正弦位置编码的嵌入向量。

        该方法与tensor2tensor中的实现匹配,但与《Attention Is All You Need》中第3.5节的描述略有不同。
        """
        # 计算嵌入向量的半径
        half_dim = embedding_dim // 2
        # 计算正弦函数的周期
        emb = math.log(10000) / (half_dim - 1)
        # 计算正弦位置编码的指数值
        emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
        # 计算位置编码张量
        emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
        # 拼接正弦和余弦函数,生成最终的位置编码张量
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        # 如果嵌入维度是奇数,则在末尾填充零
        if embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        # 如果指定了填充索引,则将该位置的嵌入向量置为零向量
        if padding_idx is not None:
            emb[padding_idx, :] = 0

        return emb.to(torch.get_default_dtype())

    # 用于前向传播计算的方法,设置位置编码
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        past_key_values_length: int = 0,
        position_ids: torch.Tensor = None,
    ):
        # 如果传入了 input_ids 参数
        if input_ids is not None:
            # 获取 batch size 和 sequence length
            bsz, seq_len = input_ids.size()
            # 如果 position_ids 参数为 None
            if position_ids is None:
                # 根据输入的 token ids 创建 position ids。任何填充的 token 保持填充状态。
                position_ids = create_position_ids_from_input_ids(
                    input_ids, self.padding_idx, past_key_values_length
                ).to(input_ids.device)
        else:
            # 获取 batch size 和 sequence length,排除最后一维
            bsz, seq_len = inputs_embeds.size()[:-1]
            # 如果 position_ids 参数为 None
            if position_ids is None:
                # 根据 inputs_embeds 和 past_key_values_length 创建 position ids
                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)

        # 如果需要扩展 embeddings
        max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
        if max_pos > self.weights.size(0):
            # 根据最大位置和偏移量,以及 embedding 维度和填充索引,创建新的 weights
            self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)

        # 根据 position_ids 从 weights 中选择对应的 embeddings,并重新组织形状
        return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()

    # 从 transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds 复制而来
    def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
        """
        直接提供 embeddings。无法推断哪些是填充的,因此生成顺序的 position ids。

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        """
        # 获取输入 embeddings 的形状,排除最后一维
        input_shape = inputs_embeds.size()[:-1]
        # 获取序列长度
        sequence_length = input_shape[1]

        # 根据序列长度、padding_idx 和设备类型,在设备上创建 long 类型的序列 tensor
        position_ids = torch.arange(
            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
        )
        # 扩展 position_ids 的形状以匹配 inputs_embeds,并确保连续性,加上 past_key_values_length
        return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
class KosmosTextAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`.
    def __init__(
        self,
        config,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        add_inner_attn_layernorm: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 设置模型的嵌入维度
        self.num_heads = num_heads  # 设置注意力头的数量
        self.dropout = dropout  # 设置dropout概率
        self.head_dim = embed_dim // num_heads  # 计算每个注意力头的维度

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子
        self.is_decoder = is_decoder  # 是否为解码器

        # 初始化线性投影层
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        # 添加内部注意力层规范化
        self.inner_attn_ln = None
        if add_inner_attn_layernorm:
            self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def _shape(self, projection: torch.Tensor) -> torch.Tensor:
        new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim)
        # 将投影重新形状以适应多头注意力的结构
        # (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
        new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
        return new_projection

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    # 定义神经网络的前向传播方法,接收隐藏状态作为输入
    def forward(self, hidden_states):
        # 将隐藏状态输入全连接层 fc1,并应用激活函数 activation_fn
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对隐藏状态进行 dropout 操作,以防止过拟合,根据训练状态决定是否执行
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        # 对经过 dropout 后的隐藏状态进行层归一化处理
        hidden_states = self.ffn_layernorm(hidden_states)
        # 将归一化后的隐藏状态输入全连接层 fc2
        hidden_states = self.fc2(hidden_states)
        # 对最终输出的隐藏状态再次进行 dropout 操作,根据训练状态决定是否执行
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # 返回经过前向传播后的隐藏状态
        return hidden_states
# 定义一个名为 Kosmos2TextBlock 的神经网络模块,继承自 nn.Module
class Kosmos2TextBlock(nn.Module):
    # 初始化函数,接受一个名为 config 的 Kosmos2TextConfig 类型参数
    def __init__(self, config: Kosmos2TextConfig):
        # 调用父类 nn.Module 的初始化方法
        super().__init__()
        # 设置嵌入维度为 config 中的 embed_dim
        self.embed_dim = config.embed_dim

        # 创建自注意力层 KosmosTextAttention 对象
        self.self_attn = KosmosTextAttention(
            config,
            embed_dim=self.embed_dim,
            num_heads=config.attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            add_inner_attn_layernorm=True,
        )
        
        # 设置 dropout 概率
        self.dropout = config.dropout
        # 创建自注意力层的 LayerNorm 层
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        # 如果配置中包含交叉注意力设置
        if config.add_cross_attention:
            # 创建编码器注意力层 KosmosTextAttention 对象
            self.encoder_attn = KosmosTextAttention(
                config,
                embed_dim=self.embed_dim,
                num_heads=config.attention_heads,
                dropout=config.attention_dropout,
                is_decoder=True,
                add_inner_attn_layernorm=False,
            )
            # 创建编码器注意力层的 LayerNorm 层
            self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        # 创建前馈神经网络对象 Kosmos2TextFFN
        self.ffn = Kosmos2TextFFN(config)
        # 创建最终输出层的 LayerNorm 层
        self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    # 前向传播函数,接受多个输入参数
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存原始的隐藏状态作为残差连接的基础
        residual = hidden_states

        # Self Attention
        # 如果有过去的键/值缓存,从中提取decoder单向self-attention的缓存键/值对,位置为1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # 对隐藏状态进行 layer normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 使用self-attention机制处理隐藏状态
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=self_attn_past_key_value,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )

        # 对输出的隐藏状态进行dropout处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        
        # 将残差连接到当前隐藏状态上
        hidden_states = residual + hidden_states

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None

        # 如果有encoder的隐藏状态
        if encoder_hidden_states is not None:
            # 检查是否存在cross-attention层,若不存在则抛出异常
            if not hasattr(self, "encoder_attn"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"
                )

            # 保存当前隐藏状态作为残差连接的基础
            residual = hidden_states

            # 对隐藏状态进行 layer normalization
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # 如果有过去的键/值缓存,从中提取cross-attention的缓存键/值对,位置为3,4
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None

            # 使用cross-attention机制处理隐藏状态
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
                hidden_states=hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                output_attentions=output_attentions,
            )

            # 对输出的隐藏状态进行dropout处理
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

            # 将残差连接到当前隐藏状态上
            hidden_states = residual + hidden_states

            # 将cross-attention的键/值对添加到当前的present_key_value中,位置为3,4
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        # 保存当前隐藏状态作为残差连接的基础
        residual = hidden_states

        # 对隐藏状态进行 layer normalization
        hidden_states = self.final_layer_norm(hidden_states)

        # Feed Forward Network (FFN)
        hidden_states = self.ffn(hidden_states)

        # 将残差连接到当前隐藏状态上
        hidden_states = residual + hidden_states

        # 将最终的隐藏状态作为输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,将self-attention和cross-attention的权重也添加到输出中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        # 如果需要使用缓存,将当前的present_key_value添加到输出中
        if use_cache:
            outputs += (present_key_value,)

        return outputs
    """
    Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].

    Args:
        config: Kosmos2TextConfig
    """
    
    def __init__(self, config: Kosmos2TextConfig):
        super().__init__()
        self.config = config  # 保存配置对象
        self.dropout = config.dropout  # 设置 dropout 概率
        self.layerdrop = config.layerdrop  # 设置层级 dropout 概率

        self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0  # 计算嵌入的缩放因子
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)  # 嵌入层,根据配置创建

        self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
            num_positions=config.max_position_embeddings,
            embedding_dim=config.embed_dim,
            padding_idx=config.pad_token_id,
        )  # 创建位置嵌入对象

        self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)])  # 创建多层 Transformer 块
        self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)  # 创建层归一化对象

        self.gradient_checkpointing = False  # 是否使用梯度检查点,默认为 False

    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # 创建因果注意力遮罩
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )  # 调用函数创建因果遮罩

        if attention_mask is not None:
            # 扩展注意力遮罩的维度
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
                inputs_embeds.device
            )
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )  # 合并注意力遮罩

        return combined_attention_mask

    def forward_embedding(
        self,
        input_ids,
        inputs_embeds: torch.Tensor = None,
        image_embeds: torch.Tensor = None,
        img_input_mask: torch.Tensor = None,
        past_key_values_length: int = 0,
        position_ids: torch.Tensor = None,
        # 如果未提供 `inputs_embeds` 参数,则使用 `input_ids` 生成对应的嵌入表示
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果提供了 `image_embeds` 参数,则将其融合到 `inputs_embeds` 中相应的位置
        if image_embeds is not None:
            # 使用 `img_input_mask` 将 `image_embeds` 插入到 `inputs_embeds` 中对应位置
            inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view(
                -1, image_embeds.size(-1)
            )

        # 将 `inputs_embeds` 缩放乘以 `self.embed_scale`
        inputs_embeds = inputs_embeds * self.embed_scale

        # 嵌入位置信息
        positions = self.embed_positions(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
            position_ids=position_ids,
        )
        # 将位置嵌入张量移到与 `inputs_embeds` 相同的设备上
        positions = positions.to(inputs_embeds.device)

        # 将位置嵌入张量与输入嵌入张量相加,得到隐藏状态张量
        hidden_states = inputs_embeds + positions

        # 在训练过程中进行 dropout 操作
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        # 返回最终的隐藏状态张量作为前向传播的输出
        return hidden_states
class Kosmos2PreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 配置类,用于当前模型的配置信息
    config_class = Kosmos2Config
    # 是否支持梯度检查点(gradient checkpointing)
    supports_gradient_checkpointing = True
    # 不需要分割的模块列表
    _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]

class Kosmos2VisionModel(Kosmos2PreTrainedModel):
    # 使用的配置类
    config_class = Kosmos2VisionConfig
    # 主要输入名称为像素值
    main_input_name = "pixel_values"

    # 从 CLIPVisionModel.__init__ 复制而来,修改了命名空间和变量名
    def __init__(self, config: Kosmos2VisionConfig):
        super().__init__(config)
        # 创建视觉模型对象
        self.model = Kosmos2VisionTransformer(config)
        # 初始化权重并进行最终处理
        self.post_init()

    # 从 CLIPVisionModel.get_input_embeddings 复制而来,修改了命名空间和变量名
    def get_input_embeddings(self) -> nn.Module:
        # 返回嵌入层的 patch 嵌入
        return self.model.embeddings.patch_embedding

    # 添加了文档字符串到模型前向方法的装饰器
    # 替换了返回文档字符串,指定了输出类型和配置类
    @add_start_docstrings_to_model_forward(KOSMOS2_VISION_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Kosmos2VisionConfig)
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        """
        前向传播方法,接受像素值作为输入,可选输出注意力、隐藏状态和返回字典。

        Returns:
            返回模型的输出,可能是元组或带池化的基础模型输出。
        """
        return self.model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


class Kosmos2TextModel(Kosmos2PreTrainedModel):
    # 使用的配置类
    config_class = Kosmos2TextConfig

    def __init__(self, config: Kosmos2TextConfig):
        super().__init__(config)
        # 创建文本模型对象
        self.model = Kosmos2TextTransformer(config)
        # 初始化权重并进行最终处理
        self.post_init()

    # 获取输入嵌入的方法
    def get_input_embeddings(self) -> nn.Module:
        # 返回嵌入令牌(embed_tokens)
        return self.model.embed_tokens

    # 设置输入嵌入的方法
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    # 添加了文档字符串到模型前向方法的装饰器
    # 替换了返回文档字符串,指定了输出类型和配置类
    @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
        r"""
        将输入参数传递给模型,并返回模型的输出。

        Parameters:
        - input_ids (Optional[torch.Tensor]): 输入的 token IDs 序列,默认为 None。
        - attention_mask (Optional[torch.Tensor]): 注意力遮罩张量,默认为 None。
        - image_embeds (Optional[torch.Tensor]): 图像嵌入张量,默认为 None。
        - image_embeds_position_mask (Optional[torch.Tensor]): 图像嵌入的位置遮罩张量,默认为 None。
        - encoder_hidden_states (Optional[torch.Tensor]): 编码器的隐藏状态张量,默认为 None。
        - encoder_attention_mask (Optional[torch.Tensor]): 编码器的注意力遮罩张量,默认为 None。
        - head_mask (Optional[torch.Tensor]): 头部遮罩张量,默认为 None。
        - cross_attn_head_mask (Optional[torch.Tensor]): 跨注意力头部遮罩张量,默认为 None。
        - past_key_values (Optional[List[torch.FloatTensor]]): 过去的键值对列表,默认为 None。
        - inputs_embeds (Optional[torch.Tensor]): 输入的嵌入张量,默认为 None。
        - position_ids (Optional[torch.Tensor]): 位置 ID 张量,默认为 None。
        - use_cache (Optional[bool]): 是否使用缓存,默认为 None。
        - output_attentions (Optional[bool]): 是否输出注意力权重,默认为 None。
        - output_hidden_states (Optional[bool]): 是否输出隐藏状态,默认为 None。
        - return_dict (Optional[bool]): 是否返回字典格式的输出,默认为 None。

        Returns:
        - Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 返回模型的输出,可能是一个元组或特定的输出类对象。

        """
        # 调用模型的 forward 方法,将所有参数传递给模型,并返回模型的输出结果
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            image_embeds=image_embeds,
            image_embeds_position_mask=image_embeds_position_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            head_mask=head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            position_ids=position_ids,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
"""
The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
embeddings).
"""
# 基于KOSMOS-2的文本模型,顶部带有语言建模头部(线性层,权重与输入嵌入绑定)。

# 使用装饰器添加文档字符串到类的开头
@add_start_docstrings(
    KOSMOS2_START_DOCSTRING,  # 使用预定义的起始文档字符串
)
class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel):
    config_class = Kosmos2TextConfig  # 设置配置类为Kosmos2TextConfig
    _tied_weights_keys = ["lm_head.weight"]  # 定义权重绑定的键名列表

    def __init__(self, config: Kosmos2TextConfig):
        super().__init__(config)

        # 初始化模型和语言建模头部线性层
        self.model = Kosmos2TextTransformer(config)  # 使用配置初始化文本转换器模型
        self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
        # 初始化线性层,输入维度为config.embed_dim,输出维度为config.vocab_size,无偏置

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        return self.model.embed_tokens  # 返回模型的输入嵌入层

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value  # 设置模型的输入嵌入层为给定的value

    def get_output_embeddings(self) -> nn.Module:
        return self.lm_head  # 返回语言建模头部的输出嵌入层

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings  # 设置语言建模头部的输出嵌入层为给定的new_embeddings

    # 使用装饰器添加文档字符串到模型的前向方法
    @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(
        output_type=CausalLMOutputWithCrossAttentions,  # 替换输出类型为带交叉注意力的因果语言建模输出
        config_class=Kosmos2TextConfig  # 替换配置类为Kosmos2TextConfig
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,接受多种输入参数并返回模型输出
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Returns:
            Depending on `return_dict`, either a tuple with `loss` and various outputs or an instance of
            `CausalLMOutputWithCrossAttentions` containing `loss`, `logits`, and other relevant model outputs.

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False  # 如果提供了labels,则强制关闭use_cache,避免使用缓存

        outputs = self.model(
            input_ids=input_ids,  # 输入的token IDs
            attention_mask=attention_mask,  # 注意力遮罩,用于指示哪些token需要被注意
            image_embeds=image_embeds,  # 图像嵌入向量,可选
            image_embeds_position_mask=image_embeds_position_mask,  # 图像嵌入位置掩码,可选
            encoder_hidden_states=encoder_hidden_states,  # 编码器的隐藏状态,用于多层编码器
            encoder_attention_mask=encoder_attention_mask,  # 编码器的注意力遮罩
            head_mask=head_mask,  # 多头注意力机制的头部遮罩
            cross_attn_head_mask=cross_attn_head_mask,  # 跨注意力头部的遮罩
            past_key_values=past_key_values,  # 过去的键值,用于生成
            inputs_embeds=inputs_embeds,  # 输入的嵌入向量,用于替代input_ids
            position_ids=position_ids,  # 位置IDs,指定每个token的位置
            use_cache=use_cache,  # 是否使用缓存,根据labels的存在动态设置
            output_attentions=output_attentions,  # 是否输出注意力权重
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态
            return_dict=return_dict,  # 是否返回字典形式的输出
        )
        lm_logits = self.lm_head(outputs[0])  # 使用语言模型头部预测的logits

        loss = None
        if labels is not None:
            labels = labels.to(lm_logits.device)  # 将labels移到与logits相同的设备上,以支持模型并行处理
            shift_logits = lm_logits[..., :-1, :].contiguous()  # 将logits向左移动一位,用于计算损失
            shift_labels = labels[..., 1:].contiguous()  # 将labels向左移动一位,与shift_logits对齐
            batch_size, seq_length, vocab_size = shift_logits.shape  # 获取logits的形状信息
            loss_fct = CrossEntropyLoss()  # 交叉熵损失函数
            loss = loss_fct(
                shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
            )  # 计算损失,将logits和labels展平成二维张量进行计算

        if not return_dict:
            output = (lm_logits,) + outputs[1:]  # 如果不返回字典,输出包含logits和其它模型输出
            return (loss,) + output if loss is not None else output  # 如果有损失,则返回损失和输出;否则只返回输出

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )  # 返回一个包含损失、logits和其它相关模型输出的对象
    ):
        input_shape = input_ids.shape
        # 如果模型作为编码器-解码器模型的解码器使用,会动态创建解码器注意力掩码
        if attention_mask is None:
            # 如果没有提供注意力掩码,创建一个全为1的张量,形状与输入张量相同
            attention_mask = input_ids.new_ones(input_shape)

        position_ids = None

        # 如果使用了过去的键值对,根据输入的ID创建位置ID
        if past_key_values is not None:
            position_ids = create_position_ids_from_input_ids(
                input_ids,
                padding_idx=self.config.pad_token_id,
                past_key_values_length=0,
            )[:, -1:]

            # 截取输入的ID,仅保留最后一个
            input_ids = input_ids[:, -1:]
            # 图像信息已经编码到过去的键/值中,因此不需要额外的图像嵌入
            image_embeds = None
            image_embeds_position_mask = None
        elif image_embeds_position_mask is not None:
            # 将`False`追加到`image_embeds_position_mask`(因为在生成过程中`input_ids`会增长)
            batch_size, seq_len = input_ids.size()
            mask_len = image_embeds_position_mask.size()[-1]
            image_embeds_position_mask = torch.cat(
                (
                    image_embeds_position_mask,
                    torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),
                ),
                dim=1,
            )

        return {
            "input_ids": input_ids,
            "image_embeds": image_embeds,
            "image_embeds_position_mask": image_embeds_position_mask,
            "past_key_values": past_key_values,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "use_cache": use_cache,
        }

    @staticmethod
    # 从transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache复制过来的方法
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                # 根据beam_idx重新排序过去的状态
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past
class Kosmos2ImageToTextProjection(nn.Module):
    """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""

    def __init__(self, config: Kosmos2Config):
        super().__init__()
        # 定义一个全连接层,将图像模型输出的隐藏状态映射到文本模型的嵌入维度
        self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
        # 定义一个可学习的查询参数矩阵,用于文本注意力机制
        self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))

        # 初始化文本注意力机制,用于处理图像到文本的投影
        self.x_attn = KosmosTextAttention(
            config.text_config,
            config.text_config.embed_dim,
            config.text_config.attention_heads,
            dropout=config.text_config.attention_dropout,
            is_decoder=False,
            add_inner_attn_layernorm=False,
        )

    def forward(self, features):
        # 使用全连接层将图像特征转换为隐藏状态
        hidden_states = self.dense(features)

        # shape = [batch, latent_query_num, h_dim]
        # 准备 latent_query,扩展以匹配隐藏状态的形状
        latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
        # 将隐藏状态和 latent_query 连接起来,形成键值状态
        key_value_states = torch.cat([hidden_states, latent_query], dim=1)

        # 应用文本注意力机制,处理图像到文本的转换过程
        hidden_states, attn_weights, _ = self.x_attn(
            hidden_states=latent_query,
            encoder_hidden_states=key_value_states,
            past_key_value=None,
            attention_mask=None,
            output_attentions=None,
        )

        # 返回处理后的隐藏状态和注意力权重
        return hidden_states, attn_weights


@add_start_docstrings(
    """
    KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
    """,
    KOSMOS2_START_DOCSTRING,
)
class Kosmos2Model(Kosmos2PreTrainedModel):
    config_class = Kosmos2Config
    main_input_name = "pixel_values"

    def __init__(self, config: Kosmos2Config):
        super().__init__(config)

        # 初始化文本模型、视觉模型和图像到文本的投影层
        self.text_model = Kosmos2TextModel(config.text_config)
        self.vision_model = Kosmos2VisionModel(config.vision_config)
        self.image_to_text_projection = Kosmos2ImageToTextProjection(config)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        # 返回文本模型的嵌入层
        return self.text_model.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置文本模型的嵌入层
        self.text_model.model.embed_tokens = value

    @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
    # 定义模型的前向传播方法,接受多个可选的输入参数
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,  # 图像像素值的张量,可选
        input_ids: Optional[torch.Tensor] = None,  # 输入文本的张量表示,可选
        image_embeds_position_mask: Optional[torch.Tensor] = None,  # 图像嵌入位置掩码的张量,可选
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的张量,可选
        head_mask: Optional[torch.Tensor] = None,  # 头部掩码的张量,可选
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 过去的键值对列表,可选
        image_embeds: Optional[torch.Tensor] = None,  # 图像嵌入的张量表示,可选
        inputs_embeds: Optional[torch.Tensor] = None,  # 输入嵌入的张量表示,可选
        position_ids: Optional[torch.Tensor] = None,  # 位置ID的张量表示,可选
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选
        return_dict: Optional[bool] = None,  # 是否返回字典格式的结果,可选
"""
KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
language model.
"""
@add_start_docstrings(
    """
    KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
    language model.
    """,
    KOSMOS2_START_DOCSTRING,
)
class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel):
    # 指定配置类
    config_class = Kosmos2Config
    # 主要输入名称为像素值
    main_input_name = "pixel_values"
    # 绑定权重的键列表
    _tied_weights_keys = ["text_model.lm_head.weight"]

    def __init__(self, config: Kosmos2Config):
        # 调用父类初始化方法
        super().__init__(config)

        # 文本模型部分,使用给定的文本配置初始化
        self.text_model = Kosmos2TextForCausalLM(config.text_config)
        # 视觉模型部分,使用给定的视觉配置初始化
        self.vision_model = Kosmos2VisionModel(config.vision_config)

        # 图像到文本投影模块,使用给定的配置初始化
        self.image_to_text_projection = Kosmos2ImageToTextProjection(config)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self) -> nn.Module:
        # 返回文本模型的嵌入层
        return self.text_model.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置文本模型的嵌入层
        self.text_model.model.embed_tokens = value

    def get_output_embeddings(self) -> nn.Module:
        # 返回文本模型的输出嵌入层
        return self.text_model.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        # 设置文本模型的输出嵌入层
        self.text_model.set_output_embeddings(new_embeddings)

    @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        image_embeds: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播方法,详细参数和返回值请参考模型输入和输出文档字符串
        pass

    def generate(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        # 生成方法,用于生成文本和边界框,接受多种输入参数
        pass
        # 为了允许 `inputs` 参数(如 `GenerationMixin` 中所需)
        inputs = kwargs.pop("inputs", None)
        # 如果 `pixel_values` 不为 None,并且 `inputs` 也不为 None,则抛出 ValueError
        if pixel_values is not None and inputs is not None:
            raise ValueError(
                f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
                f"Make sure to either pass `inputs` or pixel_values=..."
            )
        # 如果 `pixel_values` 为 None 且 `inputs` 不为 None,则将 `pixel_values` 设置为 `inputs`
        if pixel_values is None and inputs is not None:
            pixel_values = inputs

        # 如果 `image_embeds` 为 None,则进行以下操作
        if image_embeds is None:
            # 使用 `self.vision_model` 处理 `pixel_values` 得到视觉模型的输出
            vision_model_output = self.vision_model(pixel_values)
            # 将整个 `last_hidden_state` 通过 `post_layernorm` 而不是只使用 `pooled_output`
            image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
            # 对特征进行归一化处理
            image_embeds = nn.functional.normalize(image_embeds, dim=-1)
            # 将图像嵌入向量转换为文本嵌入向量
            image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)

        # 使用 `self.text_model` 生成文本输出
        output = self.text_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            image_embeds=image_embeds,
            image_embeds_position_mask=image_embeds_position_mask,
            **kwargs,
        )

        # 返回生成的输出结果
        return output
posted @ 2024-06-29 16:58  绝不原创的飞龙  阅读(15)  评论(0编辑  收藏  举报