
Transformers 源码解析(六十)


Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.

import os
from typing import List, Optional, Union

# Importing necessary modules from the package
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from import AutoTokenizer

class InstructBlipProcessor(ProcessorMixin):
    Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single

    [`InstructBlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
    docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.

        image_processor (`BlipImageProcessor`):
            An instance of [`BlipImageProcessor`]. The image processor is a required input.
        tokenizer (`AutoTokenizer`):
            An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
        qformer_tokenizer (`AutoTokenizer`):
            An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.

    # Define class attributes
    attributes = ["image_processor", "tokenizer"]
    image_processor_class = "BlipImageProcessor"
    tokenizer_class = "AutoTokenizer"

    def __init__(self, image_processor, tokenizer, qformer_tokenizer):
        # Call the constructor of the superclass (ProcessorMixin)
        super().__init__(image_processor, tokenizer)

        # Initialize QFormer tokenizer attribute
        self.qformer_tokenizer = qformer_tokenizer
    # 定义一个特殊方法 __call__,使对象可以像函数一样被调用
    def __call__(
        # 参数:images,接受图像输入,可以是单个图像或图像列表
        images: ImageInput = None,
        # 参数:text,接受文本输入,可以是单个文本、预分词文本或它们的列表
        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
        # 参数:add_special_tokens,是否添加特殊令牌
        add_special_tokens: bool = True,
        # 参数:padding,是否进行填充,可以是布尔值、字符串或填充策略对象
        padding: Union[bool, str, PaddingStrategy] = False,
        # 参数:truncation,是否进行截断,可以是布尔值、字符串或截断策略对象
        truncation: Union[bool, str, TruncationStrategy] = None,
        # 参数:max_length,返回的最大长度
        max_length: Optional[int] = None,
        # 参数:stride,步进大小,默认为0
        stride: int = 0,
        # 参数:pad_to_multiple_of,填充到的倍数大小
        pad_to_multiple_of: Optional[int] = None,
        # 参数:return_attention_mask,是否返回注意力掩码
        return_attention_mask: Optional[bool] = None,
        # 参数:return_overflowing_tokens,是否返回溢出的令牌
        return_overflowing_tokens: bool = False,
        # 参数:return_special_tokens_mask,是否返回特殊令牌掩码
        return_special_tokens_mask: bool = False,
        # 参数:return_offsets_mapping,是否返回偏移映射
        return_offsets_mapping: bool = False,
        # 参数:return_token_type_ids,是否返回令牌类型ID
        return_token_type_ids: bool = False,
        # 参数:return_length,是否返回长度
        return_length: bool = False,
        # 参数:verbose,是否输出详细信息,默认为True
        verbose: bool = True,
        # 参数:return_tensors,返回的张量类型,可以是字符串或张量类型对象
        return_tensors: Optional[Union[str, TensorType]] = None,
        # 其他关键字参数
    ) -> BatchFeature:
        使用 [`BlipImageProcessor.__call__`] 方法准备模型的图像数据,
        和 [`BertTokenizerFast.__call__`] 方法准备模型的文本数据。

        # 如果既没有图像也没有文本,抛出数值错误异常
        if images is None and text is None:
            raise ValueError("You have to specify at least images or text.")

        # 创建一个空的 BatchFeature 对象用于存储编码后的数据
        encoding = BatchFeature()

        # 如果有文本输入,则使用 tokenizer 对文本进行编码
        if text is not None:
            text_encoding = self.tokenizer(
            # 将文本编码结果更新到 encoding 对象中

            # 使用 qformer_tokenizer 对文本进行编码
            qformer_text_encoding = self.qformer_tokenizer(
            # 将 qformer_tokenizer 的输入 ID 更新到 encoding 对象中
            encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
            # 将 qformer_tokenizer 的 attention mask 更新到 encoding 对象中
            encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")

        # 如果有图像输入,则使用 image_processor 对图像进行处理
        if images is not None:
            image_encoding = self.image_processor(images, return_tensors=return_tensors)
            # 将图像处理结果更新到 encoding 对象中

        # 返回编码后的数据对象 encoding
        return encoding

    # 从 transformers.models.blip.processing_blip.BlipProcessor.batch_decode 复制,使用 BertTokenizerFast->PreTrainedTokenizer
    def batch_decode(self, *args, **kwargs):
        此方法将所有参数转发给 PreTrainedTokenizer 的 [`~PreTrainedTokenizer.batch_decode`] 方法。
        # 调用 tokenizer 的 batch_decode 方法,并返回结果
        return self.tokenizer.batch_decode(*args, **kwargs)
    # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
    def decode(self, *args, **kwargs):
        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        # 调用 PreTrainedTokenizer 的 decode 方法,将所有参数转发到该方法
        return self.tokenizer.decode(*args, **kwargs)

    # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
    def model_input_names(self):
        # 获取 tokenizer 的模型输入名称列表
        tokenizer_input_names = self.tokenizer.model_input_names
        # 获取 image_processor 的模型输入名称列表
        image_processor_input_names = self.image_processor.model_input_names
        # 返回去重后的 tokenizer 和 image_processor 的模型输入名称列表组成的列表
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

    # overwrite to save the Q-Former tokenizer in a separate folder
    def save_pretrained(self, save_directory, **kwargs):
        # 如果 save_directory 是一个文件而不是文件夹,则抛出错误
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        # 创建 save_directory 文件夹(如果不存在)
        os.makedirs(save_directory, exist_ok=True)
        # 定义 Q-Former tokenizer 的保存路径
        qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
        # 将 Q-Former tokenizer 保存到指定路径
        # 调用父类的 save_pretrained 方法,保存模型到 save_directory
        return super().save_pretrained(save_directory, **kwargs)

    # overwrite to load the Q-Former tokenizer from a separate folder
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 从指定路径加载 Q-Former tokenizer
        qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
        # 获取从预训练模型名或路径中获取的参数
        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
        # 将加载的 Q-Former tokenizer 添加到参数列表末尾
        # 使用参数实例化当前类,并返回
        return cls(*args)


# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义的异常类和模块加载延迟处理类
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_instructblip": [
    "processing_instructblip": ["InstructBlipProcessor"],

# 检查是否可以导入 torch,如果不可以,则抛出 OptionalDependencyNotAvailable 异常
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果可以导入 torch,则添加以下模块到导入结构中
    _import_structure["modeling_instructblip"] = [

# 如果是类型检查阶段,进行详细的导入
    # 导入配置相关的类和变量
    from .configuration_instructblip import (
    # 导入处理相关的类
    from .processing_instructblip import InstructBlipProcessor

        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 导入模型相关的类和变量
        from .modeling_instructblip import (

# 如果不是类型检查阶段,将模块指定给自定义的 LazyModule 类
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)


import os
from typing import List, Union

from ...configuration_utils import PretrainedConfig
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# Jukebox 预训练配置文件映射
    "openai/jukebox-5b-lyrics": "",
    "openai/jukebox-1b-lyrics": "",

# 大型注意力列表
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "block_attn",  # 块注意力
    "transpose_block_attn",  # 转置块注意力
    "prev_block_attn",  # 前一块注意力
    "cross_attention",  # 交叉注意力
# 定义三个注意力模式的名称列表
_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
# 定义全连接密集注意力模式的名称列表
_FullDenseAttention = ["dense_attention"]
# 定义Prime-Prime-Dense注意力模式的名称列表
_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]

# 定义函数,返回全连接密集注意力模式的名称
def full_dense_attention(layer):
    return _FullDenseAttention[0]

# 定义函数,根据层索引返回RawColumnPreviousRowAttention模式的名称
def raw_column_previous_row_attention(layer):
    return _RawColumnPreviousRowAttention[layer % 3]

# 定义函数,根据层索引返回large separated enc dec w lyrics模式的名称
def large_separated_enc_dec_w_lyrics(layer):
    return _LARGE_ATTENTION[layer % 79]  # _LARGE_ATTENTION未定义,可能存在错误

# 定义函数,根据层索引返回enc dec with lyrics模式的名称
def enc_dec_with_lyrics(layer):
    if layer % 16 == 15:
        return _PrimePrimeDenseAttention[layer % 3]
    return _RawColumnPreviousRowAttention[layer % 3]

# 定义全局变量,包含不同注意力模式的名称及其对应的函数引用
    "full_dense_attention": full_dense_attention,
    "raw_column_previous_row_attention": raw_column_previous_row_attention,  # 用于替换行、列和上一行注意力
    "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics,  # 用于带歌词的大型分离enc dec模型
    "enc_dec_with_lyrics": enc_dec_with_lyrics,  # 用于带歌词的编码器-解码器模型

# 定义配置类,存储JukeboxPrior模型的配置信息
class JukeboxPriorConfig(PretrainedConfig):
        This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
        `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
        configuration with the defaults will yield a similar configuration to that of the top level prior from the
    -1b-lyrics) architecture.

        Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
        documentation from [`PretrainedConfig`] for more information.


    # 模型类型
    model_type = "jukebox_prior"
    # 属性映射字典,映射配置项到模型的实际参数名
    attribute_map = {
        "max_position_embeddings": "n_positions",  # 最大位置嵌入数对应的模型参数名
        "num_attention_heads": "n_head",  # 注意力头数对应的模型参数名
    # 定义初始化函数,用于创建一个对象实例
    def __init__(
        act_fn="quick_gelu",  # 激活函数名称,默认为 "quick_gelu"
        level=0,  # 模型的层级,默认为 0
        alignment_head=2,  # 对齐头部参数,默认为 2
        alignment_layer=68,  # 对齐层参数,默认为 68
        attention_multiplier=0.25,  # 注意力乘子,默认为 0.25
        attention_pattern="enc_dec_with_lyrics",  # 注意力模式,默认为 "enc_dec_with_lyrics"
        attn_dropout=0,  # 注意力部分的 dropout 概率,默认为 0
        attn_res_scale=False,  # 注意力残差比例,默认为 False
        blocks=64,  # 块数,默认为 64
        conv_res_scale=None,  # 卷积残差比例,默认为 None
        num_layers=72,  # 层数,默认为 72
        emb_dropout=0,  # 嵌入部分的 dropout 概率,默认为 0
        encoder_config=None,  # 编码器配置信息,默认为 None
        encoder_loss_fraction=0.4,  # 编码器损失分数,默认为 0.4
        hidden_size=2048,  # 隐藏层大小,默认为 2048
        init_scale=0.2,  # 初始化比例,默认为 0.2
        is_encoder_decoder=True,  # 是否是编码解码模型,默认为 True
        lyric_vocab_size=80,  # 歌词词汇量大小,默认为 80
        mask=False,  # 是否使用掩码,默认为 False
        max_duration=600,  # 最大持续时间,默认为 600
        max_nb_genres=1,  # 最大音乐类型数,默认为 1
        merged_decoder=True,  # 是否合并解码器,默认为 True
        metadata_conditioning=True,  # 是否使用元数据条件,默认为 True
        metadata_dims=[604, 7898],  # 元数据维度,默认为 [604, 7898]
        min_duration=0,  # 最小持续时间,默认为 0
        mlp_multiplier=1.0,  # 多层感知机乘数,默认为 1.0
        music_vocab_size=2048,  # 音乐词汇量大小,默认为 2048
        n_ctx=6144,  # 上下文大小,默认为 6144
        n_heads=2,  # 多头注意力的头数,默认为 2
        nb_relevant_lyric_tokens=384,  # 相关歌词标记数,默认为 384
        res_conv_depth=3,  # 残余卷积深度,默认为 3
        res_conv_width=128,  # 残余卷积宽度,默认为 128
        res_convolution_multiplier=1,  # 残余卷积乘数,默认为 1
        res_dilation_cycle=None,  # 残余扩张周期,默认为 None
        res_dilation_growth_rate=1,  # 残余扩张增长率,默认为 1
        res_downs_t=[3, 2, 2],  # 残余下采样时序,默认为 [3, 2, 2]
        res_strides_t=[2, 2, 2],  # 残余步长时序,默认为 [2, 2, 2]
        resid_dropout=0,  # 残余 dropout 概率,默认为 0
        sampling_rate=44100,  # 采样率,默认为 44100
        spread=None,  # 传播参数,默认为 None
        timing_dims=64,  # 时间维度,默认为 64
        zero_out=False,  # 是否清零,默认为 False
        **kwargs,  # 其他关键字参数,用于捕获未指定的关键字参数
        # 初始化函数,接受多个参数并将它们赋值给对象的属性
        self.act_fn = act_fn
        self.alignment_head = alignment_head
        self.alignment_layer = alignment_layer
        self.attention_multiplier = attention_multiplier
        self.attention_pattern = attention_pattern
        self.attn_dropout = attn_dropout
        self.attn_res_scale = attn_res_scale
        self.blocks = blocks
        self.conv_res_scale = conv_res_scale
        self.num_layers = num_layers
        self.emb_dropout = emb_dropout
        self.music_vocab_size = music_vocab_size
        # 如果提供了编码器配置,将其转换为 JukeboxPriorConfig 对象
        if encoder_config is not None:
            self.encoder_config = JukeboxPriorConfig(**encoder_config)
            self.encoder_config = None
        self.encoder_loss_fraction = encoder_loss_fraction
        self.init_scale = init_scale
        self.is_encoder_decoder = is_encoder_decoder
        self.lyric_vocab_size = lyric_vocab_size
        self.level = level
        self.mask = mask
        self.max_duration = max_duration
        self.max_nb_genres = max_nb_genres
        self.merged_decoder = merged_decoder
        self.metadata_conditioning = metadata_conditioning
        self.metadata_dims = metadata_dims
        self.min_duration = min_duration
        self.mlp_multiplier = mlp_multiplier
        self.n_ctx = n_ctx
        self.n_heads = n_heads
        self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
        self.res_conv_depth = res_conv_depth
        self.res_conv_width = res_conv_width
        self.res_convolution_multiplier = res_convolution_multiplier
        self.res_dilation_cycle = res_dilation_cycle
        self.res_dilation_growth_rate = res_dilation_growth_rate
        self.res_downs_t = res_downs_t
        self.res_strides_t = res_strides_t
        self.resid_dropout = resid_dropout
        self.sampling_rate = sampling_rate
        self.spread = spread
        self.timing_dims = timing_dims
        self.hidden_size = hidden_size
        self.zero_out = zero_out
        # 设置对象的初始化完成标志

    def from_pretrained(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
    ) -> "PretrainedConfig":
        # 设置传递给类方法的特殊标记位

        # 获取配置字典和更新后的 kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果配置字典指定模型类型为 "jukebox",则使用对应级别的先验配置
        if config_dict.get("model_type") == "jukebox":
            config_dict = config_dict[f"prior_{level}"]

        # 检查配置字典中的模型类型是否与类的模型类型匹配,如果不匹配则发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."

        # 从配置字典创建并返回一个新的 PretrainedConfig 对象实例
        return cls.from_dict(config_dict, **kwargs)
# 定义 JukeboxVQVAEConfig 类,继承自 PretrainedConfig 类,用于存储 JukeboxVQVAE 模型的配置信息
class JukeboxVQVAEConfig(PretrainedConfig):
    This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
    `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the VQVAE from
    [openai/jukebox-1b-lyrics]( architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    # 定义一个函数,用于构建 VQVAE 模型
    def build_model(
        act_fn: str = "relu",  # 激活函数,默认为 ReLU
        nb_discrete_codes: int = 2048,  # VQVAE 的离散码数量,默认为 2048
        commit: float = 0.02,  # Commit loss 的乘数,默认为 0.02
        conv_input_shape: int = 1,  # 音频通道数,默认为 1
        conv_res_scale: bool = False,  # 是否缩放 JukeboxResConv1DBlock 的残差,默认为 False
        embed_dim: int = 64,  # Codebook 向量的嵌入维度,默认为 64
        hop_fraction: List[int] = [0.125, 0.5, 0.5],  # 进行采样过程时使用的非交叠窗口的分数列表,默认为 [0.125, 0.5, 0.5]
        levels: int = 3,  # 在 VQVAE 中使用的层级数,默认为 3
        lmu: float = 0.99,  # 用于代码本更新的指数移动平均系数,默认为 0.99
        multipliers: List[int] = [2, 1, 1],  # 每个层级使用的深度和宽度乘数列表,默认为 [2, 1, 1]
        res_conv_depth: int = 4,  # 编码器和解码器块的深度,默认为 4
        res_conv_width: int = 32,  # 编码器和解码器块的宽度,默认为 32
        res_convolution_multiplier: int = 1,  # JukeboxResConv1DBlock 中隐藏维度的缩放因子,默认为 1
        res_dilation_cycle: int = None,  # JukeboxResnet 中使用的扩张周期值,默认为 None
        res_dilation_growth_rate: int = 3,  # VQVAE 中使用的 ResNet 扩张增长率,默认为 3
        res_downs_t: List[int] = [3, 2, 2],  # 分层 VQ-VAE 中每个层级的下采样率列表,默认为 [3, 2, 2]
        res_strides_t: List[int] = [2, 2, 2],  # 分层 VQ-VAE 中每个层级的步长列表,默认为 [2, 2, 2]
        sample_length: int = 1058304,  # VQVAE 的最大输入形状,默认为 1058304
        init_scale: float = 0.2,  # 初始化尺度,默认为 0.2
        zero_out: bool = False,  # 初始化时是否将卷积权重置零,默认为 False
        构建 VQVAE 模型,根据给定的参数设置各种配置和参数。
        # 函数体为空,用于声明函数的开始
    # 设定模型类型为 "jukebox_vqvae"
    model_type = "jukebox_vqvae"
    # 定义类的初始化方法,接受多个参数
    def __init__(
        act_fn="relu",  # 激活函数,默认为 relu
        nb_discrete_codes=2048,  # 离散代码数量,默认为 2048
        commit=0.02,  # commit 参数,默认为 0.02
        conv_input_shape=1,  # 卷积输入形状,默认为 1
        conv_res_scale=False,  # 是否使用卷积残差缩放,默认为 False
        embed_dim=64,  # 嵌入维度,默认为 64
        hop_fraction=[0.125, 0.5, 0.5],  # hop fraction 列表,默认值为 [0.125, 0.5, 0.5]
        levels=3,  # 级别数量,默认为 3
        lmu=0.99,  # lmu 参数,默认为 0.99
        multipliers=[2, 1, 1],  # 多重因子列表,默认为 [2, 1, 1]
        res_conv_depth=4,  # 卷积深度,默认为 4
        res_conv_width=32,  # 卷积宽度,默认为 32
        res_convolution_multiplier=1,  # 卷积乘数,默认为 1
        res_dilation_cycle=None,  # 膨胀周期,默认为 None
        res_dilation_growth_rate=3,  # 膨胀增长率,默认为 3
        res_downs_t=[3, 2, 2],  # 下采样 t 列表,默认为 [3, 2, 2]
        res_strides_t=[2, 2, 2],  # 步幅 t 列表,默认为 [2, 2, 2]
        sample_length=1058304,  # 样本长度,默认为 1058304
        init_scale=0.2,  # 初始化规模,默认为 0.2
        zero_out=False,  # 是否置零,默认为 False
        **kwargs,  # 其他关键字参数
        self.hop_fraction = hop_fraction  # 设置类属性 hop_fraction
        self.conv_input_shape = conv_input_shape  # 设置类属性 conv_input_shape
        self.sample_length = sample_length  # 设置类属性 sample_length
        # 设置 VQVAE 参数(全部使用)
        self.levels = levels
        self.embed_dim = embed_dim
        self.nb_discrete_codes = nb_discrete_codes
        self.res_conv_width = res_conv_width
        self.res_conv_depth = res_conv_depth
        self.res_convolution_multiplier = res_convolution_multiplier
        self.res_dilation_growth_rate = res_dilation_growth_rate
        self.res_dilation_cycle = res_dilation_cycle
        self.multipliers = multipliers
        self.res_downs_t = res_downs_t
        self.res_strides_t = res_strides_t
        self.lmu = lmu
        self.commit = commit
        self.conv_res_scale = conv_res_scale
        self.act_fn = act_fn
        self.init_scale = init_scale
        self.zero_out = zero_out
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        cls._set_token_in_kwargs(kwargs)  # 在关键字参数中设置令牌
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)  # 获取配置字典和更新后的关键字参数
        # 如果加载的是 CLIPConfig,获取文本配置字典
        if config_dict.get("model_type") == "jukebox":
            config_dict = config_dict["vqvae_config"]
        # 检查配置字典中的模型类型是否与类中定义的模型类型一致
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
        # 从配置字典和关键字参数实例化类并返回
        return cls.from_dict(config_dict, **kwargs)
class JukeboxConfig(PretrainedConfig):
    This is the configuration class to store the configuration of a [`JukeboxModel`].

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
    yield a similar configuration to that of
    [openai/jukebox-1b-lyrics]( architecture.

    The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
    (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
    to get the second level codes. This is mostly true for training the top level prior and the upsamplers.

        vqvae_config (`JukeboxVQVAEConfig`, *optional*):
            Configuration for the `JukeboxVQVAE` model.
        prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
            List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
        nb_priors (`int`, *optional*, defaults to 3):
            Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
            (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
            trained using a top prior and 2 upsampler priors.
        sampling_rate (`int`, *optional*, defaults to 44100):
            Sampling rate of the raw audio.
        timing_dims (`int`, *optional*, defaults to 64):
            Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
            layer. The timing embedding layer converts the absolute and relative position in the currently sampled
            audio to a tensor of length `timing_dims` that will be added to the music tokens.
        min_duration (`int`, *optional*, defaults to 0):
            Minimum duration of the audios to generate
        max_duration (`float`, *optional*, defaults to 600.0):
            Maximum duration of the audios to generate
        max_nb_genres (`int`, *optional*, defaults to 5):
            Maximum number of genres that can be used to condition a single sample.
        metadata_conditioning (`bool`, *optional*, defaults to `True`):
            Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum


    >>> from transformers import JukeboxModel, JukeboxConfig

    >>> # Initializing a Jukebox configuration
    >>> configuration = JukeboxConfig()

    >>> # Initializing a model from the configuration
    >>> model = JukeboxModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    # 类型标识符,用于标识该配置类是`jukebox`类型的配置
    model_type = "jukebox"
    # 初始化方法,用于实例化 JukeboxConfig 对象
    def __init__(
        # 如果 vqvae_config 为 None,则用空字典初始化
        if vqvae_config is None:
            vqvae_config = {}
            # 记录日志,说明 vqvae_config 是 None,使用默认值初始化 JukeboxVQVAE
  "vqvae_config is None. initializing the JukeboxVQVAE with default values.")

        # 使用给定的 vqvae_config 字典创建 JukeboxVQVAEConfig 对象,并赋值给 self.vqvae_config
        self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)

        # 如果 prior_config_list 不为 None,则依次用 JukeboxPriorConfig 类实例化列表中的每个配置
        if prior_config_list is not None:
            self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
            # 否则初始化为空列表
            self.prior_configs = []
            # 对于每个 prior_idx 在 nb_priors 范围内,尝试从 kwargs 中获取配置信息,如果没有则使用空字典初始化
            for prior_idx in range(nb_priors):
                prior_config = kwargs.pop(f"prior_{prior_idx}", None)
                if prior_config is None:
                    prior_config = {}
                    # 记录日志,说明该 prior_idx 的配置是 None,使用默认值初始化 JukeboxPriorConfig 列表
                        f"prior_{prior_idx}'s  config is None. Initializing the JukeboxPriorConfig list with default"
                        " values."
                # 使用 prior_config 字典创建 JukeboxPriorConfig 对象,并添加到 prior_configs 列表中

        # 将 vqvae_config 中的 hop_fraction 属性赋值给当前对象的 hop_fraction 属性
        self.hop_fraction = self.vqvae_config.hop_fraction

        # 将传入的各种元数据配置参数赋值给对象的相应属性
        self.nb_priors = nb_priors
        self.max_nb_genres = max_nb_genres
        self.sampling_rate = sampling_rate
        self.timing_dims = timing_dims
        self.min_duration = min_duration
        self.max_duration = max_duration
        self.metadata_conditioning = metadata_conditioning

        # 调用父类的初始化方法,传入剩余的 kwargs 参数

    def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
        Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model

            [`JukeboxConfig`]: An instance of a configuration object
        # 将 prior_configs 列表中每个配置对象转换为字典形式,存入 prior_config_list
        prior_config_list = [config.to_dict() for config in prior_configs]
        # 调用当前类的初始化方法,传入 prior_config_list 和 vqvae_config 的字典形式,以及 kwargs 参数
        return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)

    def to_dict(self):
        # 重写父类的 to_dict 方法,将对象转换为字典形式
        result = super().to_dict()
        # 将 prior_configs 列表中每个配置对象转换为字典形式,存入 result 字典的 "prior_config_list" 键下
        result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
        return result


import argparse  # 导入处理命令行参数的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入操作系统相关功能的模块
from pathlib import Path  # 导入处理路径相关操作的模块

import requests  # 导入发送 HTTP 请求的模块
import torch  # 导入 PyTorch 深度学习框架

from transformers import JukeboxConfig, JukeboxModel  # 导入 Jukebox 模型相关类
from transformers.utils import logging  # 导入日志记录模块

logging.set_verbosity_info()  # 设置日志记录级别为信息
logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

PREFIX = ""  # Jukebox 模型下载地址的前缀
    "jukebox-1b-lyrics": [
    "jukebox-5b-lyrics": [

def replace_key(key):
    if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
        key = key.replace(".model.1.bias", ".conv1d_1.bias")  # 替换模型参数键名中的 ".model.1.bias" 为 ".conv1d_1.bias"
    elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
        key = key.replace(".model.1.weight", ".conv1d_1.weight")  # 替换模型参数键名中的 ".model.1.weight" 为 ".conv1d_1.weight"
    elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
        key = key.replace(".model.3.bias", ".conv1d_2.bias")  # 替换模型参数键名中的 ".model.3.bias" 为 ".conv1d_2.bias"
    elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
        key = key.replace(".model.3.weight", ".conv1d_2.weight")  # 替换模型参数键名中的 ".model.3.weight" 为 ".conv1d_2.weight"

    if "conditioner_blocks.0." in key:
        key = key.replace("conditioner_blocks.0", "conditioner_blocks")  # 替换模型参数键名中的 "conditioner_blocks.0" 为 "conditioner_blocks"

    if "prime_prior" in key:
        key = key.replace("prime_prior", "encoder")  # 替换模型参数键名中的 "prime_prior" 为 "encoder"

    if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
        key = key.replace(".emb.", ".")  # 替换模型参数键名中的 ".emb." 为 "."

    if key.endswith("k"):  # 如果键名以 "k" 结尾,替换为以 "codebook" 结尾
        return key.replace(".k", ".codebook")
    if "y_emb." in key:
        return key.replace("y_emb.", "metadata_embedding.")  # 替换模型参数键名中的 "y_emb." 为 "metadata_embedding."

    if "x_emb.emb." in key:
        key = key.replace("0.x_emb.emb", "embed_tokens")  # 替换模型参数键名中的 "0.x_emb.emb" 为 "embed_tokens"

    if "prime_state_ln" in key:
        return key.replace("prime_state_ln", "encoder.final_layer_norm")  # 替换模型参数键名中的 "prime_state_ln" 为 "encoder.final_layer_norm"
    if ".ln" in key:
        return key.replace(".ln", ".layer_norm")  # 替换模型参数键名中的 ".ln" 为 ".layer_norm"
    if "_ln" in key:
        return key.replace("_ln", "_layer_norm")  # 替换模型参数键名中的 "_ln" 为 "_layer_norm"

    if "prime_state_proj" in key:
        return key.replace("prime_state_proj", "encoder.proj_in")  # 替换模型参数键名中的 "prime_state_proj" 为 "encoder.proj_in"
    if "prime_x_out" in key:
        return key.replace("prime_x_out", "encoder.lm_head")  # 替换模型参数键名中的 "prime_x_out" 为 "encoder.lm_head"
    # 如果字符串 "prior.x_out" 在 key 中,将 "x_out" 替换为 "fc_proj_out" 并返回替换后的结果
    if "prior.x_out" in key:
        return key.replace("x_out", "fc_proj_out")
    # 如果字符串 "x_emb" 在 key 中,将 "x_emb" 替换为 "embed_tokens" 并返回替换后的结果
    if "x_emb" in key:
        return key.replace("x_emb", "embed_tokens")

    # 如果以上条件都不满足,则返回 key 本身
    return key
def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
    # 初始化一个空字典,用于存储修复后的模型权重
    new_dict = {}
    # 导入正则表达式模块用于匹配模型权重的键名
    import re

    # 正则表达式用于匹配编码器块的卷积层输入
    re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配编码器块的ResNet结构
    re_encoder_block_resnet = re.compile(
    # 正则表达式用于匹配编码器块的投影输出
    re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")

    # 正则表达式用于匹配解码器块的卷积层输出
    re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配解码器块的ResNet结构
    re_decoder_block_resnet = re.compile(
    # 正则表达式用于匹配解码器块的投影输入
    re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")

    # 正则表达式用于匹配先验条件块的卷积层输出
    re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
    # 正则表达式用于匹配先验条件块的ResNet结构
    re_prior_cond_resnet = re.compile(
    # 正则表达式用于匹配先验条件块的投影输入
    re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")

    # 返回初始化的空字典
    return new_dict

def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
    Copy/paste/tweak model's weights to our Jukebox structure.
    # 遍历模型映射中的每个文件路径
    for file in MODEL_MAPPING[model_name]:
        # 如果文件不存在于指定路径中,则从URL下载文件并保存
        if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
            r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
            os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
            open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)

    # 根据模型名称加载预训练配置和模型
    model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
    config = JukeboxConfig.from_pretrained(model_name)
    model = JukeboxModel(config)

    # 初始化一个空列表用于存储模型的权重字典
    weight_dict = []
    # 初始化一个空字典用于存储映射关系
    mapping = {}

    # 遍历要转换的每个模型字典名称
    for i, dict_name in enumerate(model_to_convert):
        # 从PyTorch模型文件中加载旧的字典
        old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}")["model"]

        # 初始化一个空字典用于存储新的修复后的字典
        new_dic = {}
        # 遍历旧字典的每个键
        for k in old_dic.keys():
            # 根据键名的后缀进行不同的处理
            if k.endswith(".b"):
                new_dic[k.replace("b", "bias")] = old_dic[k]
            elif k.endswith(".w"):
                new_dic[k.replace("w", "weight")] = old_dic[k]
            elif "level_2" not in dict_name and "cond.model." in k:
                new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
                new_dic[k] = old_dic[k]

        # 根据特定前缀修复Jukebox模型的键名
        key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
        new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
        # 将修复后的字典添加到权重列表中

    # 从权重列表中取出VQ-VAE部分的状态字典并加载到模型中
    vqvae_state_dict = weight_dict.pop(0)
    # 遍历权重列表中的每个元素,将其加载到模型的先验概率分布部分
    for i in range(len(weight_dict)):
        model.priors[i].load_state_dict(weight_dict[2 - i])

    # 确保指定路径存在,用于保存转换后的PyTorch模型
    # 将映射数据保存为 JSON 文件到指定路径
    with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
        json.dump(mapping, txtfile)
    # 打印模型保存信息,包括模型名称和保存路径
    print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
    # 使用 PyTorch 模型对象的方法保存模型权重到指定路径
    # 返回保存的权重字典
    return weight_dict
if __name__ == "__main__":
    # 如果当前脚本作为主程序执行,则进入条件判断块

    parser = argparse.ArgumentParser()
    # 创建参数解析器对象

    # Required parameters
        help="Name of the model you'd like to convert.",
    # 添加必选参数:模型名称,设置默认值为"jukebox-5b-lyrics",类型为字符串,用于指定要转换的模型名称

        help="Path to the output PyTorch model directory.",
    # 添加必选参数:PyTorch 模型输出文件夹路径,设置默认值为"jukebox-5b-lyrics-converted",类型为字符串,用于指定输出转换后的模型的存储路径

    args = parser.parse_args()
    # 解析命令行参数并存储到 args 变量中

    convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
    # 调用函数 convert_openai_checkpoint,传入模型名称和输出文件夹路径作为参数,执行模型转换操作


"""PyTorch Jukebox model."""

import math  # 导入数学库
import os  # 导入操作系统相关功能
from typing import List, Optional, Tuple  # 导入类型提示相关的模块

import numpy as np  # 导入NumPy库
import torch  # 导入PyTorch库
import torch.nn.functional as F  # 导入PyTorch中的函数模块
from torch import nn  # 导入PyTorch的神经网络模块
from torch.nn import LayerNorm as FusedLayerNorm  # 导入PyTorch的归一化层模块

from ...activations import ACT2FN  # 导入自定义的激活函数
from ...modeling_utils import PreTrainedModel  # 导入预训练模型基类
from ...utils import add_start_docstrings, logging  # 导入工具函数和日志模块
from ...utils.logging import tqdm  # 导入进度条显示模块
from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig  # 导入配置文件

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

    # See all Jukebox models at

def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering

        logits (`torch.Tensor`):
            logits distribution shape (vocabulary size)
        top_k (`int`, *optional*, defaults to 0):
            When `top_k >0` keep only top key tokens with highest probability (top-k filtering).
        top_p (`int`, *optional*, defaults to 0):
            When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering).
    logits = logits.clone()  # 复制logits张量,确保不改变原始数据
    top_k = min(top_k, logits.size(-1))  # 安全检查,确保top_k不超过logits的最后一个维度大小

    if top_k > 0:
        # 移除概率小于top-k中的最后一个概率的所有标记
        indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:]
        logits[indices_to_remove] = filter_value  # 将这些标记的概率值设置为filter_value
    # 如果给定的 top_p 阈值大于 0,则执行以下操作
    if top_p > 0.0:
        # 对 logits 进行降序排序,并返回排序后的 logits 和对应的索引
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        # 计算排序后的 logits 的累积 softmax 概率
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        # 根据累积概率超过阈值的情况,标记需要移除的索引
        sorted_indices_to_remove = cumulative_probs > top_p
        # 将超过阈值的索引右移一位,以保留第一个超过阈值的 token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        # 根据排序后的索引,创建一个布尔张量表示需要移除的位置
        indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
            dim=-1, index=sorted_indices, src=sorted_indices_to_remove
        # 将需要移除的位置对应的 logits 设置为指定的 filter_value
        logits[indices_to_remove] = filter_value
    # 返回经过处理后的 logits
    return logits
# 根据给定的参数,从完整的歌词标记中提取相关的标记。
# 返回的标记数为 `max_n_lyric_tokens`。如果提供的标记序列较小,将进行填充;否则,只返回从中点向左右偏移 `max_n_lyric_tokens//2` 的标记。
# 这个过程专注于时间上最相关的标记。

def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration):
    Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be
    returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the
    midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on
    the most relevant tokens (in time) for the sequence.

        full_tokens (`List[int]`):
            List containing the token ids of the entire lyrics.
        max_n_lyric_tokens (`int`):
            Maximum number of lyric tokens to return.
        total_length (`int`):
            Total expected length of the music (not all of it is generated, see duration), in samples.
        offset (`int`):
            Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into
        duration (`int`):
            Expected duration of the generated music, in samples. The duration has to be smaller than the total length,
            which represent the overall length of the signal,
    full_tokens = full_tokens[0]  # 取出列表中的第一个元素(预计是整数列表)
    if len(full_tokens) < max_n_lyric_tokens:
        # 如果标记序列长度小于 `max_n_lyric_tokens`,进行填充
        tokens =
            [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens]
        indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens)))
        # 计算中点位置
        midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length)
        # 限制中点位置在有效范围内
        midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2)
        # 提取中心周围的标记
        tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2]
        indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2))
    return tokens.unsqueeze(dim=0), indices

# 将总长度 `total_length` 拆分为大小为 `n_ctx` 的窗口,每隔 `hop_length` 个样本分隔开
def get_starts(total_length, n_ctx, hop_length):
    starts = []
    for start in range(0, total_length - n_ctx + hop_length, hop_length):
        if start + n_ctx >= total_length:
            # 最后一个窗口可能会较小,我们设定为 `n_ctx` 以最大化上下文
            start = total_length - n_ctx
    return starts

# 获取音乐标记、标签、先验值和配置信息,返回对齐信息
def get_alignment(music_tokens, labels, prior, config):
    level = prior.levels - 1  # 使用的顶层
    n_ctx = prior.n_ctx
    tokens = music_tokens[level]
    batch_size, total_length = tokens.shape[0], tokens.shape[1]
    if total_length < n_ctx:
        # 如果总长度小于 `n_ctx`,进行填充
        padding_length = n_ctx - total_length
        tokens =
            [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1
        total_length = tokens.shape[1]
        padding_length = 0

    # 计算 `hop_length`,这是根据配置的音频片段长度的分数决定的
    hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx)
    # 从配置中获取对齐头部和对齐层信息,并选择第一个
    alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0]
    # 创建包含alignment_layer的集合
    attn_layers = {alignment_layer}
    # 创建空的对齐跳数字典
    alignment_hops = {}
    # 创建空的索引跳数字典
    indices_hops = {}
    # 对于每个从get_starts生成的起始位置进行迭代,显示"Computing lyric to music alignment"进度条
    for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "):
        end = start + n_ctx
        # 获取metadata和indices_hop,从prior获取标签,开始,采样长度,并获取indices
        metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0)
        # 将tokens分块为batch_size大小的张量块
        tokens_bs = torch.chunk(tokens, batch_size, dim=0)
        # 将metadata分块为batch_size大小的张量块
        metadata_bs = torch.chunk(metadata, batch_size, dim=0)
        # 创建空列表w_hops
        w_hops = []
        # 对于tokens_bs和metadata_bs中的每一对,执行以下操作
        for tokens_i, metadata_i in zip(tokens_bs, metadata_bs):
            # 调用prior的forward_tokens函数,传递tokens_i[:, start:end],空列表,metadata_i参数,获取attn_layers的注意力权重
            w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers)
            # 将第一个返回的注意力权重的alignment_head列添加到w_hops中
            w_hops.append(w_hop[0][:, alignment_head])
            # 删除w_hop变量以释放内存
            del w_hop
        # 将w_hops中的张量连接成一个张量weights
        weights =, dim=0)
        # 删除w_hops以释放内存
        del w_hops
        # 将weights转换为float类型,移动到CPU上,并转换为numpy数组,存储在alignment_hop中
        alignment_hop = weights.float().cpu().numpy()
        # 删除weights以释放内存
        del weights

        # alignment_hop的形状为(bs, n_ctx, nb_relevant_lyric_tokens)
        # indices_hop是长度为bs的列表,每个条目长度为hps.nb_relevant_lyric_tokens
        # 将indices_hop和alignment_hop存储在对应的start位置
        indices_hops[start] = indices_hop
        alignment_hops[start] = alignment_hop

    # 将每个跳的attn组合成全范围的attn
    # 使用indices将它们放置到相应源tokens的正确位置
    alignments = []
    for item in range(batch_size):
        # 注意每个item具有不同长度的歌词
        full_tokens = labels[0, 3:]
        # 创建全零数组alignment,形状为(total_length, len(full_tokens) + 1)
        alignment = np.zeros((total_length, len(full_tokens) + 1))
        # 对于反向遍历的每个start,执行以下操作
        for start in reversed(get_starts(total_length, n_ctx, hop_length)):
            end = start + n_ctx
            # 获取alignment_hops中的alignment_hop[item]和indices_hops中的indices[item]
            alignment_hop = alignment_hops[start][item]
            indices = indices_hops[start][item]
            # 将alignment_hop放置到对应的indices位置
            alignment[start:end, indices] = alignment_hop
        # 去除token填充和最后一个歌词索引,截取alignment数组
        alignment = alignment[: total_length - padding_length, :-1]
        # 将alignment添加到alignments列表中
    # 返回alignments列表作为函数结果
    return alignments
# 定义一个函数,用于保存临时音频数据
def save_temp_audio(fname, lvl, metas, aud):
    # 将音频数据限制在[-1, 1]范围内,并转换为numpy数组
    aud = torch.clamp(aud, -1, 1).cpu().numpy()
    # 遍历音频数据的每一个片段
    for i in list(range(aud.shape[0])):
        # 如果提供了元数据信息
        if metas is not None:
            # 获取当前片段的艺术家、流派和歌词信息
            artists, genres, lyrics = list(metas)[i].values()
            # 构建保存路径,包含文件夹名、级别、艺术家、流派、歌词前5个字符和索引信息
            path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}"
            # 保存numpy数组为.npy文件
  , aud[i])
            # 如果未提供元数据信息,直接保存为.npy文件,文件名包含级别和索引信息
  "{fname}/lvl_{lvl}-sample-{i}", aud[i])

# 定义一个函数,根据不同的掩码类型生成掩码张量
def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t):
    # 如果掩码为None或者查询长度为1,则返回None,表示无需掩码
    if mask is None or query_length == 1:
        return None
    # 计算偏移量,用于掩码生成的起始位置
    offset = sample_t - query_length if sample else max(key_value_length - query_length, 0)
    # 根据不同的掩码类型生成相应的掩码张量
    if mask == "autoregressive":
        # 自回归掩码:下三角形式的矩阵,掩盖查询和键值之间的依赖关系
        mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
    elif mask == "summary":
        # 摘要掩码:用于对输入进行汇总处理时使用的掩码
        mask = torch.ones(query_length, query_length, device=device).tril()
        mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :]
        mask = (
                (0, 0, 1, 0),
            .view(query_length, key_value_length)
    elif mask == "prime":
        # 主掩码:一种特定形式的下三角掩码
        mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
    return mask.view(1, 1, query_length, key_value_length)

# 定义一个神经网络模型类,实现基于卷积的Jukebox模型
class JukeboxConv1D(nn.Module):
    def __init__(self, input_width, output_width):
        self.input_width = input_width
        self.output_width = output_width
        # 初始化权重和偏置参数
        weight = torch.empty(input_width, output_width)
        bias = torch.zeros(output_width)
        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias)

    def forward(self, hidden_states):
        # 计算输出大小
        size_out = (*hidden_states.size()[:-1], self.output_width)
        # 执行卷积操作,并加上偏置项
        hidden_states = torch.addmm(
            hidden_states.view(-1, hidden_states.size(-1)),
        # 重新调整输出形状并返回结果
        hidden_states = hidden_states.view(*size_out)
        return hidden_states

# 定义一个神经网络模型类,实现基于残差卷积的Jukebox模型块
class JukeboxResConv1DBlock(nn.Module):
    def __init__(self, config, conv_width, depth=1, res_scale=1.0):
        # 根据配置参数计算隐藏层维度、膨胀率和填充大小
        hidden_dim = config.res_convolution_multiplier * conv_width
        dilation = config.res_dilation_growth_rate**depth
        padding = dilation

        self.res_scale = res_scale
        self.activation = nn.ReLU()
        # 定义第一个卷积层和第二个卷积层
        self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation)
        self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0)
    # 定义一个前向传播方法,用于神经网络模型中
    def forward(self, hidden_states):
        # 将输入的隐藏状态保存为残差项
        residuals = hidden_states
        # 对隐藏状态应用激活函数
        hidden_states = self.activation(hidden_states)
        # 应用第一个一维卷积层
        hidden_states = self.conv1d_1(hidden_states)
        # 再次应用激活函数
        hidden_states = self.activation(hidden_states)
        # 应用第二个一维卷积层
        hidden_states = self.conv1d_2(hidden_states)
        # 返回残差项与带有残差缩放系数的隐藏状态的和
        return residuals + self.res_scale * hidden_states
# 定义 JukeboxResnet1D 类,继承自 nn.Module 类,实现了一维卷积神经网络的残差结构
class JukeboxResnet1D(nn.Module):
    # 初始化函数,接受配置 config、卷积宽度 conv_width、深度 n_depth、是否反向扩张 reverse_dilation 参数
    def __init__(self, config, conv_width, n_depth, reverse_dilation=False):
        # 调用父类的初始化方法
        # 根据配置设置残差扩张周期
        self.dilation_cycle = config.res_dilation_cycle
        # 如果配置了卷积残差缩放,则根据深度设置缩放系数
        res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth)

        # 创建空的块列表
        blocks = []
        # 根据深度循环创建残差卷积块
        for depth in range(n_depth):
            # 如果设置了扩张周期,则取当前深度对扩张周期取模得到块深度
            block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle
            # 创建并添加 JukeboxResConv1DBlock 到块列表中
            blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale))

        # 如果设置了反向扩张,则对块列表进行反向排序
        if reverse_dilation:
            blocks = blocks[::-1]
        # 将块列表转换为 nn.ModuleList 类型的模块列表并赋值给实例变量
        self.resnet_block = nn.ModuleList(blocks)

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 遍历每个残差卷积块,依次对输入进行处理
        for block in self.resnet_block:
            hidden_states = block(hidden_states)
        # 返回处理后的 hidden_states
        return hidden_states

# 定义 JukeboxEncoderConvBlock 类,继承自 nn.Module 类,实现了编码器的卷积块结构
class JukeboxEncoderConvBlock(nn.Module):
    # 初始化函数,接受配置 config、嵌入维度 embed_dim、隐藏维度 hidden_dim、深度 depth、down_t、stride_t 参数
    def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t):
        # 调用父类的初始化方法
        # 创建空的块列表
        blocks = []
        # 计算滤波器大小 filter_t 和填充大小 pad_t
        filter_t = stride_t * 2
        pad_t = stride_t // 2
        # 如果 down_t 大于 0,则循环添加卷积层和残差卷积块到块列表中
        if down_t > 0:
            for i in range(down_t):
                # 添加 1 维卷积层到块列表中
                blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t))
                # 添加 JukeboxResnet1D 模块到块列表中
                blocks.append(JukeboxResnet1D(config, hidden_dim, depth))

        # 创建输出投影层
        self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1)
        # 将块列表转换为 nn.ModuleList 类型的模块列表并赋值给实例变量
        self.downsample_block = nn.ModuleList(blocks)

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 遍历每个块,依次对输入进行处理
        for block in self.downsample_block:
            hidden_states = block(hidden_states)
        # 将处理后的 hidden_states 经过投影层处理后返回
        hidden_states = self.proj_out(hidden_states)
        return hidden_states

# 定义 JukeboxEncoder 类,继承自 nn.Module 类,实现了 Jukebox 编码器结构
class JukeboxEncoder(nn.Module):
    # 初始化函数,接受配置 config、宽度 width、深度 depth、层级 levels、downs_t、strides_t 参数
    def __init__(self, config, width, depth, levels, downs_t, strides_t):
        # 调用父类的初始化方法
        # 设置层级数
        self.levels = levels
        # 创建模块列表 level_blocks
        self.level_blocks = nn.ModuleList()

        # 使用 zip 函数迭代 levels、downs_t 和 strides_t,并根据迭代结果生成 JukeboxEncoderConvBlock 模块并添加到 level_blocks 中
        iterator = zip(list(range(self.levels)), downs_t, strides_t)
        for i, down_t, stride_t in iterator:
                    config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t

    # 前向传播函数,接受输入 hidden_states
    def forward(self, hidden_states):
        # 创建空列表 all_hidden_states 用于存储所有层级的隐藏状态
        all_hidden_states = []

        # 遍历每个层级
        for level in range(self.levels):
            # 获取当前层级的 JukeboxEncoderConvBlock 模块
            level_block = self.level_blocks[level]
            # 对输入 hidden_states 应用当前层级的模块处理
            hidden_states = level_block(hidden_states)
            # 将处理后的隐藏状态添加到 all_hidden_states 列表中

        # 返回所有层级的隐藏状态列表
        return all_hidden_states

# 定义 JukeboxDecoderConvBock 类,继承自 nn.Module 类,未完成的类定义
class JukeboxDecoderConvBock(nn.Module):
    # 初始化函数,用于初始化类实例
    def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True):
        # 设置类的属性 embed_dim 和 hidden_dim
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        # 调用父类的初始化方法
        # 初始化空列表用于存储模块
        blocks = []
        # 如果 down_t 大于 0,执行以下操作
        if down_t > 0:
            # 计算滤波器长度和填充长度
            filter_t = stride_t * 2
            pad_t = stride_t // 2
            # 创建输入投影层,将 embed_dim 维度的输入转换为 hidden_dim 维度
            self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1)
            # 循环 down_t 次,添加 JukeboxResnet1D 模块和反卷积层到 blocks 列表中
            for i in range(down_t):
                blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation))
                        hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t
        # 将 blocks 列表作为 ModuleList 赋给实例的 upsample_block 属性
        self.upsample_block = nn.ModuleList(blocks)

    # 前向传播函数,处理输入的隐藏状态
    def forward(self, hidden_states):
        # 将输入的隐藏状态通过投影层 proj_in 进行转换
        hidden_states = self.proj_in(hidden_states)
        # 对 upsample_block 中的每个模块进行前向传播
        for block in self.upsample_block:
            hidden_states = block(hidden_states)
        # 返回处理后的隐藏状态
        return hidden_states
class JukeboxDecoder(nn.Module):
    # 定义JukeboxDecoder类,继承自nn.Module
    def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t):
        self.levels = levels
        self.level_blocks = nn.ModuleList()
        for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t):
                JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t)

        self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1)
        # 初始化各个网络层

    def forward(self, hidden_states, all_levels=True):
        hidden_state = hidden_states[-1]

        # 32, 64 ...
        for level in reversed(range(self.levels)):
            level_block = self.level_blocks[level]
            hidden_state = level_block(hidden_state)

            if level != 0 and all_levels:
                hidden_state = hidden_state + hidden_states[level - 1]
        # 在不同的层级进行前向传播,并根据需要进行级联

        hidden_state = self.out(hidden_state)
        return hidden_state
        # 返回隐藏状态

class JukeboxBottleneckBlock(nn.Module):
    # 定义JukeboxBottleneckBlock类,继承自nn.Module
    def __init__(self, config: JukeboxVQVAEConfig):
        self.nb_discrete_codes = config.nb_discrete_codes
        self.codebook_width = config.embed_dim = config.lmu
        self.threshold = 1.0
        self.init = False
        self.codebook_sum = None
        self.codebook_elem = None
        self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width))
        # 初始化相关变量,并注册缓冲区

    def _tile(self, hidden_states):
        dim, embed_width = hidden_states.shape
        if dim < self.nb_discrete_codes:
            n_repeats = (self.nb_discrete_codes + dim - 1) // dim
            std = 0.01 / np.sqrt(embed_width)
            hidden_states = hidden_states.repeat(n_repeats, 1)
            hidden_states = hidden_states + torch.randn_like(hidden_states) * std
        return hidden_states
        # 定义辅助函数_tile,用于重复和扩展隐藏状态

    def init_codebook(self, hidden_states):
        nb_discrete_codes = self.nb_discrete_codes
        self.init = True
        codes = self._tile(hidden_states)
        self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
        self.codebook_sum = self.codebook
        self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device)
        # 初始化码书信息
    # 更新代码本函数,更新代码簿中的中心点
    def update_codebook(self, hidden_states, latent_states):
        # 从对象属性中获取参数
        mu, codebook_width, nb_discrete_codes =, self.codebook_width, self.nb_discrete_codes
        # 禁止梯度计算
        with torch.no_grad():
            # 计算新的中心点
            # 将离散状态转换为独热编码
            latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device)
            latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1)

            # 计算每个簇的加权和
            _codebook_sum = torch.matmul(latent_states_onehot, hidden_states)
            # 计算每个簇的元素数量
            _codebook_elem = latent_states_onehot.sum(dim=-1)  # nb_discrete_codes
            # 复制隐藏状态以扩展簇的数量
            codes = self._tile(hidden_states)
            # 随机选取一些代码本的样本
            _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]

            # 更新中心点
            old_codebook = self.codebook
            # 更新加权和
            self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum
            # 更新簇的元素数量
            self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem  # nb_discrete_codes
            # 计算每个簇的使用情况
            usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float()

            # 归一化簇的中心点
            norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view(
                nb_discrete_codes, 1
            # 更新代码本
            self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook
            # 计算每个簇的概率
            _codebook_prob = _codebook_elem / torch.sum(_codebook_elem)  # prob of each bin
            # 计算熵,用于衡量多样性
            entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8))  # entropy ie how diverse
            # 计算当前使用的簇的数量
            used_curr = (_codebook_elem >= self.threshold).sum()
            # 计算簇的使用情况
            usage = torch.sum(usage)
            # 计算 K-L 散度
            dk = torch.norm(self.codebook - old_codebook) / np.sqrt(
        # 返回更新结果
        return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}

    # 预处理函数,用于规范化隐藏状态
    def preprocess(self, hidden_states):
        # 调整张量形状以便后续处理
        hidden_states = hidden_states.permute(0, 2, 1).contiguous()
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

        # 如果隐藏状态的维度等于代码本的宽度
        if hidden_states.shape[-1] == self.codebook_width:
            # 计算预规范化值
            prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(
        # 如果隐藏状态的维度是代码本宽度的两倍
        elif hidden_states.shape[-1] == 2 * self.codebook_width:
            # 分离隐藏状态的两部分
            x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
            # 分别计算两部分的预规范化值,并相加
            prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt( + (
                torch.norm(x2 - torch.mean(x2)) / np.sqrt(
            # 合并隐藏状态的两部分
            hidden_states = x1 + x2

        # 返回预处理后的隐藏状态及其规范化值
        return hidden_states, prenorm
    def postprocess(self, latent_states, dequantised_states, x_shape):
        # 获取输入数据的批次大小和时间步数
        batch_size, time = x_shape
        # 重新组织 dequantised_states 的形状,使其变为 (batch_size, -1, time)
        dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous()
        # 重新组织 latent_states 的形状,使其变为 (batch_size, time)
        latent_states = latent_states.view(batch_size, time)
        return latent_states, dequantised_states

    def quantise(self, latent_states):
        # 计算 latent_states 与 codebook 中的距离
        codebook_weights = self.codebook.t()
        distance = (
            torch.sum(latent_states**2, dim=-1, keepdim=True)
            - 2 * torch.matmul(latent_states, codebook_weights)
            + torch.sum(codebook_weights**2, dim=0, keepdim=True)
        )  # 形状为 (batch_size * latent_states , codebook_weights)
        # 找到每个 latent_state 最接近的 codebook 中的索引 music_tokens
        min_distance, music_tokens = torch.min(distance, dim=-1)
        # 计算平均最小距离
        fit = torch.mean(min_distance)
        return music_tokens, fit

    def dequantise(self, music_tokens):
        # 使用 music_tokens 从 codebook 中获取对应的 dequantised_states
        dequantised_states = F.embedding(music_tokens, self.codebook)
        return dequantised_states

    def encode(self, latent_states):
        samples, _, seq_len = latent_states.shape

        # 数据预处理
        latent_states, _ = self.preprocess(latent_states)

        # 量化过程
        music_tokens, _ = self.quantise(latent_states)

        # 后处理
        music_tokens = music_tokens.view(samples, seq_len)
        return music_tokens

    def decode(self, music_tokens):
        samples, seq_len = music_tokens.shape

        # 反量化过程
        dequantised_states = self.dequantise(music_tokens)

        # 后处理
        dequantised_states = (
            dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous()
        return dequantised_states

    def forward(self, hidden_states, update_codebook=True):
        samples, _, seq_len = hidden_states.shape

        # 数据预处理
        hidden_states, prenorm = self.preprocess(hidden_states)

        # 如果需要更新 codebook 并且未初始化,则进行初始化
        if update_codebook and not self.init:

        # 通过编码和解码过程量化和反量化
        music_tokens, fit = self.quantise(hidden_states)
        dequantised_states = self.dequantise(music_tokens)

        # 如果需要更新 codebook,则更新相关指标
        if update_codebook:
            update_metrics = self.update_codebook(hidden_states, music_tokens)
            update_metrics = {}

        # 计算损失
        commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 /

        # 通过传递增强数据流
        dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()

        # 后处理
        music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len))
        return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
# 导入 PyTorch 的 nn 模块
import torch.nn as nn

# 定义一个名为 JukeboxBottleneck 的类,继承自 nn.Module
class JukeboxBottleneck(nn.Module):
    # 初始化方法,接受 config 和 levels 参数
    def __init__(self, config, levels):
        self.levels = levels  # 设置 levels 属性
        self.level_blocks = nn.ModuleList()  # 初始化一个 nn.ModuleList 用于存储每个 level 的 block
        # 遍历 levels 创建 JukeboxBottleneckBlock,并添加到 level_blocks 中
        for level in range(self.levels):

    # 编码方法,接受 raw_audio 参数
    def encode(self, raw_audio):
        # 使用列表推导式对每个 level_block 和对应的 hidden_states 进行编码
        music_tokens = [
            level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio)
        return music_tokens  # 返回编码后的音乐 tokens

    # 解码方法,接受 music_tokens、start_level 和 end_level 参数
    def decode(self, music_tokens, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels  # 如果未指定 end_level,默认为 levels
        # 使用列表推导式对每个 level_block 和对应的 music_tokens 进行解码
        quantised_audio = [
            level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens)
        return quantised_audio  # 返回量化后的音频数据

    # 前向传播方法,接受 input_audio 参数
    def forward(self, input_audio):
        music_tokens, quantised_states, commit_losses, metrics = [], [], [], []
        # 遍历每个 level
        for level in range(self.levels):
            level_block = self.level_blocks[-level - 1]  # 获取当前 level 的 block
            hidden_states = input_audio[level]  # 获取对应的输入音频的隐藏状态
            # 调用 level_block 进行处理,获取返回值
            sampled_tokens, quantised_state, commit_loss, metric = level_block(
            music_tokens.append(sampled_tokens)  # 将 sampled_tokens 添加到 music_tokens 列表中
            if not
                # 在非训练模式下,确保编码器权重不会从直通估计中更改
                quantised_state = quantised_state.detach()
            quantised_states.append(quantised_state)  # 将 quantised_state 添加到 quantised_states 列表中
            commit_losses.append(commit_loss)  # 将 commit_loss 添加到 commit_losses 列表中
                metrics.append(metric)  # 在训练模式下,将 metric 添加到 metrics 列表中
        # 返回 music_tokens、quantised_states、commit_losses 和 metrics
        return music_tokens, quantised_states, commit_losses, metrics

# 设置 JUKEBOX_START_DOCSTRING 变量,包含模型的一些基本文档信息

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads

    This model is also a PyTorch [torch.nn.Module]( subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

        config (`JukeboxConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.

# 使用 @add_start_docstrings 装饰器添加额外的文档信息到 JukeboxVQVAE 类
    """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam
Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](

# 定义 JukeboxVQVAE 类,继承自 PreTrainedModel
class JukeboxVQVAE(PreTrainedModel):
    config_class = JukeboxVQVAEConfig  # 设置 config_class 属性
    # 设置基础模型前缀为 "vqvae"
    base_model_prefix = "vqvae"

    # 初始化权重的函数,用于初始化模块的权重
    def _init_weights(self, module):
        # 如果模块是 nn.Embedding 类型,例如 embed_tokens
        if isinstance(module, nn.Embedding):
            # 初始化权重为正态分布,均值为 0,标准差为 0.02 乘以配置参数中的初始化比例
  , std=0.02 * self.config.init_scale)
        # 如果模块是 JukeboxConv1D 类型
        elif isinstance(module, JukeboxConv1D):
            # 根据配置参数决定是否将权重初始化为零,否则初始化为正态分布
            if self.config.zero_out:
      , std=0.02 * self.config.init_scale)
        # 如果模块是 JukeboxResConv1DBlock 类型,并且配置参数中指定了 zero_out 为 True
        elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
            # 将第二个卷积层的权重和偏置初始化为零
        # 如果模块是 nn.LayerNorm 类型
        if isinstance(module, nn.LayerNorm):
            # 将偏置初始化为零
            # 将权重初始化为全 1
        # 如果模块是 nn.Linear 类型,并且有偏置项
        if isinstance(module, nn.Linear) and module.bias is not None:
            # 将偏置项初始化为零

    # 初始化函数,接受一个 JukeboxVQVAEConfig 类型的配置参数
    def __init__(self, config: JukeboxVQVAEConfig):
        # 调用父类的初始化方法,传入配置参数
        # 获取配置参数中的 res_downs_t 和 res_strides_t
        downs_t = config.res_downs_t
        strides_t = config.res_strides_t
        # 如果配置参数中没有指定 sample_length
        if not config.sample_length:
            # 计算 downsamples 数组,每个元素为 stride**down 的结果
            downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
            # 计算 top_raw_to_tokens,即 downsamples 的乘积
            top_raw_to_tokens =
            # 根据采样率和 top_raw_to_tokens 计算 sample_length
            config.sample_length = (
                config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens
            ) * top_raw_to_tokens
            # 将 sample_length 转换为整数类型
            config.sample_length = config.sample_length.astype(int)

        # 设置一些模型参数,从配置中获取
        self.nb_discrete_codes = config.nb_discrete_codes
        self.commit = config.commit
        self.sample_length = config.sample_length

        # 计算 downsamples 数组和 hop_lengths 数组
        self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
        self.hop_lengths = np.cumprod(self.downsamples)
        self.levels = levels = config.levels
        # 计算 music_tokens_shapes 数组
        self.music_tokens_shapes = [
            (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels)

        # 设置 multipliers 数组,如果配置中没有指定,则全部设置为 1
        self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels

        # 初始化 encoders 和 decoders,都是 nn.ModuleList 类型
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for level in range(levels):
            # 计算当前层的宽度和深度
            width = config.res_conv_width * self.multipliers[level]
            depth = config.res_conv_depth * self.multipliers[level]
            # 分别创建 JukeboxEncoder 和 JukeboxDecoder 并加入到 encoders 和 decoders 中
                JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
                JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])

        # 初始化 bottleneck 层
        self.bottleneck = JukeboxBottleneck(config, levels)
    # 解码函数,将音乐编码 `music_tokens` 解码为原始音频表示
    def _decode(self, music_tokens, start_level=0, end_level=None):
        # 如果未指定结束级别,则使用最大级别
        if end_level is None:
            end_level = self.levels
        # 使用瓶颈网络进行解码,获取潜在状态
        latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level)
        # 只使用最低级别的解码器
        decoder, dequantised_state = self.decoders[start_level], latent_states[0:1]
        # 使用解码器对去量化状态进行解码
        dequantised_state = decoder(dequantised_state, all_levels=False)
        # 调整维度顺序,将时间轴移至第二个维度
        dequantised_state = dequantised_state.permute(0, 2, 1)
        return dequantised_state

    # 解码函数,将音乐编码 `music_tokens` 解码为原始音频表示,支持批处理
    def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor:
        将输入的 `music_tokens` 解码为它们的 `raw_audio` 表示。

            music_tokens (`torch.LongTensor`):
                音乐编码的张量,通过使用码本将其解码为原始音频。每个音乐编码应该是码本中相应 `code` 向量的索引。
            start_level (`int`, *optional*):
                解码过程开始的级别。默认为 0。
            end_level (`int`, *optional*):
                解码过程结束的级别。默认为 None。
            bs_chunks (int, *optional*):

            `torch.Tensor`: 解码后的原始音频张量。
        # 将音乐编码分块,以便并行处理
        token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens]
        dequantised_states = []
        for i in range(bs_chunks):
            music_tokens_i = [chunks[i] for chunks in token_chunks]
            # 调用 `_decode` 函数进行解码
            dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level)
        # 拼接所有解码后的状态张量
        return, dim=0)

    # 编码函数,将原始音频 `raw_audio` 编码为音乐编码 `music_tokens`
    def _encode(self, raw_audio, start_level=0, end_level=None):
        # 编码
        if end_level is None:
            end_level = self.levels
        # 调整输入音频的维度顺序,确保正确的输入格式
        input_audio = raw_audio.permute(0, 2, 1).float()
        latent_states = []
        # 遍历所有级别的编码器,获取潜在状态
        for level in range(self.levels):
            encoder = self.encoders[level]
            latent_state = encoder(input_audio)
            latent_states.append(latent_state[-1])  # 仅保留每级别最后一个潜在状态
        # 使用瓶颈网络对潜在状态进行编码,得到音乐编码 `music_tokens`
        music_tokens = self.bottleneck.encode(latent_states)
        return music_tokens[start_level:end_level]
    # 将输入音频分割成若干块,每块作为一个处理单元
    audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0)
    # 初始化一个空列表,用于存储每个音频块的离散表示
    music_tokens_list = []
    # 遍历每个音频块
    for chunk_i in audio_chunks:
        # 调用内部方法 `_encode` 对当前音频块进行编码,生成其离散表示
        music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level)
        # 将编码后的离散表示添加到列表中
    # 将每个音频块的离散表示进行合并,按照维度0连接在一起,形成最终的音乐表示
    music_tokens = [, dim=0) for music_tokens_level in zip(*music_tokens_list)]
    # 返回整个音乐表示
    return music_tokens

# 生成指定数量的音乐样本的离散表示
def sample(self, n_samples):
    # 为每个离散表示形状生成随机整数,表示从0到nb_discrete_codes之间的离散码
    music_tokens = [
        torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu")
        for music_tokens_shape in self.music_tokens_shapes
    # 调用解码方法,将生成的离散表示解码为音频样本
    return self.decode(music_tokens)
    # Encode/Decode
    input_audio = raw_audio.permute(0, 2, 1).float()
    # 将输入音频数据重新排列维度,使其符合模型要求的格式,并转换为浮点数类型

    latent_states = []
    # 创建空列表,用于存储每个级别的潜在状态

    for level in range(self.levels):
        # 遍历所有编码器级别
        encoder = self.encoders[level]
        # 获取当前级别的编码器
        latent_state = encoder(input_audio)
        # 对输入音频进行编码,得到潜在状态
        # 将编码后的潜在状态加入列表中,取最后一个状态

    _, music_tokens, commit_losses, _ = self.bottleneck(latent_states)
    # 使用瓶颈模型处理潜在状态,得到音乐编码、commit loss 等结果

    dequantised_states = []
    # 创建空列表,用于存储每个级别的反量化状态

    for level in range(self.levels):
        # 遍历所有解码器级别
        decoder = self.decoders[level]
        # 获取当前级别的解码器
        dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False)
        # 使用解码器解码音乐编码得到反量化状态
        dequantised_states.append(dequantised_state.permute(0, 2, 1))
        # 将反量化状态重新排列维度并加入列表中

    commit_loss = sum(commit_losses)
    # 计算总的 commit loss
    loss = self.commit * commit_loss
    # 根据 commit 系数计算最终损失值

    return dequantised_states, loss
    # 返回解码后的状态列表及计算得到的损失值
class JukeboxMLP(nn.Module):
    def __init__(self, config):
        # 初始化函数,定义一个多层感知机(MLP)模型
        # 从配置中获取隐藏层大小作为嵌入维度
        embed_dim = config.hidden_size
        # 计算隐藏层大小的倍数作为MLP的隐藏层大小
        hidden_dim = int(config.mlp_multiplier * embed_dim)

        # 创建第一个卷积层,输入维度为embed_dim,输出维度为hidden_dim
        self.c_fc = JukeboxConv1D(embed_dim, hidden_dim)
        # 创建第二个卷积层,输入维度为hidden_dim,输出维度为embed_dim
        self.c_proj = JukeboxConv1D(hidden_dim, embed_dim)
        # 选择激活函数,从预定义的函数字典ACT2FN中获取对应配置的激活函数
        self.act = ACT2FN[config.act_fn]
        # 定义Dropout层,使用配置中的残差丢弃率
        self.dropout = nn.Dropout(config.resid_dropout)

    def forward(self, hidden_states):
        # MLP模型的前向传播函数
        # 应用第一个卷积层
        hidden_states = self.c_fc(hidden_states)
        # 应用激活函数
        hidden_states = self.act(hidden_states)
        # 应用第二个卷积层
        hidden_states = self.c_proj(hidden_states)
        # 应用Dropout层
        hidden_states = self.dropout(hidden_states)
        return hidden_states

class JukeboxLayerNorm(FusedLayerNorm):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        # 初始化函数,定义Jukebox模型的LayerNorm层
        super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
        # 计算输入张量的总元素个数(维度的乘积)
        self.width =
        # 计算能够处理的最大元素个数,限制为65535 * self.width
        self.max_numel = 65535 * self.width

    def forward(self, input):
        # Jukebox模型LayerNorm层的前向传播函数
        if input.numel() > self.max_numel:
            # 如果输入张量的元素个数超过self.max_numel,使用PyTorch的layer_norm函数处理
            return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
            # 否则调用父类FusedLayerNorm的forward方法处理
            return super().forward(input).type_as(input)

class JukeboxAttention(nn.Module):
    # Jukebox模型的Attention模块定义,未完整展示,不做进一步注释
    # 初始化函数,用于初始化一个模型对象
    def __init__(self, config, n_ctx, attn_func="dense_attn"):
        # 调用父类的初始化函数
        # 设置嵌入维度为配置文件中的隐藏大小
        self.embed_dim = config.hidden_size
        # 设置注意力头数为配置文件中的头数
        self.n_heads = config.n_heads
        # 设置注意力的dropout概率为配置文件中的注意力dropout
        self.dropout = config.attn_dropout
        # 计算隐藏层维度,根据注意力乘子乘以嵌入维度
        hidden_dim = int(config.attention_multiplier * self.embed_dim)

        # 设置每个头的维度
        self.head_dim = hidden_dim // config.n_heads
        # 设置上下文长度
        self.n_ctx = n_ctx
        # 设置隐藏层维度
        self.hidden_dim = hidden_dim
        # 设置缩放因子,用于注意力机制中的缩放
        self.scale = self.head_dim**-0.25
        # 设置是否使用掩码
        self.mask = config.mask

        # 根据注意力函数类型选择不同的处理方式
        if attn_func == "cross_attention":
            # 如果是交叉注意力,设置交叉注意力部分的卷积模块
            self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim)
            # 设置交叉注意力中的编码键值的卷积模块
            self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2)
            # 对于其他类型的注意力,设置通用的卷积模块
            self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3)

        # 设置投影层的卷积模块,用于最终的投影
        self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim)
        # 设置注意力的dropout层
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        # 设置残差连接的dropout层
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        # 根据序列长度seq_len将其分解为[块数, seq_len // 块数]的形式
        self.attn_func = attn_func
        # 根据注意力函数类型选择对应的QKV处理函数
        if attn_func == "cross_attention":
            self.qkv = self.decode_qkv
        elif attn_func == "prime_attn":
            self.qkv = self.prime_qkv
            self.qkv = self.factored_qkv

        # 定义不同注意力类型的映射关系
        ATTENTION_MAP = {
            "dense_attn": (self.dense_attn, "autoregressive"),
            "block_attn": (self.block_attn, "autoregressive"),
            "transpose_block_attn": (self.transpose_block_attn, "autoregressive"),
            "prev_block_attn": (self.prev_block_attn, None),
            "summary_attn": (self.summary_attn, "summary"),
            "summary_spread_attn": (self.summary_spread_attn, "summary"),
            "cross_attention": (self.dense_attn, None),
            "prime_attn": (self.prime_attn, "prime"),
        # 根据传入的注意力函数名称选择对应的注意力函数及其掩码
        self.attn, self.attn_mask = ATTENTION_MAP[attn_func]

        # 设置块数和扩展数
        self.blocks = config.blocks
        self.spread = config.spread
        # 如果定义了块数,则设置块上下文长度
        if self.blocks is not None:
            self.block_ctx = self.n_ctx // self.blocks

        # 设置采样时间为0
        self.sample_t = 0
        # 初始化缓存字典
        self.cache = {}
        # 设置编码器长度,即编码器输入标识符的长度
        self.encoder_len = config.nb_relevant_lyric_tokens  # length of the encoder input ids
        # 记录是否记录注意力权重
        self.record_attn = False
    # 定义注意力机制函数,接受查询、键和值状态以及采样参数
    def _attn(self, query_states, key_states, value_states, sample):
        scale = self.scale
        # 如果处于训练阶段,应用缩放因子后计算注意力权重
            attention_weight = torch.matmul(query_states * scale, key_states * scale)
            # 否则直接计算注意力权重,并乘以缩放因子的平方
            attention_weight = torch.matmul(query_states, key_states)
            attention_weight.mul_(scale * scale)
        attn_weight_type = attention_weight.dtype
        # 将注意力权重转换为 float 类型
        attention_weight = attention_weight.float()
        # 如果有掩码需求
        if self.mask:
            # 生成适当的掩码以遮蔽当前位置之前的所有位置
            # 对于稠密运算可能占用大量内存,因此可以缓存
            mask = get_mask(
            # 如果掩码存在,则应用掩码;否则令未被掩码的位置的注意力权重为一个极小的值
            if mask is not None:
                attention_weight = attention_weight * mask + -1e9 * (1 - mask)
        # 对注意力权重进行 softmax 归一化,并根据原始类型重新转换
        attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type)
        # 如果记录注意力权重
        if self.record_attn:
            self.attention_prob = attention_prob
            # 如果使用的是特定的注意力函数,只保留音乐查询和歌词键/值对应的注意力权重
            if self.attn_func == "prime_attn":
                self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len]
        # 对注意力权重应用 dropout
        attention_prob = self.attn_dropout(attention_prob)
        # 计算上下文状态,通过注意力权重加权求和值状态
        context_states = torch.matmul(attention_prob, value_states)
        return context_states

    # 合并多头注意力机制的结果
    def merge_heads(self, hidden_states):
        # 对隐藏状态进行维度置换,以便后续合并多头注意力的结果
        hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
        new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1))
        # 将维度变换后的隐藏状态返回,与 TensorFlow 实现中的 merge_states 函数相对应
        return hidden_states.view(*new_hidden_states_shape)

    # 将隐藏状态拆分为多头注意力机制所需的形状
    def split_heads(self, hidden_states, is_key=False):
        # 计算新的隐藏状态形状,以便进行多头注意力机制的拆分
        new_hidden_states_shape = (
            hidden_states.size(-1) // self.n_heads,
        # 根据新形状对隐藏状态进行视图变换,与 TensorFlow 实现中的 split_states 函数对应
        hidden_states = hidden_states.view(*new_hidden_states_shape)
        # 如果是键,进一步置换维度以满足多头注意力机制的要求
        if is_key:
            return hidden_states.permute(0, 2, 3, 1)
            return hidden_states.permute(0, 2, 1, 3)

    # 密集注意力机制的实现,接受查询、键、值和采样参数
    def dense_attn(self, query, key, value, sample):
        # 对查询、键和值分别进行多头拆分
        query = self.split_heads(query)
        key = self.split_heads(key, is_key=True)
        value = self.split_heads(value)
        # 应用注意力机制计算上下文状态
        context_states = self._attn(query, key, value, sample)
        # 合并多头注意力机制的结果
        context_states = self.merge_heads(context_states)
        return context_states
    # 定义一个方法用于处理分块注意力机制,接受查询(query)、键(key)、值(value)和一个是否抽样的标志(sample)
    def block_attn(self, query, key, value, sample):
        # 将当前对象的块上下文(block_ctx)存储到局部变量block_ctx中
        block_ctx = self.block_ctx
        # 获取值(value)的形状,其中包括批量大小(batch_size)、序列长度(seq_len)和嵌入维度(embed_dim)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        # 如果抽样标志为True,调用dense_attn方法处理注意力计算,并将结果调整为(batch_size, 1, embed_dim)的形状
        if sample:
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
            # 否则,根据查询(query)的长度重新组织查询(query)张量
            query_length = query.shape[1]
            query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
            # 如果查询长度小于序列长度(seq_len),更新序列长度为查询长度,同时截取键(key)和值(value)的最后一部分
            if query_length < seq_len:
                seq_len = query_length
                key = key[:, -seq_len:].contiguous()
                value = value[:, -seq_len:].contiguous()
            # 将键(key)和值(value)重新组织为适合分块上下文的形状
            key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
            value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
            # 调用dense_attn方法计算分块注意力,并将结果调整为(batch_size, seq_len, embed_dim)的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 定义一个方法用于处理转置的分块注意力机制,接受查询(query)、键(key)、值(value)和一个是否抽样的标志(sample)
    def transpose_block_attn(self, query, key, value, sample):
        # 将当前对象的块上下文(block_ctx)存储到局部变量block_ctx中
        block_ctx = self.block_ctx
        # 获取值(value)的形状,其中包括批量大小(batch_size)、序列长度(seq_len)和嵌入维度(embed_dim)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        # 如果抽样标志为True,计算最后一个分块长度,截取键(key)和值(value)的特定分块,并调用dense_attn方法计算注意力
        if sample:
            block_len = (seq_len - 1) % block_ctx
            key = key[:, block_len::block_ctx, :]
            value = value[:, block_len::block_ctx, :]
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
            # 否则,重新组织查询(query)、键(key)和值(value),以便进行分块转置操作
            query_length = query.shape[1]
            query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim)
            query = query.transpose(1, 2).contiguous()
            query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim)

            key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
            key = key.transpose(1, 2).contiguous()
            key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)

            value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
            value = value.transpose(1, 2).contiguous()
            value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)

            # 调用dense_attn方法计算分块注意力,并进行转置以匹配原始序列的形状
            block_attn = self.dense_attn(query, key, value, sample)
            block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim)
            block_attn = block_attn.transpose(1, 2).contiguous()
            block_attn = block_attn.view(batch_size, query_length, embed_dim)

            return block_attn
    # 定义一个方法,用于处理前一个块的注意力计算
    def prev_block_attn(self, query, key, value, sample):
        # 获取块的上下文大小
        block_ctx = self.block_ctx
        # 获取 value 的形状信息:batch_size(批大小)、seq_len(序列长度)、embed_dim(嵌入维度)
        batch_size, seq_len, embed_dim = value.shape  # For sample, query_len= 1, key_len = value_len = sample_t
        # 如果需要采样(sample=True),则处理前一个块的注意力
        if sample:
            # 计算当前块的数量
            block = (seq_len - 1) // block_ctx
            # 计算前一个块的长度
            prev_l = (block - 1) * block_ctx
            # 如果存在前一个块
            if block > 0:
                # 截取前一个块的 key 和 value
                key = key[:, prev_l : prev_l + block_ctx, :]
                value = value[:, prev_l : prev_l + block_ctx, :]
                # 如果不存在前一个块,则创建零张量填充
                key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
                value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
            # 调用 self.dense_attn 方法进行注意力计算,并将结果 reshape 成 (batch_size, 1, embed_dim) 的形式
            return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
        # 如果不需要采样
            # 获取 query 的长度
            query_length = query.shape[1]
            # 将 query reshape 成适合块大小的形状
            query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)

            # 将 key 和 value 根据块大小进行 reshape
            key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
            key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0))
            key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)

            value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
            value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0))
            value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)

            # 如果 query 的长度小于 seq_len,则对 key 和 value 进行进一步处理以匹配 query 的长度
            if query_length < seq_len:
                nb_query_blocks = query_length // block_ctx
                nb_key_blocks = seq_len // block_ctx
                seq_len = query_length
                key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
                key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)

                value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
                value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)

            # 调用 self.dense_attn 方法进行注意力计算,并将结果 reshape 成 (batch_size, seq_len, embed_dim) 的形式
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
    # 计算自注意力摘要
    def summary_attn(self, query, key, value, sample):
        # 获取模型的块数和块上下文大小
        blocks = self.blocks
        block_ctx = self.block_ctx
        batch_size, seq_len, embed_dim = value.shape  # 获取值的形状,其中值的形状为(batch_size, seq_len, embed_dim),用于sample情况下,query_len= 1, key_len = value_len = sample_t
        if sample:
            # 对样本进行处理,目前未实现该分支的处理方式
            raise NotImplementedError
            # 对非样本进行处理
            # 调整key的形状以匹配块结构,并进行零填充以适应模型要求
            key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
            key = torch.nn.functional.pad(key, (0, 0, 1, 0))  # 在最后一维上进行零填充,确保形状为(batch_size, blocks, embed_dim)

            # 调整value的形状以匹配块结构,并进行零填充以适应模型要求
            value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
            value = torch.nn.functional.pad(value, (0, 0, 1, 0))  # 在最后一维上进行零填充,确保形状为(batch_size, blocks, embed_dim)

            # 使用自定义的注意力函数dense_attn进行注意力计算,并重新调整输出的形状以匹配输入value的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 计算分散注意力摘要
    def summary_spread_attn(self, query, key, value, sample):
        # 获取模型的块数和分散度大小
        blocks = self.blocks
        spread = self.spread

        batch_size, seq_len, embed_dim = value.shape  # 获取值的形状,其中值的形状为(batch_size, seq_len, embed_dim),用于sample情况下,query_len= 1, key_len = value_len = sample_t
        if sample:
            # 对样本进行处理,目前未实现该分支的处理方式
            raise NotImplementedError
            # 对非样本进行处理
            # 调整key的形状以匹配块结构并减少尾部的spread,然后进行零填充和连续化处理以适应模型要求
            key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
            key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous()  # 在维度1和2上进行零填充,确保形状为(batch_size, blocks * spread, embed_dim)

            # 调整value的形状以匹配块结构并减少尾部的spread,然后进行零填充和连续化处理以适应模型要求
            value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
            value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous()  # 在维度1和2上进行零填充,确保形状为(batch_size, blocks * spread, embed_dim)

            # 使用自定义的注意力函数dense_attn进行注意力计算,并重新调整输出的形状以匹配输入value的形状
            return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)

    # 计算主要注意力摘要
    def prime_attn(self, query, key, value, sample):
        # 获取编码器长度
        encoder_len = self._encoder_len

        # 调整key和value的形状以匹配编码器长度,并返回dense_attn函数计算的结果
        key = key[:, :encoder_len]
        value = value[:, :encoder_len]
        return self.dense_attn(query, key, value, sample)
    # 根据给定的隐藏状态张量计算查询、键、值
    def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 如果存在上一个编码器的隐藏状态,则抛出类型错误
        if last_encoder_hidden_states is not None:
            raise TypeError("last_encoder_hidden_states should be None")

        # 将隐藏状态张量按照最后一个维度分成查询、键、值三部分
        query, key, value = hidden_states.chunk(3, dim=2)
        # 如果需要进行采样
        if sample:
            # 增加采样计数器
            self.sample_t += curr_ctx
            # 将键和值追加到缓存中
            key, value = self._append_cache(key, value)
            # 计算当前缓存长度
            l_cache = self._suff_cache_len()
            # 如果整体缓存长度超过阈值,进行缓存切片
            if self._cache_len() > l_cache:
            # 如果当前上下文大于1
            if curr_ctx > 1:
                # 如果注意力函数不是 "dense_attn",对查询、键、值进行块填充
                if self.attn_func != "dense_attn":
                    query = self._pad_to_block_ctx(query, query=True)
                    key = self._pad_to_block_ctx(key)
                    value = self._pad_to_block_ctx(value)
                # 禁用采样标志
                sample = False
                # 如果当前上下文为1,则从缓存中获取键和值
                key = self.cache["key"]
                value = self.cache["value"]

        # 返回查询、键、值以及采样标志
        return query, key, value, sample

    # 根据给定的隐藏状态张量计算查询、键、值
    def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 如果存在上一个编码器的隐藏状态,则抛出类型错误
        if last_encoder_hidden_states is not None:
            raise TypeError("last_encoder_hidden_states should be None")
        # 将隐藏状态张量按照最后一个维度分成查询、键、值三部分
        query, key, value = hidden_states.chunk(3, dim=2)
        # 如果需要进行采样
        if sample:
            # 如果缓存长度小于编码器长度,则将键和值追加到缓存中
            if self._cache_len() < self._encoder_len:
                self._append_cache(key, value)
            # 如果缓存长度大于编码器长度,则对缓存进行切片操作
            if self._cache_len() > self._encoder_len:
                self._slice_cache(0, self._encoder_len)
            # 从缓存中获取键和值
            key, value = self.cache["key"], self.cache["value"]
            # 增加采样计数器
            self.sample_t += curr_ctx
        # 返回查询、键、值以及采样标志
        return query, key, value, sample

    # 根据给定的隐藏状态张量计算查询、键、值
    def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文大小
        curr_ctx = hidden_states.shape[1]
        # 将隐藏状态作为查询
        query = hidden_states
        # 如果需要进行采样
        if sample:
            # 如果采样计数器为0,则从编码器的隐藏状态生成键和值,并存入缓存
            if self.sample_t == 0:
                self.cache["key"], self.cache["value"] = self.c_enc_kv(
                ).chunk(2, dim=2)
            # 从缓存中获取键和值
            key, value = self.cache["key"], self.cache["value"]
            # 增加采样计数器
            self.sample_t += curr_ctx
            # 否则,根据给定的隐藏状态生成键和值
            key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2)
        # 返回查询、键、值以及采样标志
        return query, key, value, sample
    # 定义一个方法,用于进行前向传播计算
    def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 获取当前上下文的长度
        curr_ctx = hidden_states.shape[1]
        # 对输入的隐藏状态应用注意力机制
        hidden_states = self.c_attn(hidden_states)
        # 使用查询、键、值进行注意力机制的计算
        query, key, value, sample = self.qkv(
            hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
        # 计算注意力分数
        attention_scores = self.attn(query, key, value, sample)
        # 如果注意力分数的长度与当前上下文长度不一致,则进行偏移操作
        if attention_scores.shape[1] != curr_ctx:
            offset = self._offset(curr_ctx)
            attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous()
        # 应用变换投影到输出空间
        attention_scores = self.c_proj(attention_scores)
        # 应用残差连接的dropout操作并返回结果
        return self.resid_dropout(attention_scores)

    # 定义一个属性方法,用于获取编码器的长度
    def _encoder_len(self):
        # 获取编码器长度属性
        encoder_len = self.encoder_len
        # 计算编码器块的数量
        encoder_blocks = (encoder_len // self.blocks) + 1
        # 返回调整后的编码器长度
        return encoder_blocks * self.blocks

    # 定义一个方法,用于计算偏移量
    def _offset(self, curr_ctx):
        # 如果使用密集注意力机制,则返回0
        if self.attn_func == "dense_attn":
            return 0
        # 否则,计算偏移量并返回
        return (self.sample_t - curr_ctx) % self.block_ctx

    # 定义一个方法,用于将隐藏状态填充到块上下文的长度
    def _pad_to_block_ctx(self, hidden_states, query=False):
        # 获取序列长度
        seq_len = hidden_states.shape[1]
        # 如果是查询,则计算偏移量
        offset = self._offset(seq_len) if query else 0
        # 计算块的数量
        n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx
        # 计算填充的长度
        pad = n_blocks * self.block_ctx - seq_len - offset
        # 如果无需填充,则直接返回隐藏状态
        if pad == 0 and offset == 0:
            return hidden_states
            # 否则,对隐藏状态进行填充并返回
            return F.pad(hidden_states, (0, 0, offset, pad))

    # 定义一个方法,用于获取缓存的长度
    def _cache_len(self):
        # 如果缓存中没有键值对,则返回0;否则返回键的长度
        return 0 if "key" not in self.cache else self.cache["key"].shape[1]

    # 定义一个方法,用于获取必要的缓存长度
    def _suff_cache_len(self):
        # 计算前一个块的长度
        previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx
        # 定义必要的缓存长度字典
            "dense_attn": self.sample_t,
            "block_attn": (self.sample_t - 1) % self.block_ctx + 1,
            "transpose_block_attn": self.sample_t,
            "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length,
            "cross_attn": self.encoder_len,
            "prime_attn": min(self.sample_t, self._encoder_len),
        # 返回根据注意力机制类型选择的必要缓存长度
        return REQUIRED_CACHE_LEN[self.attn_func]

    # 定义一个方法,用于对缓存进行切片
    def _slice_cache(self, start, end=None):
        # 对键和值缓存进行切片操作
        self.cache["key"] = self.cache["key"][:, start:end]
        self.cache["value"] = self.cache["value"][:, start:end]
    # 将键值对添加到缓存中,如果键不存在则创建新的缓存项,否则更新现有缓存项
    def _append_cache(self, key, value):
        # 检查缓存中是否已存在键
        if "key" not in self.cache:
            # 如果不存在,则将提供的键和值存入缓存
            self.cache["key"] = key
            self.cache["value"] = value
            # 如果存在,则合并现有键值和新的键值对,并更新缓存
            old_key, old_value = key, value
            key =[self.cache["key"], old_key], dim=1)
            value =[self.cache["value"], old_value], dim=1)
            # 删除旧的键和值以释放内存
            del self.cache["key"]
            del self.cache["value"]
            del old_key
            del old_value
            # 更新缓存的键和值
            self.cache["key"] = key
            self.cache["value"] = value
        # 返回更新后的键和值
        return self.cache["key"], self.cache["value"]
    # 清空缓存中的所有项,并重置样本计数器
    def del_cache(self):
        self.sample_t = 0  # 重置样本计数器为0
        if "key" in self.cache:
            del self.cache["key"]  # 删除缓存中的键
        if "value" in self.cache:
            del self.cache["value"]  # 删除缓存中的值
        self.cache = {}  # 清空整个缓存字典
class JukeboxBlock(nn.Module):
    # JukeboxBlock 类,用于实现一个模块
    def __init__(self, config, n_ctx, attn_func="dense_attn"):
        # 设置模块的宽度为隐藏层大小
        self.width = config.hidden_size
        # 创建 JukeboxAttention 对象,并存储在 self.attn 中
        self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func)

        # 创建第一个 Layer Normalization 层,并存储在 self.layer_norm_0 中
        self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size)
        # 创建 JukeboxMLP 对象,并存储在 self.mlp 中
        self.mlp = JukeboxMLP(config)
        # 创建第二个 Layer Normalization 层,并存储在 self.layer_norm_1 中
        self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size)
        # 设置残差比例,如果启用注意力残差缩放,为 1/层数,否则为 1.0
        self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0
        # 存储注意力函数名称
        self.attn_func = attn_func

    def forward(self, hidden_states, last_encoder_hidden_states, sample=False):
        # 复制输入的隐藏状态作为残差
        residuals = hidden_states
        # 应用第一个 Layer Normalization 层
        hidden_states = self.layer_norm_0(hidden_states)
        # 应用注意力机制,并更新隐藏状态
        hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample)

        # 计算输出状态,结合残差和更新后的隐藏状态
        output_states = self.layer_norm_1(residuals + hidden_states)
        # 应用 MLP 层
        output_states = self.mlp(output_states)
        # 计算最终输出,结合残差、更新后的隐藏状态和 MLP 输出
        if self.res_scale == 1.0:
            output = residuals + hidden_states + output_states
            output = residuals + self.res_scale * (hidden_states + output_states)
        return output

class JukeboxLayerStack(nn.Module):
    # JukeboxLayerStack 类,用于堆叠多个 JukeboxBlock 模块
    def __init__(self, config, n_ctx):
        # 初始化上下文长度和宽度为隐藏层大小
        self.n_ctx = n_ctx
        self.width = config.hidden_size
        # 设置层数和块数
        self.num_layers = config.num_layers
        self.blocks = config.blocks
        # 设置注意力模式
        self.attention_pattern = config.attention_pattern
        # 如果定义了块数,则计算每个块的上下文长度
        if self.blocks is not None:
            self.block_ctx = n_ctx // self.blocks
        # 设置编码器长度
        self.encoder_len = config.nb_relevant_lyric_tokens
        # 设置头数
        self.n_heads = config.n_heads

        # 根据注意力模式创建注意力模块列表
        attention_pattern = ATTENTION_PATTERNS[self.attention_pattern]
        self._attn_mods = nn.ModuleList()
        for depth in range(self.num_layers):
            # 向 _attn_mods 列表添加 JukeboxBlock 模块
            self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth)))

        # 用于存储注意力权重
        self.saved_attn_weights = []

    def set_record_attn(self, record_attn):
        设置是否记录注意力 softmax 到 self.saved_attn_weights 中。

            record_attn (`Union[bool,set]`):
                若为 set 类型,表示要记录哪些层的注意力 softmax;若为 bool 类型,表示是否全部记录。
        # 判断是否记录每一层的注意力 softmax
        def _should_record_attn(layer_idx):
            if isinstance(record_attn, bool):
                return record_attn
            return layer_idx in record_attn

        # 设置每个层的注意力记录属性
        for i, layer in enumerate(self._attn_mods):
            layer.attn.record_attn = _should_record_attn(i)

        # 若不记录任何注意力 softmax,则清空 self.saved_attn_weights
        if not record_attn:
            self.saved_attn_weights = []
    # 前向传播函数,用于处理隐藏状态和可能的编码器最后隐藏状态,支持采样
    def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
        # 遍历注意力层模块列表
        for i, attn_layer in enumerate(self._attn_mods):
            # 如果当前注意力层为跨注意力机制,即跨编码器-解码器注意力
            if attn_layer.attn_func == "cross_attention":  # attend to the lyrics
                # 执行跨注意力机制,将当前隐藏状态和最后编码器隐藏状态作为参数传入
                hidden_states = attn_layer(
                    hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
                # 否则,执行普通的注意力机制,不使用编码器的隐藏状态
                hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample)
            # 如果当前注意力层记录了注意力权重
            if attn_layer.attn.record_attn:
                # 将当前注意力层的注意力权重保存到列表中
        # 返回处理后的隐藏状态
        return hidden_states

    # 删除缓存函数,用于清空所有注意力层的缓存
    def del_cache(self):
        # 遍历所有注意力层模块
        for attn_layer in self._attn_mods:
            # 调用注意力层对象的删除缓存方法
class JukeboxPositionalEmbedding(nn.Module):
    # JukeboxPositionalEmbedding 类定义,继承自 nn.Module
    def __init__(self, embed_dim, width):
        # 初始化方法
        # 创建一个可学习的参数 pos_emb,其形状为 (embed_dim, width)
        self.pos_emb = nn.Parameter(torch.empty((embed_dim, width)))

    def forward(self):
        # 前向传播方法
        pos_emb = self.pos_emb
        # 返回位置嵌入参数 pos_emb
        return pos_emb

class JukeboxConditionalAutoregressive(nn.Module):
    # JukeboxConditionalAutoregressive 类定义,继承自 nn.Module
    def __init__(
        # 初始化方法,接受多个参数,包括模型配置、上下文长度、嵌入维度等
        # 此处缺少进一步的代码,可能涉及模型的具体定义和初始化,需查看完整代码以添加详细注释
        Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly
        set fro each configuration.

            config (`JukeboxPriorConfig`):
                Model configuration class with all the parameters of the model. Initializing with a config file does
                not load the weights associated with the model, only the configuration. Check out the
                [`~PreTrainedModel.from_pretrained`] method to load the model weights.
            n_ctx (`int`, *optional*):
                Number of tokens or lyrics tokens provided in a single pass.
            embed_dim (`int`, *optional*):
                Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension,
                if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder
            audio_conditioning (`bool`, *optional`, defaults to `False`):
                Whether or not the prior supports conditioning on audio.
            metadata_conditioning (`bool`, *optional`, defaults to `False`):
                Whether or not the prior supports conditioning on artist, genres, lyrics, and timing.
            is_encoder (`bool`, *optional`, defaults to `False`):
                Whether the model is an encoder only model.

        # Initialize the class inheriting from nn.Module

        # Set the width of the model from the configuration
        self.width = config.hidden_size
        # Set the number of layers from the configuration
        self.num_layers = config.num_layers
        # Set the context length from the provided argument or the configuration
        self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx
        # Set the embedding dimension based on the argument or configuration's music vocabulary size
        self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size

        # Initialize embedding tokens using nn.Embedding with embed_dim and hidden_size from configuration
        self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size)
        # Apply dropout to embed_tokens based on config's embedding dropout rate
        self.embed_tokens_dropout = nn.Dropout(config.emb_dropout)

        # Set metadata and audio conditioning flags
        self.metadata_conditioning = metadata_conditioning
        self.audio_conditioning = audio_conditioning

        # If metadata_conditioning is False, initialize start_token as a learnable parameter
        if not metadata_conditioning:
            self.start_token = nn.Parameter(torch.empty((1, config.hidden_size)))

        # Initialize positional embedding using JukeboxPositionalEmbedding with n_ctx and hidden_size
        self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size)
        # Apply dropout to positional embedding based on config's embedding dropout rate
        self.pos_emb_dropout = nn.Dropout(config.emb_dropout)

        # Initialize transformer layer stack using JukeboxLayerStack with config and n_ctx
        self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx)
        # Set whether the model is an encoder based on is_encoder flag
        self.is_encoder = is_encoder
        # Set encoder length from configuration's relevant lyric tokens count
        self.encoder_len = config.nb_relevant_lyric_tokens

        # Conditional setups based on config's merged_decoder flag
        if config.merged_decoder:
            self.add_cond_after_transformer = False
            self.share_embed_tokens_fc_proj_out = False
            self.add_cond_after_transformer = True
            self.share_embed_tokens_fc_proj_out = True

        # If not an encoder, initialize output projection layer and loss function
        if not is_encoder:
            # Linear projection layer from hidden_size to embed_dim
            self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False)
            # If sharing embed tokens and fc_proj_out weights, synchronize them
            if self.share_embed_tokens_fc_proj_out:
                self.fc_proj_out.weight = self.embed_tokens.weight
            # Cross-entropy loss function initialization
            self.loss = torch.nn.CrossEntropyLoss()
    def forward(
            tokens (`torch.tensor`):
                Can represent music tokens, lyrics tokens or both, depending on the configuration.
        # Preprocess.
        batch_size = tokens.shape[0]  # 获取批处理大小
        with torch.no_grad():
            tokens = tokens.view(batch_size, -1).long()  # 转换 tokens 的形状

        if not self.audio_conditioning:
            # 如果没有音频条件,则创建全零的音频条件张量
            audio_conditioning = torch.zeros(
                (batch_size, 1, self.width),

        target = tokens  # 目标 tokens
        hidden_states = self.embed_tokens(tokens)  # 嵌入 tokens
        # Shift by 1, and fill in start token
        hidden_states =[:, -1:], hidden_states[:, :-1]), dim=1)  # 将 tokens 向右移动一个位置,并填充起始 token
        if self.metadata_conditioning:
            hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width)  # 如果有元数据条件,则使用元数据条件
            hidden_states[:, 0] = self.start_token  # 否则使用预定义的起始 token

        hidden_states = (
            self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning
        )  # 添加嵌入 tokens 的 dropout、位置编码的 dropout 和音频条件

        hidden_states = self.transformer(
            hidden_states, last_encoder_hidden_states=last_encoder_hidden_states
        )  # 应用 transformer 模型进行编码

        if self.add_cond_after_transformer:  # 如果在 transformer 后添加条件
            hidden_states = hidden_states + audio_conditioning  # 添加音频条件

        activations = hidden_states  # 激活值等于隐藏状态
        if self.is_encoder:
            return hidden_states  # 如果是编码器,直接返回隐藏状态

        hidden_states = self.fc_proj_out(hidden_states)  # 使用全连接层进行预测
        loss_fn = nn.CrossEntropyLoss()  # 使用交叉熵损失函数

        if get_sep_loss:
            # 如果需要单独计算损失
            lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim)
            token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim)

            lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0)  # 计算歌词部分的损失
            music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0)  # 计算音乐 token 部分的损失

            loss = (lyric_loss, music_token_loss)  # 返回歌词损失和音乐 token 损失
            loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0)  # 计算整体损失

        if get_preds:
            return loss, hidden_states  # 如果需要预测,返回损失和隐藏状态
        elif get_acts:
            return loss, activations  # 如果需要激活值,返回损失和激活值
            return loss, None  # 否则只返回损失
    # 定义一个方法,用于获取嵌入表示
    def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning):
        # 如果是第一个样本
        if sample_t == 0:
            # 创建一个空的张量用于存储隐藏状态,形状为 (n_samples, 1, self.width),数据类型与权重张量相同,并移到相同的设备上
            hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to(
            # 如果有元数据条件
            if self.metadata_conditioning:
                # 将元数据条件视图重塑为 (n_samples, self.width),并赋值给隐藏状态的第一个位置
                hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width)
                # 否则将起始标记赋值给隐藏状态的第一个位置
                hidden_states[:, 0] = self.start_token
            # 对于非第一个样本,使用嵌入的 token 表示 tokens
            hidden_states = self.embed_tokens(tokens)
        # 如果音频条件的形状与期望的形状相同
        if audio_conditioning.shape == (n_samples, self.n_ctx, self.width):
            # 则将对应的音频条件切片赋给 cond,形状为 (n_samples, 1, self.width)
            cond = audio_conditioning[:, sample_t : sample_t + 1, :]
            # 否则直接使用原始的音频条件
            cond = audio_conditioning
        # 添加位置嵌入和音频条件到隐藏状态中,位置嵌入在评估时的 dropout 是恒等映射
        hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond
        # 返回更新后的隐藏状态和条件
        return hidden_states, cond
        # 如果未指定采样的 tokens 数量,则使用默认值 self.n_ctx
        if sample_tokens is None:
            sample_tokens = self.n_ctx

        # 如果不需要音频调节,则创建一个全零张量作为音频调节
        if not self.audio_conditioning:
            audio_conditioning = torch.zeros(
                (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype

        # 禁止梯度更新
        with torch.no_grad():
            sampled_tokens = []
            tokens = None
            if get_preds:
                preds = []

            # 使用 tqdm 创建进度条迭代器
            iter = tqdm(range(0, sample_tokens), leave=False)
            for sample_t in iter:
                iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True)
                # 获取嵌入向量和条件
                hidden_states, cond = self.get_emb(
                    sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning

                # 使用 transformer 进行前向传播
                hidden_states = self.transformer(
                    hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
                # 如果设置了在 transformer 后添加条件
                if self.add_cond_after_transformer:
                    hidden_states = hidden_states + cond
                # 使用全连接层进行预测
                hidden_states = self.fc_proj_out(hidden_states)  # Predictions
                # 如果需要获取预测值,则保存预测结果
                if get_preds:
                # 调整 logits 的值
                hidden_states = hidden_states / temp
                hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
                # 从 logits 中采样生成 tokens
                tokens = torch.distributions.Categorical(logits=hidden_states).sample()

            del tokens
            # 清除 transformer 的缓存

            # 拼接所有采样的 tokens
            tokens =, dim=1)
            if get_preds:
                preds =, dim=1)
        # 如果需要获取预测值,则返回 tokens 和 preds
        if get_preds:
            return tokens, preds
        # 否则,只返回 tokens
            return tokens

    def split_chunks(self, length, chunk_size):
        # 计算分块的数量
        n_passes = (length + chunk_size - 1) // chunk_size
        # 计算每个分块的大小列表
        chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
        return chunk_sizes

    def primed_sample(
class JukeboxMusicTokenConditioner(nn.Module):
    The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's
    codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE).

    def __init__(self, config, level):
        # Initialize an embedding layer for music tokens based on vocabulary size and hidden size
        self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size)
        # Set the embed_dim attribute in config to music_vocab_size for compatibility with JukeboxDecoder
        config.embed_dim = config.music_vocab_size  # setting correct argument for the `JukeboxDecoder`

        # Initialize the upsampler using a custom convolutional block
        self.upsampler = JukeboxDecoderConvBock(
        # Initialize layer normalization for the hidden states
        self.layer_norm = JukeboxLayerNorm(config.hidden_size)

    def forward(self, music_tokens, raw_audio_conditionning=None):
            music_tokens (`torch.LongTensor`):
                Music tokens form the uper level in range(nb_discrete_codes)
            raw_audio_conditionning (`torch.LongTensor`, *optional*):
                Audio used when primed sampling, raw audio information that conditions the generation
        # Set default value for raw_audio_conditioning if not provided
        if raw_audio_conditionning is None:
            raw_audio_conditionning = 0.0
        # Convert music_tokens to long type
        music_tokens = music_tokens.long()
        # Embed music_tokens using the previously initialized embedding layer
        hidden_states = self.embed_tokens(music_tokens)
        # Add raw_audio_conditioning to the embedded music tokens
        hidden_states = hidden_states + raw_audio_conditionning

        # Permute dimensions for upsampling
        hidden_states = hidden_states.permute(0, 2, 1)
        # Apply the upsampler to the permuted hidden states
        hidden_states = self.upsampler(hidden_states)
        # Permute dimensions back to original shape
        hidden_states = hidden_states.permute(0, 2, 1)
        # Apply layer normalization to the processed hidden states
        hidden_states = self.layer_norm(hidden_states)
        # Return the normalized hidden states
        return hidden_states

class JukeboxRangeEmbedding(nn.Module):
    The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional
    embedding of length `n_ctx`.

    Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end)
    -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <=

    def __init__(self, n_time, embed_dim, range, out_width, clamp=False):
        # Initialize an embedding layer with size embed_dim and output width out_width
        self.emb = nn.Embedding(embed_dim, out_width)
        self.n_time = n_time
        self.embed_dim = embed_dim
        # Define positional range [pos_min, pos_max]
        self.pos_min, self.pos_max = range
        self.clamp = clamp
    # 定义一个方法用于将位置起始点和结束点进行前向传播
    def forward(self, pos_start, pos_end=None):
        # 检查 pos_start 的形状是否为二维
        if not len(pos_start.shape) == 2:
            raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}")
        # 检查 pos_start 是否在指定范围 [pos_min, pos_max) 内
        if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all():
            raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}")

        # 将 pos_start 转换为 float 类型
        pos_start = pos_start.float()
        # 如果 pos_end 不为 None
        if pos_end is not None:
            # 如果设置了 clamp 标志,将 pos_end 限制在 pos_min 和 pos_max 范围内
            if self.clamp:
                pos_end = pos_end.clamp(self.pos_min, self.pos_max)

            # 将 pos_end 转换为 float 类型
            pos_end = pos_end.float()

        # 计算插值以使得 [pos_start, ..., pos_end] <-> 长度为 n_ctx 的位置张量
        n_time = self.n_time
        if n_time != 1:
            # 生成插值张量,用于在 pos_start 到 pos_end 之间进行线性插值
            interpolation = (
                torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time
            position = pos_start + (pos_end - pos_start) * interpolation
            position = pos_start

        # 将位置归一化到 [0, 1] 范围内
        normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min)
        # 将归一化后的位置映射到 bins_,用于离散化表示
        bins_ = (self.embed_dim * normalised_position).floor().long().detach()
        # 返回根据 bins_ 索引得到的嵌入向量
        return self.emb(bins_)
class JukeboxLabelConditioner(nn.Module):
    def __init__(self, config, include_time_signal):

        embed_dim = config.hidden_size  # 从配置中获取隐藏单元的维度
        timing_dims = config.timing_dims  # 从配置中获取时间维度
        sampling_rate = config.sampling_rate  # 从配置中获取采样率
        nb_genres, nb_artists = config.metadata_dims  # 从配置中获取流派和艺术家的维度
        music_tokens_shape = config.n_ctx  # 从配置中获取音乐令牌的形状

        self.max_nb_genres = config.max_nb_genres  # 设置最大流派数量
        self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim)  # 创建流派嵌入层
        self.artist_emb = nn.Embedding(nb_artists, embed_dim)  # 创建艺术家嵌入层
        self.include_time_signal = include_time_signal  # 设置是否包含时间信号的标志
        if self.include_time_signal:
            # 如果包含时间信号,设置总长度范围、绝对位置范围和相对位置范围
            total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate)
            absolute_pos_range = (0.0, config.max_duration * sampling_rate)
            relative_pos_range = (0.0, 1.0)
            # 创建总长度、绝对位置和相对位置的嵌入层
            self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim)
            self.absolute_pos_emb = JukeboxRangeEmbedding(
                music_tokens_shape, timing_dims, absolute_pos_range, embed_dim
            self.relative_pos_emb = JukeboxRangeEmbedding(
                music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True

    def forward(self, metadata):
        total_length = metadata[:, 0:1]  # 提取元数据中的总长度
        offset = metadata[:, 1:2]  # 提取元数据中的偏移量
        length = metadata[:, 2:3]  # 提取元数据中的长度
        artist = metadata[:, 3:4]  # 提取元数据中的艺术家
        genre = metadata[:, 4:]  # 提取元数据中的流派

        # 起始嵌入,长度为1
        artist_emb = self.artist_emb(artist)  # 计算艺术家的嵌入表示
        # 空的流派插槽用-1表示,对其进行屏蔽处理
        mask = (genre >= 0).float().unsqueeze(2)  # 创建流派屏蔽掩码
        genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)  # 计算流派的嵌入表示
        start_emb = genre_emb + artist_emb  # 合并艺术家和流派的嵌入表示作为起始嵌入

        # 位置嵌入,长度为n_ctx
        if self.include_time_signal:
            start, end = offset, offset + length  # 计算起始和结束位置
            total_length = total_length.float()  # 将总长度转换为浮点数
            start = start.float()  # 将起始位置转换为浮点数
            end = end.float()  # 将结束位置转换为浮点数
            # 计算总长度、绝对位置和相对位置的嵌入表示
            pos_emb = (
                + self.absolute_pos_emb(start, end)
                + self.relative_pos_emb(start / total_length, end / total_length)
            pos_emb = None  # 如果不包含时间信号,则位置嵌入为None
        return start_emb, pos_emb  # 返回起始嵌入和位置嵌入
    # 定义一个类变量,指定配置类为 JukeboxPriorConfig
    config_class = JukeboxPriorConfig

    # 初始化模型权重的方法,接受一个模块作为参数
    def _init_weights(self, module):
        # 从配置中获取初始化比例
        init_scale = self.config.init_scale

        # 如果模块是 nn.Embedding 类型
        if isinstance(module, nn.Embedding):
            # 对权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
  , std=0.02 * init_scale)
        # 如果模块是 JukeboxConv1D 类型
        elif isinstance(module, JukeboxConv1D):
            # 如果配置中指定需要将权重置零
            if self.config.zero_out:
                # 将权重数据置零
                # 否则对权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
      , std=0.02 * init_scale)
        # 如果模块是 JukeboxPositionalEmbedding 类型
        elif isinstance(module, JukeboxPositionalEmbedding):
            # 对位置嵌入数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
  , std=0.01 * init_scale)
        # 如果模块是 JukeboxRangeEmbedding 类型
        elif isinstance(module, JukeboxRangeEmbedding):
            # 对范围嵌入的权重数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
  , std=0.01 * init_scale)
        # 如果模块是 JukeboxConditionalAutoregressive 类型,并且具有 lm_head 属性
        elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"):
            # 对 lm_head 的权重数据进行正态分布初始化,均值为 0,标准差为 0.02 * init_scale
  , std=0.02 * init_scale)
        # 如果模块是 JukeboxConditionalAutoregressive 类型,并且具有 start_token 属性
        elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"):
            # 对 start_token 的数据进行正态分布初始化,均值为 0,标准差为 0.01 * init_scale
  , std=0.01 * init_scale)
        # 如果模块是 JukeboxResConv1DBlock 类型,并且配置中指定需要置零
        elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
            # 将 conv1d_2 的权重和偏置数据置零
        # 如果模块是 nn.LayerNorm 类型
        if isinstance(module, nn.LayerNorm):
            # 将偏置数据置零
            # 将权重数据填充为 1.0
        # 如果模块是 nn.Linear 类型,并且具有偏置
        if isinstance(module, nn.Linear) and module.bias is not None:
            # 将偏置数据置零
        def get_metadata(self, labels, start, total_length, offset, get_indices=False):
            # 克隆 labels 张量以避免修改原始数据,创建 metadata 张量
            metadata = labels.clone()
            # 设置 metadata 的第一列为总长度
            metadata[:, 0] = total_length
            # 设置 sample_length 列以匹配当前层级的样本长度
            metadata[:, 2] = int(self.sample_length)
            # 设置偏移量,计算偏移量在 token 的索引
            metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens)
            # 由于 metadata 包含完整的 token_list,只需选择相关的部分
            # 设置歌词 token,调用 set_metadata_lyric_tokens 方法
            metadata, indices = self.set_metadata_lyric_tokens(metadata)
            # 根据 get_indices 参数返回 metadata 和 indices
            if get_indices:
                return metadata, indices
                return metadata
        def set_metadata_lyric_tokens(self, labels):
            处理完整的标签,只提取相关的歌词 token,并保持元数据的条件 token。
            # 如果有相关的歌词 token
            if self.nb_relevant_lyric_tokens > 0:
                # 初始化 tokens_list 张量,尺寸为 (labels 行数, nb_relevant_lyric_tokens),数据类型为 long,设备为 labels 的设备
                tokens_list = torch.zeros(
                    (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device
                indices_list = []  # 存储原始数组中每个字符的索引
                # 遍历每一行标签数据
                for idx in range(labels.shape[0]):
                    # 克隆 labels 的所有行,但不包括前四列和元数据嵌入的最大生成数量
                    full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :]
                    total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2]
                    # 获取相关的歌词 token 和其索引
                    tokens, indices = get_relevant_lyric_tokens(
                        full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration
                    # 将获取的 tokens 存入 tokens_list
                    tokens_list[idx, :] = tokens
                # 返回更新后的 labels 和索引列表,合并原标签的前几列与 tokens_list
                return (
          [:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1),
                # 如果没有相关的歌词 token,直接返回原 labels 和 None
                return labels, None
        def get_music_tokens_conds(self, music_tokens, start, end):
            提取当前层级的条件音乐 token。
            # 如果不是第一层级
            if self.level != 0:
                # 获取上一层级的音乐 token 条件
                music_tokens_cond = music_tokens[self.level - 1]
                # 根据 start 和 end 索引提取音乐 token
                music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample]
                # 计算缺失的条件长度
                missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1]
                # 如果有缺失的条件长度,填充零
                if missing_cond_len > 0:
                    init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device)
                    music_tokens_cond =, init_cond), dim=-1).long()
                # 返回处理后的音乐 token 条件列表
                music_tokens_conds = [music_tokens_cond]
                music_tokens_conds = None
            return music_tokens_conds
    def prior_preprocess(self, tokens, conds):
        Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music
        tokens should be shifted by. It is equal to `lyric_vocab_size`.
        # 获取批次大小
        batch_size = tokens[0].shape[0]
        # 对每个输入的 token 进行偏移处理
        for i in range(len(tokens)):
            tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1)

        # 对每个条件进行处理,如果条件为 None,则用零填充
        for i in range(len(conds)):
            if conds[i] is None:
                conds[i] = torch.zeros(
                    (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device

        # 将处理后的 tokens 和 conds 拼接起来返回
        return, dim=1),, dim=1)

    def prior_postprocess(self, tokens):
        Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is
        shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music
        # 获取批次大小
        batch_size = tokens.shape[0]
        # 划分 tokens 为列表,按照指定维度进行切分
        dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0])
        tokens = list(torch.split(tokens, dims, dim=1))

        # 对每个切分后的 token 进行逆向处理,将其偏移值减去
        for i in range(len(tokens)):
            bins_shift = int(self.embed_dim_shift[i])
            tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1)
            tokens[i] = torch.clamp(tokens[i], min=0)
            # 如果不屏蔽损失,模型可能生成的歌词/音符 token 可能会被 bin_shift 偏移小于0
        # 返回处理后的最后一个 token
        return tokens[-1]

    def embed_tokens(self, music_tokens_conds):
        Embeds the upper level music tokens and upsamples them to provide as audio conditioning.
        # 仅处理 music_tokens_conds 中指定条件级别以下的内容
        music_tokens_conds = music_tokens_conds[: self.cond_level + 1]
        audio_conditioning = None
        # 对 music_tokens_conds 和条件块进行逆向处理
        for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))):
            audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning)
        # 返回音频条件化结果
        return audio_conditioning

    def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1):
        Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states.
        # 如果未指定起始级别,则使用默认级别
        if start_level is None:
            start_level = self.level
        # 如果未指定结束级别,则使用默认级别
        if end_level is None:
            end_level = self.levels
        # 使用 VQVAE 编码器获取潜在状态
        with torch.no_grad():
            latent_states = self.vqvae_encoder(
                hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
        # 返回潜在状态
        return latent_states
    # 使用给定的音乐令牌解码成原始音频序列
    def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1):
        Usamples the sequence of codebook vectors to a raw audio.
        # 如果未指定起始级别,默认使用对象的级别
        if start_level is None:
            start_level = self.level
        # 如果未指定结束级别,默认使用对象的级别数
        if end_level is None:
            end_level = self.levels
        # 使用禁用梯度环境运行以下代码
        with torch.no_grad():
            # 调用 VQ-VAE 解码器来生成输出
            output = self.vqvae_decoder(
                music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
        return output

    # 获取条件信息,将音乐令牌转换为输入嵌入。将歌词与其余元数据分开,歌词令牌可以为 None。
    def get_cond(self, music_tokens_conds, metadata):
        Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens
        can be None.
        # 如果存在元数据,则从中提取标签和歌词令牌
        if metadata is not None:
            n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens
            metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:]
            # 否则设置元数据和歌词令牌为 None
            metadata, lyric_tokens = None, None
        # 根据是否有元数据条件,生成相应的元数据嵌入和位置编码
        metadata_conditioning, metadata_pos = (
            self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None)
        # 根据音频条件设置,生成音频条件输入嵌入或者使用元数据位置编码
        audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos
        return audio_conditioning, metadata_conditioning, lyric_tokens

    # 对模型进行采样生成
    def sample(
        # 该函数在这里不完整,需要在此处添加代码以完成功能

    # 获取编码器状态,提取将由解码器关注的歌词编码器的最后隐藏状态。通过歌词编码器向前传播。
    def get_encoder_states(self, lyric_tokens, sample=False):
        Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through
        the lyric encoder.
        # 如果存在相关歌词令牌且歌词条件为真
        if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
            # 如果需要进行采样,则将编码器转移到设备上
            if sample:
                self.encoder =
            # 通过编码器生成歌词编码活动
            lyric_acts = self.encoder(lyric_tokens, None, None, None)
            # 将歌词编码活动投影到输入空间
            lyric_acts = self.encoder.proj_in(lyric_acts)
            # 对编码后的结果进行最终的层归一化处理
            last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts)
            # 否则将最终隐藏状态设置为 None
            last_encoder_hidden_states = None
        return last_encoder_hidden_states

    # 获取编码器损失,计算歌词编码器的损失:下一个歌词令牌的预测。
    def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics):
        Computes the loss for the lyric encoder: next lyric token prediction.
        # 如果启用了歌词条件
        if self.lyric_conditioning:
            # 对最终隐藏状态进行语言模型头部处理
            last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states)
            # 计算交叉熵损失
            encoder_loss = nn.functional.cross_entropy(
                last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1)
            ) / np.log(2.0)
            # 否则将编码器损失设置为 0
            encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device)
        return encoder_loss
    def forward_tokens(
        self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False
        Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the
        vqvae's encoding layers.
        # 如果需要记录注意力权重,则设置记录
        if get_attn_weights:
        # 获取音频、元数据条件和歌词 token
        audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)

        # 如果模型是编码-解码器结构
        if self.is_encoder_decoder:
            # 预处理歌词和音乐 token,返回的 tokens 包含歌词和音乐 token,audio_conditioning 也被修改
            tokens, audio_conditioning = self.prior_preprocess(
                [lyric_tokens, music_tokens], [None, audio_conditioning]
            # 使用 prior 模型进行前向传播,包括获取预测值和分离损失
            (encoder_loss, next_token_prediction_loss), preds = self.prior(
                tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds
            # 获取最后一个编码器隐藏状态
            last_encoder_hidden_states = self.get_encoder_states(lyric_tokens)
            # 计算编码器损失
            encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens)
            # 使用 prior 模型进行前向传播,获取下一个 token 预测损失和预测值
            next_token_prediction_loss, preds = self.prior(
        # 计算总损失,包括编码器损失和下一个 token 预测损失
        loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims
        loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims

        # 定义需要返回的指标
        metrics = {
            "bpd": next_token_prediction_loss.clone().detach(),
            "encoder_loss": encoder_loss.clone().detach(),
            "next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
        # 如果需要返回预测值,则加入指标中
        if get_preds:
            metrics["preds"] = preds.clone().detach()
        # 如果需要记录注意力权重,将保存的注意力权重返回并关闭记录
        if get_attn_weights:
            saved_attn_weights = self.prior.transformer.saved_attn_weights
            return saved_attn_weights
            # 否则返回计算得到的损失和指标
            return loss, metrics
    ) -> List[torch.Tensor]:
        Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens`
        function. The loss is the sum of the `encoder` loss and the `decoder` loss.

            hidden_states (`torch.Tensor`):
                Hidden states which should be raw audio
            metadata (`List[torch.LongTensor]`, *optional*):
                List containing the metadata conditioning tensor with the lyric and the metadata tokens.
            decode (`bool`, *optional*, defaults to `False`):
                Whether or not to decode the encoded to tokens.
            get_preds (`bool`, *optional*, defaults to `False`):
                Whether or not to return the actual predictions of the model.
        # 获取批处理的大小
        batch_size = hidden_states.shape[0]
        # 使用 `vqvae` 编码器对隐藏状态(原始音频)进行编码,得到音乐 tokens 和可能的条件
        music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size)
        # 调用 forward_tokens 函数计算损失和指标
        loss, metrics = self.forward_tokens(
        # 如果需要解码,则使用 `vqvae` 解码器进行解码
        if decode:
            dequantised_states = self.decode([music_tokens, *music_tokens_conds])
            dequantised_states = None
        # 返回解码后的状态、损失和指标
        return dequantised_states, loss, metrics
class JukeboxPreTrainedModel(PreTrainedModel):
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

    # 使用 JukeboxConfig 作为配置类
    config_class = JukeboxConfig
    # 基础模型前缀为 "jukebox"
    base_model_prefix = "jukebox"
    # 不支持梯度检查点
    supports_gradient_checkpointing = False

    def _init_weights(self, module):
        # 如果模块是 JukeboxPrior 或者 JukeboxVQVAE 类型,调用其 _init_weights 方法进行初始化
        if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE):

    def __init__(self, *inputs, **kwargs):
        # 调用父类的构造函数
        super().__init__(*inputs, **kwargs)

            labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` :
                List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to
                condition the generation.
            sampling_kwargs (`Dict[Any]`):
                Various additional sampling arguments that are used by the `_sample` function. A detail list of the
                arguments can bee seen in the [`_sample`] function documentation.

    """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`,
    `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If
    you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior
class JukeboxModel(JukeboxPreTrainedModel):
    _no_split_modules = ["JukeboxBlock"]

    def __init__(self, config):
        # 调用父类构造函数,并传入配置对象
        # 使用给定的 vqvae_config 初始化 JukeboxVQVAE 对象
        vqvae_config = config.vqvae_config
        self.vqvae = JukeboxVQVAE(vqvae_config)
        # 设置共享参数
        # 初始化 priors 列表,每个元素为 JukeboxPrior 类的实例
        self.priors = nn.ModuleList(
            [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)]

    def set_shared_params(self, model_config):
        Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig`
        is nest, and is thus unreachable in the `from_dict` function
        # 遍历 model_config.prior_configs 列表,并为每个配置对象设置共享参数
        for config in model_config.prior_configs:
            config.sampling_rate = model_config.sampling_rate
            config.timing_dims = model_config.timing_dims
            config.min_duration = model_config.min_duration
            config.max_duration = model_config.max_duration
            config.max_nb_genres = model_config.max_nb_genres
            config.metadata_conditioning = model_config.metadata_conditioning

    def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1):
        # 调用 vqvae 对象的 decode 方法进行音乐解码
        return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks)
    # 调用 VQ-VAE 模型的 encode 方法对输入音频进行编码
    def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
        return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks)

    # 将对象 obj 拆分成大小为 split_size 的 batch,总共 n_samples 个样本
    def split_batch(self, obj, n_samples, split_size):
        # 计算需要多少个 passes 才能处理完所有样本
        n_passes = (n_samples + split_size - 1) // split_size
        if isinstance(obj, torch.Tensor):  # 如果 obj 是 torch.Tensor 类型
            return torch.split(obj, split_size, dim=0)  # 在 dim=0 上拆分 Tensor
        elif isinstance(obj, list):  # 如果 obj 是 list 类型
            # 对 list 中的每个元素分别在 dim=0 上进行拆分,并将结果打包成 list 返回
            return list(zip(*[torch.split(item, split_size, dim=0) for item in obj]))
        elif obj is None:  # 如果 obj 为 None
            # 返回包含 n_passes 个 None 的列表
            return [None] * n_passes
            # 抛出类型错误异常
            raise TypeError("Unknown input type")

    # 在 level 层级上,从 music_tokens 中采样一个长度小于 n_ctx 的部分窗口,新增 tokens_to_sample 个新标记
    def sample_partial_window(
        self, music_tokens, labels, offset, sampling_kwargs, level, tokens_to_sample, max_batch_size
        prior = self.priors[level]  # 获取指定层级的 prior 模型
        sampled_tokens = music_tokens[level]  # 获取在指定层级的音乐 tokens
        n_ctx = prior.n_ctx  # 获取 prior 模型的上下文长度
        nb_sampled_tokens = sampled_tokens.shape[1]  # 获取已采样 tokens 的数量

        if nb_sampled_tokens < n_ctx - tokens_to_sample:
            # 如果已采样 tokens 的数量小于 n_ctx - tokens_to_sample
            sampling_kwargs["sample_tokens"] = nb_sampled_tokens + tokens_to_sample
            start = 0
            # 否则设置采样的 tokens 数量为 n_ctx
            sampling_kwargs["sample_tokens"] = n_ctx
            start = nb_sampled_tokens - n_ctx + tokens_to_sample

        # 调用 sample_single_window 方法进行单个窗口的采样
        return self.sample_single_window(music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size)

    # 在 level 层级上,从 start 位置开始采样一个长度为 n_ctx 的单个窗口
    # 从先验分布中采样一个单窗口的音乐序列片段
    def sample_single_window(self, music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size):
        # 获取当前层级的先验分布
        prior = self.priors[level]
        # 获取音乐片段的总数
        n_samples = music_tokens[0].shape[0]
        # 获取先验分布中的上下文长度
        n_ctx = prior.n_ctx
        # 计算当前片段的结束位置
        end = start + n_ctx
        # 获取已经在当前层级采样的音乐片段
        previous_sampled_tokens = music_tokens[level][:, start:end]

        # 从采样参数中获取要采样的令牌数
        sample_tokens = sampling_kwargs.get("sample_tokens", None)
        if "sample_tokens" in sampling_kwargs:
            sample_tokens = end - start

        # 计算当前条件下的令牌数量
        conditioning_tokens = previous_sampled_tokens.shape[1]
        # 计算新采样的令牌数量
        new_tokens = sample_tokens - previous_sampled_tokens.shape[1]

        # 记录采样信息日志
            f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on"
            f" {conditioning_tokens} tokens"

        # 如果没有新的令牌需要采样,则直接返回原始音乐令牌
        if new_tokens <= 0:
            return music_tokens

        # 获取上一层级的音乐令牌条件
        music_tokens_conds = prior.get_music_tokens_conds(music_tokens, start, end)
        # 如果没有上一层级,应该返回None!

        # 设置元数据的偏移量、采样长度和歌词令牌
        metadata = prior.get_metadata(labels, start, self.total_length, offset)

        # 将音乐令牌、音乐令牌条件和元数据拆分成批次
        music_tokens_list = self.split_batch(previous_sampled_tokens, n_samples, max_batch_size)
        music_tokens_conds_list = self.split_batch(music_tokens_conds, n_samples, max_batch_size)
        metadata_list = self.split_batch(metadata, n_samples, max_batch_size)
        tokens = []
        # 迭代处理每个批次的音乐令牌和条件
        iterator = tqdm(zip(music_tokens_list, music_tokens_conds_list, metadata_list), leave=False)
        for music_tokens_i, music_tokens_conds_i, metadata_i in iterator:
            # 确定当前使用的名称("祖先"或"主导"),基于是否有音乐令牌条件
            name = ["Ancestral", "Primed"][music_tokens_i.shape[1] == 0]
                f"[prior level {level}] {name} Sampling {sample_tokens} tokens out of"
                f" {self.total_length//prior.raw_to_tokens}",
            # 从先验分布中采样音乐令牌
            tokens_i = prior.sample(
        # 将所有采样的音乐令牌连接起来
        sampled_tokens =, dim=0)

        # 更新音乐令牌序列,加入新的采样片段
        music_tokens_new = sampled_tokens[:, -new_tokens:]
        music_tokens[level] =[music_tokens[level], music_tokens_new], dim=1)
        return music_tokens

    # 以指定级别、总长度和跳跃长度进行采样
    def sample_level(
        self, music_tokens, labels, offset, sampling_kwargs, level, total_length, hop_length, max_batch_size
        # 如果总长度超过当前先验模型的上下文长度
        if total_length >= self.priors[level].n_ctx:
            # 获取起始位置迭代器,根据指定的步长和先验模型的上下文长度
            iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length)
            # 对于迭代器中的每个起始位置
            for start in iterator:
                # 对单个窗口进行采样
                music_tokens = self.sample_single_window(
                    music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size

            # 对部分窗口进行采样,因为总长度小于当前先验模型的上下文长度
            music_tokens = self.sample_partial_window(
                music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size
        # 返回采样后的音乐 tokens
        return music_tokens

    def _sample(
        # 添加文档字符串作为函数注释,描述生成音乐 tokens 的过程
            Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically
            upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use
            the VQ-VAE decoder to convert the music tokens to raw audio.

                labels (`List[torch.LongTensor]`) :
                    List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
                    lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
                    which are used to condition the generation.
                n_samples (`int`, *optional*, default to 1) :
                    Number of samples to be generated in parallel.
    def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]:

        >>> from transformers import AutoTokenizer, JukeboxModel, set_seed

        >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")

        >>> lyrics = "Hey, are you awake? Can you talk to me?"
        >>> artist = "Zac Brown Band"
        >>> genre = "Country"
        >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics)
        >>> set_seed(0)
        >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400)

        >>> with torch.no_grad():
        ...     model.decode(music_tokens)[:, :10].squeeze(-1)
        tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405,
            -0.0818, -0.0697]])

        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
        # 初始化一个空的音乐 tokens 列表,用于存储采样后的结果
        music_tokens = [
            torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors))
        # 使用 _sample 方法进行采样生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens

        """Generates a continuation of the previously generated tokens.

            music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
                A sequence of music tokens which will be used as context to continue the sampling process. Should have
                `self.levels` tensors, each corresponding to the generation at a certain level.
    def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
        # 使用 _sample 方法继续生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens

        """Upsamples a sequence of music tokens using the prior at level `level`.

            music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
                A sequence of music tokens which will be used as context to continue the sampling process. Should have
                `self.levels` tensors, each corresponding to the generation at a certain level.
    def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
        # 从参数中获取采样层级列表,如果没有则使用默认值(self.priors 的长度减一)
        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1)))
        # 使用 _sample 方法上采样生成音乐 tokens
        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
        # 返回生成的音乐 tokens 列表
        return music_tokens
        """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the
        generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are
        used: as conditioning for each level, which means that no ancestral sampling is required.

            raw_audio (`List[torch.Tensor]` of length `n_samples` ) :
                A list of raw audio that will be used as conditioning information for each samples that will be

这是一个装饰器函数,用于给 `primed_sample` 方法添加文档字符串。文档字符串描述了函数的作用、参数和返回值。

    def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]:

定义了一个名为 `primed_sample` 的方法,用于生成基于提供的 `raw_audio` 条件的原始音频。返回一个列表,其中每个元素是包含音乐 token 的 torch LongTensor。

        sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))

从 `sampling_kwargs` 中获取 `sample_levels` 参数,如果不存在则默认为 `self.priors` 的长度范围内的列表。

将 `self.vqvae` 移动到 `raw_audio` 的设备上,并将其数据类型转换为 float。

        with torch.no_grad():


            music_tokens = self.vqvae.encode(
                raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0]

使用 `self.vqvae` 对 `raw_audio` 进行编码,生成音乐 token。使用从 0 到 `len(self.priors)` 的级别作为起始和结束级别,并根据 `raw_audio` 的形状分块处理。

        music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)

调用 `_sample` 方法,使用 `music_tokens`、`labels` 和 `sample_levels` 进行采样,传递额外的 `sampling_kwargs`。

        return music_tokens

返回生成的音乐 token 列表。


# 引入所需的库和模块
import json  # 导入处理 JSON 格式的模块
import os    # 导入操作系统相关功能的模块
import re    # 导入正则表达式模块
import unicodedata  # 导入 Unicode 数据库模块
from json.encoder import INFINITY  # 从 JSON 库中导入 INFINITY 常量
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示相关的功能

import numpy as np  # 导入 NumPy 库,用于数值计算
import regex       # 导入 regex 库,支持更强大的正则表达式功能

# 从 tokenization_utils 模块导入 AddedToken 和 PreTrainedTokenizer 类
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
# 从 tokenization_utils_base 模块导入 BatchEncoding 类
from ...tokenization_utils_base import BatchEncoding
# 从 utils 模块导入 TensorType, is_flax_available, is_tf_available, is_torch_available, logging 等功能
from ...utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
# 从 utils.generic 模块导入 _is_jax 和 _is_numpy 函数
from ...utils.generic import _is_jax, _is_numpy

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)

# 定义各种文件名与其对应的词汇表文件名
    "artists_file": "artists.json",   # 艺术家信息的 JSON 文件名
    "lyrics_file": "lyrics.json",     # 歌词信息的 JSON 文件名
    "genres_file": "genres.json",     # 音乐流派信息的 JSON 文件名

# 预训练词汇文件映射表
    "artists_file": {
        "jukebox": "",  # 艺术家信息的预训练 URL
    "genres_file": {
        "jukebox": "",   # 音乐流派信息的预训练 URL
    "lyrics_file": {
        "jukebox": "",   # 歌词信息的预训练 URL

# 预训练歌词 token 大小
    "jukebox": 512,   # Jukebox 模型的歌词 token 大小为 512

# JukeboxTokenizer 类,继承自 PreTrainedTokenizer
class JukeboxTokenizer(PreTrainedTokenizer):
    构造 Jukebox 分词器。Jukebox 可以根据三种不同的输入条件进行条件化:
        - 艺术家:每个艺术家关联的唯一 ID 存储在提供的字典中。
        - 音乐流派:每种流派关联的唯一 ID 存储在提供的字典中。
        - 歌词:基于字符的分词。必须初始化使用词汇表中包含的字符列表。



        - PreTrainedTokenizer:继承自父类 PreTrainedTokenizer 的构造函数。

    >>> from transformers import JukeboxTokenizer

    >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
    >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
    [tensor([[   0,    0,    0, 6785,  546,   41,   38,   30,   76,   46,   41,   49,
               40,   76,   44,   41,   27,   30]]), tensor([[  0,   0,   0, 145,   0]]), tensor([[  0,   0,   0, 145,   0]])]
    # 你可以通过在实例化这个分词器时或在调用它处理文本时传递 `add_prefix_space=True` 来避免这种行为,但由于模型不是以这种方式预训练的,可能会导致性能下降。
    # 提示信息

    # 如果未提供任何内容,流派和艺术家将随机选择或设置为 None。

    # 这个分词器继承自 [`PreTrainedTokenizer`],其中包含大多数主要方法。用户应参考该超类以获取有关这些方法的更多信息。

    # 然而,代码不允许这样做,只支持从各种流派组成。

    # 参数说明:
    # artists_file (`str`):
    #     包含艺术家与其ID映射的词汇文件的路径。默认文件支持 "v2" 和 "v3"。
    # genres_file (`str`):
    #     包含流派与其ID映射的词汇文件的路径。
    # lyrics_file (`str`):
    #     包含歌词分词接受字符的词汇文件的路径。
    # version (`List[str]`, 可选, 默认为 `["v3", "v2", "v2"]`) :
    #     分词器版本列表。`5b-lyrics` 的顶级优先模型使用 `v3` 而不是 `v2` 进行训练。
    # n_genres (`int`, 可选, 默认为 5):
    #     用于组合的最大流派数。
    # max_n_lyric_tokens (`int`, 可选, 默认为 512):
    #     保留的最大歌词分词数量。
    # unk_token (`str`, 可选, 默认为 `"<|endoftext|>"`):
    #     未知标记。词汇表中没有的标记将无法转换为ID,并被设置为此标记。

    # 定义类级别的属性

    # 词汇文件名列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 预训练词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 预训练歌词分词器的最大输入尺寸
    max_lyric_input_size = PRETRAINED_LYRIC_TOKENS_SIZES
    # 模型的输入名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        version=["v3", "v2", "v2"],
        # 如果 unk_token 是字符串,则创建一个 AddedToken 对象,保留字符串两侧空白字符
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        # 设置模型版本号
        self.version = version
        # 设置歌词最大 token 数量
        self.max_n_lyric_tokens = max_n_lyric_tokens
        # 设置流派数量
        self.n_genres = n_genres
        # 初始化未知 token 的解码器
        self._added_tokens_decoder = {0: unk_token}

        # 读取并加载艺术家编码器(JSON 格式)
        with open(artists_file, encoding="utf-8") as vocab_handle:
            self.artists_encoder = json.load(vocab_handle)

        # 读取并加载流派编码器(JSON 格式)
        with open(genres_file, encoding="utf-8") as vocab_handle:
            self.genres_encoder = json.load(vocab_handle)

        # 读取并加载歌词编码器(JSON 格式)
        with open(lyrics_file, encoding="utf-8") as vocab_handle:
            self.lyrics_encoder = json.load(vocab_handle)

        # 正则表达式模式,用于识别词汇表中的未知字符
        oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
        # 在 v2 版本中,我们的 n_vocab=80,但在 v3 中我们遗漏了 +,所以现在 n_vocab=79 个字符。
        # 如果歌词编码器长度为 79,则更新正则表达式以包括额外的字符 '-'
        if len(self.lyrics_encoder) == 79:
            oov = oov.replace(r"\-'", r"\-+'")

        # 编译正则表达式模式,用于匹配词汇表中的未知字符
        self.out_of_vocab = regex.compile(oov)
        # 创建艺术家的解码器,将编码器的键值对反转
        self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
        # 创建流派的解码器,将编码器的键值对反转
        self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
        # 创建歌词的解码器,将编码器的键值对反转
        self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
        # 调用父类的初始化方法,传递参数给父类

    def vocab_size(self):
        # 返回总的词汇量大小,包括艺术家、流派和歌词的编码器的长度之和
        return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)

    def get_vocab(self):
        # 返回包含艺术家、流派和歌词编码器的字典
        return {
            "artists_encoder": self.artists_encoder,
            "genres_encoder": self.genres_encoder,
            "lyrics_encoder": self.lyrics_encoder,

    def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
        """Converts the artist, genre and lyrics tokens to their index using the vocabulary.
        The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
        the lyrics token sequence.
        # 将艺术家标签转换为它们在编码器中的索引
        artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
        # 将流派标签转换为它们在编码器中的索引,并在需要时添加填充标记
        for genres in range(len(list_genres)):
            list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
            list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))

        # 将歌词字符转换为它们在编码器中的索引,每个歌词位置(如 total_length、offset、duration)提供相应的歌词
        lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
        return artists_id, list_genres, lyric_ids
    # 将字符串 lyrics 转换为标记序列(字符串),使用指定的标记器。
    # 如果是基于词汇的,按单词拆分;如果是基于子词(如BPE/SentencePieces/WordPieces),则按子词拆分。
    # 对于基于字符的词汇表,仅将歌词拆分成字符。
    def _tokenize(self, lyrics):
        Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).

        Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
        # 仅对歌词进行拆分,如果是基于字符的词汇表,这很容易处理
        return list(lyrics)

    # 使用标记器将艺术家、流派和歌词转换为标记序列的三元组
    def tokenize(self, artist, genre, lyrics, **kwargs):
        Converts three strings in a 3 sequence of tokens using the tokenizer
        # 准备艺术家、流派和歌词以进行标记化
        artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
        # 将歌词转换为标记序列
        lyrics = self._tokenize(lyrics)
        return artist, genre, lyrics

    # 准备艺术家、流派和歌词以进行标记化
    def prepare_for_tokenization(
        self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
    ) -> Tuple[str, str, str, Dict[str, Any]]:
        Performs any necessary transformations before tokenization.

            artist (`str`):
                The artist name to prepare. This will mostly lower the string
            genres (`str`):
                The genre name to prepare. This will mostly lower the string.
            lyrics (`str`):
                The lyrics to prepare.
            is_split_into_words (`bool`, *optional*, defaults to `False`):
                Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
                tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
                which it will tokenize. This is useful for NER or token classification.
        # 循环遍历版本列表,进行必要的转换操作
        for idx in range(len(self.version)):
            # 如果版本为 "v3",将艺术家和流派名称转换为小写
            if self.version[idx] == "v3":
                artists[idx] = artists[idx].lower()
                genres[idx] = [genres[idx].lower()]
                # 如果版本不为 "v3",对艺术家名称进行标准化处理并添加后缀 ".v2",对流派名称进行拆分并添加后缀 ".v2"
                artists[idx] = self._normalize(artists[idx]) + ".v2"
                genres[idx] = [
                    self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
                ]  # split is for the full dictionary with combined genres

        # 如果版本为 "v2",设置处理非词汇表外字符的正则表达式和词汇表
        if self.version[0] == "v2":
            self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
            vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
            # 创建词汇表和词汇表的索引
            self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
            self.vocab["<unk>"] = 0
            self.n_vocab = len(vocab) + 1
            self.lyrics_encoder = self.vocab
            self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
            self.lyrics_decoder[0] = ""
            # 如果版本不为 "v2",设置处理非词汇表外字符的正则表达式
            self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")

        # 运行去除文本中重音符号的函数
        lyrics = self._run_strip_accents(lyrics)
        # 替换文本中的 "\\" 为换行符 "\n"
        lyrics = lyrics.replace("\\", "\n")
        # 使用正则表达式去除文本中的非词汇表外字符,并初始化两个空列表
        lyrics = self.out_of_vocab.sub("", lyrics), [], []
        # 返回处理后的艺术家名称、流派名称和歌词
        return artists, genres, lyrics

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 使用 unicodedata 库规范化文本中的 Unicode 字符,去除重音符号
        text = unicodedata.normalize("NFD", text)
        output = []
        # 遍历文本中的每个字符
        for char in text:
            # 获取字符的 Unicode 分类
            cat = unicodedata.category(char)
            # 如果字符的分类为 "Mn"(非重音符号),跳过该字符
            if cat == "Mn":
            # 将符合条件的字符添加到输出列表中
        # 将输出列表中的字符连接成字符串并返回
        return "".join(output)
    # 定义一个方法,用于规范化输入的文本。这个过程适用于音乐流派和艺术家名称。

    def _normalize(self, text: str) -> str:
        Normalizes the input text. This process is for the genres and the artist

            text (`str`):
                Artist or Genre string to normalize
        # 定义可接受的字符集,包括小写字母、大写字母、数字和点号
        accepted = (
            [chr(i) for i in range(ord("a"), ord("z") + 1)]
            + [chr(i) for i in range(ord("A"), ord("Z") + 1)]
            + [chr(i) for i in range(ord("0"), ord("9") + 1)]
            + ["."]
        accepted = frozenset(accepted)  # 将字符集转换为不可变集合以提高性能
        pattern = re.compile(r"_+")  # 编译用于匹配多个下划线的正则表达式模式
        # 将文本转换为小写,并替换不在接受字符集中的字符为下划线
        text = "".join([c if c in accepted else "_" for c in text.lower()])
        text = pattern.sub("_", text).strip("_")  # 将多个连续的下划线替换为单个下划线,并去除首尾的下划线
        return text  # 返回规范化后的文本字符串

    # 定义一个方法,将歌词令牌列表转换为一个字符串
    def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
        return " ".join(lyrics)

    # 定义一个方法,用于将输入转换为张量(Tensor),可以选择添加批次轴
        Convert the inner content to tensors.

            tensor_type (`str` or [`~utils.TensorType`], *optional*):
                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
                unset, no modification is done.
            prepend_batch_axis (`int`, *optional*, defaults to `False`):
                Whether or not to add the batch dimension during the conversion.
        # Convert to TensorType
        if not isinstance(tensor_type, TensorType):
            # 如果 `tensor_type` 不是 `TensorType` 类型的实例,则转换为 `TensorType`
            tensor_type = TensorType(tensor_type)

        # Get a function reference for the correct framework
        if tensor_type == TensorType.TENSORFLOW:
            # 如果 `tensor_type` 是 `TensorType.TENSORFLOW`
            if not is_tf_available():
                # 检查 TensorFlow 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError(
                    "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
            import tensorflow as tf

            # 使用 TensorFlow 的 constant 函数
            as_tensor = tf.constant
            # 使用 TensorFlow 的 is_tensor 函数
            is_tensor = tf.is_tensor
        elif tensor_type == TensorType.PYTORCH:
            # 如果 `tensor_type` 是 `TensorType.PYTORCH`
            if not is_torch_available():
                # 检查 PyTorch 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
            import torch

            # 使用 PyTorch 的 tensor 函数
            as_tensor = torch.tensor
            # 使用 PyTorch 的 is_tensor 函数
            is_tensor = torch.is_tensor
        elif tensor_type == TensorType.JAX:
            # 如果 `tensor_type` 是 `TensorType.JAX`
            if not is_flax_available():
                # 检查 JAX 是否可用,若不可用则抛出 ImportError 异常
                raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
            import jax.numpy as jnp  # noqa: F811

            # 使用 JAX 的 array 函数
            as_tensor = jnp.array
            # 使用自定义的 `_is_jax` 函数
            is_tensor = _is_jax
            # 默认情况下使用 NumPy 的 asarray 函数
            as_tensor = np.asarray
            # 使用自定义的 `_is_numpy` 函数
            is_tensor = _is_numpy

        # Do the tensor conversion in batch
        # 在批处理中进行张量转换

            if prepend_batch_axis:
                # 如果 `prepend_batch_axis` 为真,则在 `inputs` 前面添加一个批次维度
                inputs = [inputs]

            # 如果 `inputs` 不是张量,则使用 `as_tensor` 将其转换为张量
            if not is_tensor(inputs):
                inputs = as_tensor(inputs)
        except:  # noqa E722
            # 捕获所有异常,通常用于处理可能的数值或类型转换问题
            raise ValueError(
                "Unable to create tensor, you should probably activate truncation and/or padding "
                "with 'padding=True' 'truncation=True' to have batched tensors with the same length."

        return inputs
    def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
        """Convert the raw string to a list of token ids

            artist (`str`):
                Name of the artist.
            genres (`str`):
                List of genres that will be mixed to condition the audio
            lyrics (`str`, *optional*, defaults to `""`):
                Lyrics used to condition the generation
        # 初始化输入的 token ids
        input_ids = [0, 0, 0]
        # 将 artist 复制多份,以匹配 self.version 的长度
        artist = [artist] * len(self.version)
        # 将 genres 复制多份,以匹配 self.version 的长度
        genres = [genres] * len(self.version)

        # 使用 tokenize 方法将 artist、genres 和 lyrics 转换为 tokens
        artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
        # 将 tokens 转换为对应的 ids
        artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)

        # 初始化 attention_masks 为负无穷大
        attention_masks = [-INFINITY] * len(full_tokens[-1])
        # 根据每个版本的要求,将各个 ids 组合成 input_ids,并转换为 tensors
        input_ids = [
                [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
            for i in range(len(self.version))
        # 返回 BatchEncoding 对象,包含 input_ids 和 attention_masks
        return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        Saves the tokenizer's vocabulary dictionary to the provided save_directory.

            save_directory (`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.

            filename_prefix (`Optional[str]`, *optional*):
                A prefix to add to the names of the files saved by the tokenizer.

        # 检查 save_directory 是否存在,若不存在则记录错误信息并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")

        # 将 artists_encoder 转换为 JSON 格式并保存到指定路径
        artists_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
        with open(artists_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.artists_encoder, ensure_ascii=False))

        # 将 genres_encoder 转换为 JSON 格式并保存到指定路径
        genres_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
        with open(genres_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.genres_encoder, ensure_ascii=False))

        # 将 lyrics_encoder 转换为 JSON 格式并保存到指定路径
        lyrics_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
        with open(lyrics_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))

        # 返回保存的文件路径元组
        return (artists_file, genres_file, lyrics_file)
    def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
        Converts an index (integer) in a token (str) using the vocab.

            artists_index (`int`):
                Index of the artist in its corresponding dictionary.
            genres_index (`Union[List[int], int]`):
               Index of the genre in its corresponding dictionary. Can be a single index or a list of indices.
            lyric_index (`List[int]`):
                List of character indices, each corresponding to a character.

            artist (`Optional[str]`):
                Decoded artist name corresponding to artists_index.
            genres (`List[Optional[str]]`):
                List of decoded genre names corresponding to genres_index.
            lyrics (`List[Optional[str]]`):
                List of decoded characters corresponding to lyric_index.
        # Retrieve artist name from artists_decoder using artists_index
        artist = self.artists_decoder.get(artists_index)
        # Retrieve genre names from genres_decoder for each genre index in genres_index
        genres = [self.genres_decoder.get(genre) for genre in genres_index]
        # Retrieve character representations from lyrics_decoder for each character index in lyric_index
        lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
        # Return the decoded artist name, list of decoded genres, and list of decoded characters
        return artist, genres, lyrics


# 版权声明和许可信息
# 该模块受 Apache License, Version 2.0 许可,详情请访问

# 导入类型检查模块
from typing import TYPE_CHECKING

# 导入自定义异常和模块惰性加载工具函数
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available

# 定义模块的导入结构
_import_structure = {
    "configuration_jukebox": [
    "tokenization_jukebox": ["JukeboxTokenizer"],

# 检查是否 Torch 可用,若不可用则抛出自定义的依赖不可用异常
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果 Torch 可用,则添加额外的模块导入结构
    _import_structure["modeling_jukebox"] = [

# 如果是类型检查阶段
    # 从相应模块导入特定的类或变量
    from .configuration_jukebox import (
    from .tokenization_jukebox import JukeboxTokenizer

    # 再次检查 Torch 是否可用,若不可用则忽略异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 若 Torch 可用,则从 modeling_jukebox 模块导入特定类或变量
        from .modeling_jukebox import (

# 如果不是类型检查阶段,则执行以下操作
    # 导入 sys 模块
    import sys

    # 将当前模块定义为一个惰性加载模块
    # 使用 _LazyModule 类,传入当前模块的名称、文件路径、导入结构以及模块规范
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)


# coding=utf-8
# 设置文件编码为UTF-8,确保支持多语言字符集

# 版权声明及许可证信息
# 版权所有 2023 Microsoft Research 和 HuggingFace Inc. 团队。保留所有权利。
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 没有任何明示或暗示的保证或条件。
# 有关详细信息,请参阅许可证。

""" KOSMOS-2 模型配置"""

import os
from typing import Union

# 导入预训练配置类
from ...configuration_utils import PretrainedConfig
# 导入日志记录工具
from ...utils import logging

# 获取日志记录器
logger = logging.get_logger(__name__)

# 预训练配置存档映射字典,映射模型名称到其配置文件的下载链接
    "microsoft/kosmos-2-patch14-224": (
    # 查看所有 KOSMOS-2 模型的列表:

class Kosmos2TextConfig(PretrainedConfig):
    这是一个配置类,用于存储 [`Kosmos2TextModel`] 的配置信息。根据指定的参数实例化 KOSMOS-2 文本解码器,
    定义模型架构。使用默认参数实例化配置对象将产生类似于 KOSMOS-2 文本解码器
    [microsoft/kosmos-2-patch14-224]( 架构的配置。

    配置对象继承自 [`PretrainedConfig`],可用于控制模型输出。阅读 [`PretrainedConfig`] 的文档以获取更多信息。
    # 定义 Kosmos2 模型的参数和默认值

    # 模型的类型,用于标识 Kosmos2 文本模型
    model_type = "kosmos_2_text_model"

    # 推断阶段需要忽略的键列表,这些键不会在推断时使用
    keys_to_ignore_at_inference = ["past_key_values"]

    # 属性映射字典,将模型参数的名称映射到 Kosmos2 模型期望的名称
    attribute_map = {
        "num_attention_heads": "attention_heads",  # 注意力头的数量
        "hidden_size": "embed_dim",  # 隐藏层的维度
        "num_hidden_layers": "layers",  # Transformer 编码器中的隐藏层数量
    # 初始化函数,用于创建一个新的配置对象
    def __init__(
        vocab_size=65037,                    # 词汇表大小,默认为65037
        max_position_embeddings=2048,        # 最大位置嵌入数量,默认为2048
        embed_dim=2048,                      # 嵌入维度,默认为2048
        layers=24,                           # 层数,默认为24
        ffn_dim=8192,                        # 前馈神经网络维度,默认为8192
        attention_heads=32,                  # 注意力头数,默认为32
        activation_function="gelu",          # 激活函数,默认为"gelu"
        dropout=0.1,                         # 普通层级dropout概率,默认为0.1
        attention_dropout=0.1,               # 注意力模块dropout概率,默认为0.1
        activation_dropout=0.0,              # 激活函数dropout概率,默认为0.0
        layerdrop=0.0,                       # 层级dropout概率,默认为0.0
        layer_norm_eps=1e-5,                 # 层归一化的epsilon,默认为1e-5
        init_std=0.02,                       # 初始化标准差,默认为0.02
        scale_embedding=True,                # 是否缩放嵌入,默认为True
        use_cache=True,                      # 是否使用缓存,默认为True
        pad_token_id=1,                      # 填充标记ID,默认为1
        bos_token_id=0,                      # 开始序列标记ID,默认为0
        eos_token_id=2,                      # 结束序列标记ID,默认为2
        **kwargs,                            # 其他关键字参数
        # 调用父类的初始化方法,设置填充、开始、结束标记ID等参数

        # 初始化配置对象的各个属性
        self.vocab_size = vocab_size                         # 设置词汇表大小属性
        self.max_position_embeddings = max_position_embeddings  # 设置最大位置嵌入数量属性
        self.embed_dim = embed_dim                           # 设置嵌入维度属性
        self.layers = layers                                 # 设置层数属性
        self.ffn_dim = ffn_dim                               # 设置前馈神经网络维度属性
        self.attention_heads = attention_heads               # 设置注意力头数属性
        self.activation_function = activation_function       # 设置激活函数属性
        self.dropout = dropout                               # 设置普通层级dropout概率属性
        self.attention_dropout = attention_dropout           # 设置注意力模块dropout概率属性
        self.activation_dropout = activation_dropout         # 设置激活函数dropout概率属性
        self.layerdrop = layerdrop                           # 设置层级dropout概率属性
        self.layer_norm_eps = layer_norm_eps                 # 设置层归一化的epsilon属性
        self.init_std = init_std                             # 设置初始化标准差属性
        self.scale_embedding = scale_embedding               # 设置是否缩放嵌入属性
        self.use_cache = use_cache                           # 设置是否使用缓存属性

    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 将token相关参数添加到kwargs中

        # 获取配置字典和更新后的kwargs
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

        # 如果加载自Kosmos2Config,则获取文本配置字典
        if config_dict.get("model_type") == "kosmos-2":
            config_dict = config_dict["text_config"]

        # 如果配置字典中存在model_type,并且与类的model_type不同,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."

        # 从配置字典和kwargs创建类的实例
        return cls.from_dict(config_dict, **kwargs)
# 定义 `Kosmos2VisionConfig` 类,用于存储 `Kosmos2VisionModel` 的配置信息。
# 继承自 `PretrainedConfig`,用于控制模型的输出。详细信息请参考 `PretrainedConfig` 的文档。

class Kosmos2VisionConfig(PretrainedConfig):
    This is the configuration class to store the configuration of a [`Kosmos2VisionModel`]. It is used to instantiate a
    KOSMOS-2 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
    configuration with the defaults will yield a similar configuration to that of the vision encoder of the KOSMOS-2
    [microsoft/kosmos-2-patch14-224]( architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

        hidden_size (`int`, *optional*, defaults to 1024):
            Dimensionality of the encoder layers and the pooler layer.
        intermediate_size (`int`, *optional*, defaults to 4096):
            Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
        num_hidden_layers (`int`, *optional*, defaults to 24):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        image_size (`int`, *optional*, defaults to 224):
            The size (resolution) of each image.
        patch_size (`int`, *optional*, defaults to 14):
            The size (resolution) of each patch.
        hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
            `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
        layer_norm_eps (`float`, *optional*, defaults to 1e-5):
            The epsilon used by the layer normalization layers.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        initializer_factor (`float`, *optional*, defaults to 1):
            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
        # 调用父类的初始化方法,传入关键字参数

        self.hidden_size = hidden_size
        # 设置隐藏层大小

        self.intermediate_size = intermediate_size
        # 设置中间层大小

        self.num_hidden_layers = num_hidden_layers
        # 设置隐藏层数量

        self.num_attention_heads = num_attention_heads
        # 设置注意力头数量

        self.num_channels = num_channels
        # 设置通道数量

        self.patch_size = patch_size
        # 设置图像块大小

        self.image_size = image_size
        # 设置图像大小

        self.initializer_range = initializer_range
        # 设置初始化范围

        self.initializer_factor = initializer_factor
        # 设置初始化因子

        self.attention_dropout = attention_dropout
        # 设置注意力丢弃率

        self.layer_norm_eps = layer_norm_eps
        # 设置层归一化的 epsilon 值

        self.hidden_act = hidden_act
        # 设置隐藏层激活函数

    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
        # 调用类方法 _set_token_in_kwargs,设置关键字参数中的 token

        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
        # 调用类方法 get_config_dict,获取预训练模型的配置字典和更新后的关键字参数

        # 如果从 Kosmos2Config 加载,则获取视觉配置字典
        if config_dict.get("model_type") == "kosmos-2":
            config_dict = config_dict["vision_config"]

        # 如果配置字典中存在 "model_type" 并且类具有 "model_type" 属性,并且它们不相同,发出警告
        if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
                f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
                f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."

        # 从配置字典创建类的实例,并返回
        return cls.from_dict(config_dict, **kwargs)
class Kosmos2Config(PretrainedConfig):
    This is the configuration class to store the configuration of a [`Kosmos2Model`]. It is used to instantiate a
    KOSMOS-2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the KOSMOS-2
    [microsoft/kosmos-2-patch14-224]( architecture.

        text_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Kosmos2TextConfig`].
        vision_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`Kosmos2VisionConfig`].
        latent_query_num (`int`, *optional*, defaults to 64):
            The number of latent query tokens that represent the image features used in the text decoder component.
        kwargs (*optional*):
            Dictionary of keyword arguments.


    >>> from transformers import Kosmos2Config, Kosmos2Model

    >>> # Initializing a Kosmos-2 kosmos-2-patch14-224 style configuration
    >>> configuration = Kosmos2Config()

    >>> # Initializing a model (with random weights) from the kosmos-2-patch14-224 style configuration
    >>> model = Kosmos2Model(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config

    # 设定模型类型为 "kosmos-2"
    model_type = "kosmos-2"
    # 标志这个配置类是由多个部分组成
    is_composition = True

    def __init__(
        # 调用父类构造函数,传入所有额外的关键字参数

        # 如果文本配置为空,使用默认空字典并记录日志
        if text_config is None:
            text_config = {}
  "`text_config` is `None`. Initializing the `Kosmos2TextConfig` with default values.")

        # 如果视觉配置为空,使用默认空字典并记录日志
        if vision_config is None:
            vision_config = {}
  "`vision_config` is `None`. Initializing the `Kosmos2VisionConfig` with default values.")

        # 根据传入的文本配置初始化 `Kosmos2TextConfig` 对象
        self.text_config = Kosmos2TextConfig(**text_config)
        # 根据传入的视觉配置初始化 `Kosmos2VisionConfig` 对象
        self.vision_config = Kosmos2VisionConfig(**vision_config)

        # 设置 latent_query_num 属性,表示在文本解码器组件中用于表示图像特征的潜在查询标记数目
        self.latent_query_num = latent_query_num


import argparse  # 导入命令行参数解析模块

from fairseq.checkpoint_utils import load_checkpoint_to_cpu  # 从fairseq库中导入加载checkpoint到CPU的函数

from transformers import Kosmos2Config, Kosmos2ForConditionalGeneration  # 从transformers库中导入Kosmos2Config和Kosmos2ForConditionalGeneration类

    "gpt_model.decoder.output_projection": "text_model.lm_head",  # 将"gpt_model.decoder.output_projection"映射为"text_model.lm_head"
    "gpt_model.decoder": "text_model.model",  # 将"gpt_model.decoder"映射为"text_model.model"
    "img_connector": "image_to_text_projection",  # 将"img_connector"映射为"image_to_text_projection"
    "img_model.visual.class_embedding": "vision_model.model.embeddings.class_embedding",  # 将"img_model.visual.class_embedding"映射为"vision_model.model.embeddings.class_embedding"
    "img_model.visual.positional_embedding": "vision_model.model.embeddings.position_embedding.weight",  # 将"img_model.visual.positional_embedding"映射为"vision_model.model.embeddings.position_embedding.weight"
    "img_model.visual.conv1": "vision_model.model.embeddings.patch_embedding",  # 将"img_model.visual.conv1"映射为"vision_model.model.embeddings.patch_embedding"
    "img_model.visual": "vision_model.model",  # 将"img_model.visual"映射为"vision_model.model"
    "ln_pre": "pre_layrnorm",  # 将"ln_pre"映射为"pre_layrnorm"
    "ln_post": "post_layernorm",  # 将"ln_post"映射为"post_layernorm"
    "transformer.resblocks": "encoder.layers",  # 将"transformer.resblocks"映射为"encoder.layers"
    "ts_attn": "self_attn",  # 将"ts_attn"映射为"self_attn"
    "ln_1": "layer_norm1",  # 将"ln_1"映射为"layer_norm1"
    "ln_2": "layer_norm2",  # 将"ln_2"映射为"layer_norm2"
    "c_fc": "fc1",  # 将"c_fc"映射为"fc1"
    "c_proj": "fc2",  # 将"c_proj"映射为"fc2"

    # 在原始代码中仅用于将权重发送到所需设备的缓冲区
    # 在原始的KOSMOS-2代码中前向传播中从未使用过的权重

def rename_key(key):
    for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
        if key_to_modify in key:
            key = key.replace(key_to_modify, new_key)  # 根据映射表修改键名

    return key

def convert_kosmos2_checkpoint_to_pytorch(checkpoint_path, pytorch_dump_folder_path):
    state = load_checkpoint_to_cpu(checkpoint_path)  # 加载checkpoint到CPU
    state_dict = state["model"]  # 获取模型的state_dict
    state_dict_keys = list(state_dict.keys())  # 获取state_dict中的所有键列表

    config = Kosmos2Config()  # 创建Kosmos2Config实例
    # 为了匹配原始演示给出的结果,设置必要的配置项
    config.text_config.no_repeat_ngram_size = 3
    model = Kosmos2ForConditionalGeneration(config)  # 创建Kosmos2ForConditionalGeneration模型实例

    # 转换(通过重命名键名)
    converted_state_dict = {}
    for key in state_dict_keys:
        if key in KEYS_TO_IGNORE:
            continue  # 跳过需要忽略的键名
        renamed_key = rename_key(key)  # 根据映射重命名键名
        converted_state_dict[renamed_key] = state_dict[key]  # 更新转换后的state_dict

    # 检查权重加载
    model.load_state_dict(converted_state_dict, strict=True)  # 加载转换后的state_dict到模型
    # 保存结果
    model.save_pretrained(pytorch_dump_folder_path)  # 将模型保存为PyTorch格式的文件

if __name__ == "__main__":
    parser = argparse.ArgumentParser()  # 创建参数解析器
    # 必需参数
        "--kosmos2_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump."
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    args = parser.parse_args()  # 解析命令行参数
    convert_kosmos2_checkpoint_to_pytorch(args.kosmos2_checkpoint_path, args.pytorch_dump_folder_path)  # 执行转换函数


# 设置文件编码为 UTF-8
# 版权声明,声明版权归 Microsoft Research 和 HuggingFace Inc. 团队所有
# 根据 Apache License, Version 2.0 许可,除非符合许可要求,否则不得使用此文件
# 您可以在以下网址获取许可证的副本
# 除非适用法律要求或书面同意,否则本软件按"原样"分发,不附带任何明示或暗示的担保或条件
# 请参阅许可证以了解特定语言的权限和限制

""" PyTorch KOSMOS-2 model."""

# 导入必要的库和模块
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

# 导入与激活函数相关的模块
from ...activations import ACT2FN
# 导入不同类型的模型输出
from ...modeling_outputs import (
# 导入预训练模型的基类
from ...modeling_utils import PreTrainedModel
# 导入实用工具函数
from ...utils import (
# 导入 KOSMOS-2 的配置类
from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 文档用的配置对象
_CONFIG_FOR_DOC = Kosmos2Config

# 预训练模型存档列表
    # 可以在 查看所有 KOSMOS-2 模型

# 定义函数:将注意力掩码从 `[bsz, seq_len]` 扩展到 `[bsz, 1, tgt_seq_len, src_seq_len]`
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    # 扩展注意力掩码
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    # 创建反向的掩码
    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(, torch.finfo(dtype).min)

# 定义函数:创建用于双向自注意力的因果掩码
def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
    Make causal mask used for bi-directional self-attention.
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask =

    if past_key_values_length > 0:
        mask =[torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

# 从 transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids 复制过来的部分
# 定义一个函数,根据输入的 token IDs 创建位置 ID,用于替换非填充符号为它们的位置编号。位置编号从 padding_idx+1 开始计数,忽略填充符号。
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

        input_ids (torch.Tensor): 输入的 token IDs
        padding_idx (int): 填充符号的索引
        past_key_values_length (int, optional): 过去键值长度,用于增量索引计算

        torch.Tensor: 替换后的位置 ID
    # 在这里进行一系列的类型转换和转换,以确保同时支持 ONNX 导出和 XLA。
    mask =  # 创建一个掩码,标记非填充符号的位置
    incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask  # 计算增量索引
    return incremental_indices.long() + padding_idx  # 返回最终的位置 ID,加上 padding_idx 得到真实的位置编号

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads

    This model is also a PyTorch [torch.nn.Module]( subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

        config ([`Kosmos2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.

        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.



class Kosmos2ModelOutput(ModelOutput):
    Base class for text model's outputs that also contains a pooling of the last hidden states.
    # 最后一层模型的隐藏状态,形状为(batch_size, sequence_length, hidden_size)
    last_hidden_state: torch.FloatTensor = None
    # 过去的键-值对,可选参数,形状为(config.n_layers, 2, batch_size, num_heads, sequence_length, embed_size_per_head),用于加速顺序解码
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 模型每一层的隐藏状态的元组,如果模型有嵌入层,则包括嵌入层输出,形状为(batch_size, sequence_length, hidden_size)
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 自注意力机制每一层的注意力权重的元组,形状为(batch_size, num_heads, sequence_length, sequence_length)
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 图像嵌入的隐藏状态,形状为(batch_size, latent_query_num, hidden_size),可选参数
    image_embeds: Optional[torch.FloatTensor] = None
    # 定义一个可选类型的变量 projection_attentions,可能是一个包含 torch.FloatTensor 的元组,初始值为 None
    projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 定义一个变量 vision_model_output,类型为 BaseModelOutputWithPooling,初始值为 None
    vision_model_output: BaseModelOutputWithPooling = None
    # 定义一个方法 to_tuple,返回一个元组,包含对象所有键对应的值,但对于键为"text_model_output"和"vision_model_output"的情况,返回它们的 to_tuple() 方法的结果
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
    Model output class for `Kosmos2ForConditionalGeneration`.

        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
        image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
        projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,

            Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
            the weighted average in the self-attention heads.
        vision_model_output(`BaseModelOutputWithPooling`, *optional*):
            The output of the [`Kosmos2VisionModel`].
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.

            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
            input) to speed up sequential decoding.
    # 定义可选的损失张量
    loss: Optional[torch.FloatTensor] = None
    # 定义空的 logits 张量
    logits: torch.FloatTensor = None
    # 定义可选的过去键值元组,包含 FloatTensor
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    # 定义可选的隐藏状态元组,包含 FloatTensor
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的注意力元组,包含 FloatTensor
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 定义可选的图像嵌入张量
    image_embeds: Optional[torch.FloatTensor] = None
    # 定义可选的投影注意力元组,包含 FloatTensor
    projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
    # 定义空的视觉模型输出,类型为 BaseModelOutputWithPooling
    vision_model_output: BaseModelOutputWithPooling = None

    # 转换为元组的方法,返回包含所有非 "text_model_output" 和 "vision_model_output" 的属性的元组
    def to_tuple(self) -> Tuple[Any]:
        return tuple(
            self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
            for k in self.keys()
# 从transformers.models.clip.modeling_clip.CLIPVisionEmbeddings复制而来,修改为Kosmos2
class Kosmos2VisionEmbeddings(nn.Module):
    def __init__(self, config: Kosmos2VisionConfig):
        self.config = config
        self.embed_dim = config.hidden_size  # 设置嵌入维度为配置文件中的隐藏大小
        self.image_size = config.image_size  # 设置图像大小为配置文件中的图像大小
        self.patch_size = config.patch_size  # 设置补丁大小为配置文件中的补丁大小

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))  # 定义类别嵌入作为可学习参数

        self.patch_embedding = nn.Conv2d(
        )  # 定义补丁嵌入为二维卷积层,用于从图像像素值生成嵌入向量

        self.num_patches = (self.image_size // self.patch_size) ** 2  # 计算图像中的补丁数量
        self.num_positions = self.num_patches + 1  # 计算位置嵌入的数量
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)  # 定义位置嵌入为一个嵌入层
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)  # 注册位置 ID,用于序列位置编码

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        batch_size = pixel_values.shape[0]  # 获取批次大小
        target_dtype = self.patch_embedding.weight.dtype  # 获取目标数据类型
        patch_embeds = self.patch_embedding(  # 使用补丁嵌入层处理像素值,生成补丁嵌入向量

        # 展开补丁嵌入向量并进行维度转换
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)  # 扩展类别嵌入以适应批次大小
        embeddings =[class_embeds, patch_embeds], dim=1)  # 连接类别嵌入和补丁嵌入,形成最终嵌入向量
        embeddings = embeddings + self.position_embedding(self.position_ids)  # 添加位置嵌入到最终嵌入向量中
        return embeddings

# 从transformers.models.clip.modeling_clip.CLIPAttention复制而来,修改为Kosmos2Vision
class Kosmos2VisionAttention(nn.Module):
    """来自 'Attention Is All You Need' 论文的多头注意力机制"""

    def __init__(self, config):
        self.config = config
        self.embed_dim = config.hidden_size  # 设置嵌入维度为配置文件中的隐藏大小
        self.num_heads = config.num_attention_heads  # 设置注意力头数为配置文件中的注意力头数
        self.head_dim = self.embed_dim // self.num_heads  # 计算每个注意力头的维度
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: "
        self.scale = self.head_dim**-0.5  # 缩放因子,用于缩放注意力分数
        self.dropout = config.attention_dropout  # 设置注意力层的dropout比例

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化键的投影层
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化值的投影层
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化查询的投影层
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)  # 初始化输出的投影层

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
    # 定义一个方法 `forward`,用于模型前向传播操作
    def forward(
        # 输入参数:表示模型当前隐藏状态的张量
        hidden_states: torch.Tensor,
        # 输入参数:可选的注意力掩码张量,用于指示哪些位置需要注意
        attention_mask: Optional[torch.Tensor] = None,
        # 输入参数:可选的因果注意力掩码张量,用于自回归任务的自注意力
        causal_attention_mask: Optional[torch.Tensor] = None,
        # 输入参数:是否输出注意力权重,默认为 False
        output_attentions: Optional[bool] = False,
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
class Kosmos2VisionMLP(nn.Module):
    def __init__(self, config):
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]  # 使用配置中的激活函数
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)  # 创建全连接层 fc1
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)  # 创建全连接层 fc2

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)  # 应用 fc1
        hidden_states = self.activation_fn(hidden_states)  # 应用激活函数
        hidden_states = self.fc2(hidden_states)  # 应用 fc2
        return hidden_states

# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Kosmos2Vision
class Kosmos2VisionEncoderLayer(nn.Module):
    def __init__(self, config: Kosmos2VisionConfig):
        self.embed_dim = config.hidden_size  # 设置嵌入维度为隐藏尺寸
        self.self_attn = Kosmos2VisionAttention(config)  # 创建 Kosmos2VisionAttention 自注意力层
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)  # 创建 LayerNorm 层1
        self.mlp = Kosmos2VisionMLP(config)  # 创建 Kosmos2VisionMLP 多层感知器
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)  # 创建 LayerNorm 层2

    def forward(
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        causal_attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            causal_attention_mask (`torch.FloatTensor`): mask indicating the causal nature of attention
            output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers.
        residual = hidden_states  # 记录残差连接

        hidden_states = self.layer_norm1(hidden_states)  # 应用 LayerNorm 层1
        hidden_states, attn_weights = self.self_attn(
        )  # 应用自注意力机制层,并返回注意力权重
        hidden_states = residual + hidden_states  # 残差连接

        residual = hidden_states  # 记录残差连接
        hidden_states = self.layer_norm2(hidden_states)  # 应用 LayerNorm 层2
        hidden_states = self.mlp(hidden_states)  # 应用 MLP 层
        hidden_states = residual + hidden_states  # 残差连接

        outputs = (hidden_states,)  # 输出结果为 hidden_states

        if output_attentions:
            outputs += (attn_weights,)  # 如果需要输出注意力权重,加入输出结果

        return outputs

# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Kosmos2Vision
class Kosmos2VisionEncoder(nn.Module):
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    # 定义 Kosmos2VisionEncoderLayer 类,用于处理 Kosmos2Vision 模型的编码器层
    class Kosmos2VisionEncoderLayer(nn.Module):

        # 初始化方法,接收一个 Kosmos2VisionConfig 类型的配置对象作为参数
        def __init__(self, config: Kosmos2VisionConfig):
            # 调用父类的初始化方法
            # 将传入的配置对象保存到实例变量中
            self.config = config
            # 创建一个包含多个 Kosmos2VisionEncoderLayer 实例的列表,列表长度为 config.num_hidden_layers
            self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
            # 设置梯度检查点标志为 False
            self.gradient_checkpointing = False

        # 前向传播方法,接收多个参数
        def forward(
            attention_mask: Optional[torch.Tensor] = None,
            causal_attention_mask: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
# 类定义,实现了一个类似于 `transformers.models.clip.modeling_clip.CLIPVisionTransformer` 的模型,但没有为 `forward` 方法添加文档字符串
class Kosmos2VisionTransformer(nn.Module):
    # 构造函数,接受一个 `Kosmos2VisionConfig` 类型的参数 `config`
    # 初始化父类 `nn.Module`
    def __init__(self, config: Kosmos2VisionConfig):
        self.config = config
        embed_dim = config.hidden_size

        # 实例化 `Kosmos2VisionEmbeddings` 类,用于嵌入层处理
        self.embeddings = Kosmos2VisionEmbeddings(config)
        # LayerNorm 层,对嵌入向量进行归一化
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        # `Kosmos2VisionEncoder` 类,用于编码器的处理
        self.encoder = Kosmos2VisionEncoder(config)
        # 再次应用 LayerNorm 层,对输出进行归一化
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    # 前向传播函数,接受多个参数,返回一个元组或者 `BaseModelOutputWithPooling` 类型
    def forward(
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        # 如果 `output_attentions` 未指定,则使用配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 如果 `output_hidden_states` 未指定,则使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        # 如果 `return_dict` 未指定,则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 如果 `pixel_values` 为空,则抛出值错误异常
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        # 通过嵌入层处理 `pixel_values`,得到隐藏状态
        hidden_states = self.embeddings(pixel_values)
        # 对隐藏状态应用预 LayerNorm 层进行归一化
        hidden_states = self.pre_layrnorm(hidden_states)

        # 将归一化后的隐藏状态传递给编码器 `self.encoder` 进行编码
        encoder_outputs = self.encoder(

        # 获取编码器输出的最后一个隐藏状态
        last_hidden_state = encoder_outputs[0]
        # 从最后一个隐藏状态中提取池化输出,通常是第一个位置的输出
        pooled_output = last_hidden_state[:, 0, :]
        # 对池化输出应用后 LayerNorm 层进行归一化
        pooled_output = self.post_layernorm(pooled_output)

        # 如果不需要返回字典,则返回一个元组,包含最后一个隐藏状态、池化输出以及额外的编码器输出
        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        # 否则,返回一个 `BaseModelOutputWithPooling` 对象,包含最后一个隐藏状态、池化输出、所有隐藏状态以及注意力权重
        return BaseModelOutputWithPooling(

# 类定义,实现了一个类似于 `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` 的模块,但允许传递 `position_ids`
class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length."""

    # 构造函数,无参数,继承自 `nn.Module`
    # 此处省略了具体的初始化过程
    # 初始化函数,用于设置位置编码的参数
    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
        # 设置位置编码的偏移量为2
        self.offset = 2
        # 设定位置编码的维度
        self.embedding_dim = embedding_dim
        # 可选参数:填充索引
        self.padding_idx = padding_idx
        # 调用make_weights方法生成位置编码权重
        self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)

    # 静态方法:从transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding类中复制得到
    # 生成位置编码权重的方法
    def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
        # 调用get_embedding方法获取嵌入向量权重
        emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
        # 如果self中已经有了weights属性,则在forward方法中将权重转换成正确的数据类型和设备
        if hasattr(self, "weights"):
            emb_weights =, device=self.weights.device)

        # 将生成的权重注册为缓冲区,不持久化保存
        self.register_buffer("weights", emb_weights, persistent=False)

    # 静态方法:从transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding类中复制得到
    # 生成嵌入向量的方法
    def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):

        该方法与tensor2tensor中的实现匹配,但与《Attention Is All You Need》中第3.5节的描述略有不同。
        # 计算嵌入向量的半径
        half_dim = embedding_dim // 2
        # 计算正弦函数的周期
        emb = math.log(10000) / (half_dim - 1)
        # 计算正弦位置编码的指数值
        emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
        # 计算位置编码张量
        emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
        # 拼接正弦和余弦函数,生成最终的位置编码张量
        emb =[torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        # 如果嵌入维度是奇数,则在末尾填充零
        if embedding_dim % 2 == 1:
            emb =[emb, torch.zeros(num_embeddings, 1)], dim=1)
        # 如果指定了填充索引,则将该位置的嵌入向量置为零向量
        if padding_idx is not None:
            emb[padding_idx, :] = 0


    # 用于前向传播计算的方法,设置位置编码
    def forward(
        input_ids: torch.Tensor = None,
        inputs_embeds: torch.Tensor = None,
        past_key_values_length: int = 0,
        position_ids: torch.Tensor = None,
        # 如果传入了 input_ids 参数
        if input_ids is not None:
            # 获取 batch size 和 sequence length
            bsz, seq_len = input_ids.size()
            # 如果 position_ids 参数为 None
            if position_ids is None:
                # 根据输入的 token ids 创建 position ids。任何填充的 token 保持填充状态。
                position_ids = create_position_ids_from_input_ids(
                    input_ids, self.padding_idx, past_key_values_length
            # 获取 batch size 和 sequence length,排除最后一维
            bsz, seq_len = inputs_embeds.size()[:-1]
            # 如果 position_ids 参数为 None
            if position_ids is None:
                # 根据 inputs_embeds 和 past_key_values_length 创建 position ids
                position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)

        # 如果需要扩展 embeddings
        max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
        if max_pos > self.weights.size(0):
            # 根据最大位置和偏移量,以及 embedding 维度和填充索引,创建新的 weights
            self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)

        # 根据 position_ids 从 weights 中选择对应的 embeddings,并重新组织形状
        return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()

    # 从 transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds 复制而来
    def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
        直接提供 embeddings。无法推断哪些是填充的,因此生成顺序的 position ids。

            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        # 获取输入 embeddings 的形状,排除最后一维
        input_shape = inputs_embeds.size()[:-1]
        # 获取序列长度
        sequence_length = input_shape[1]

        # 根据序列长度、padding_idx 和设备类型,在设备上创建 long 类型的序列 tensor
        position_ids = torch.arange(
            self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
        # 扩展 position_ids 的形状以匹配 inputs_embeds,并确保连续性,加上 past_key_values_length
        return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
class KosmosTextAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`.
    def __init__(
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        add_inner_attn_layernorm: bool = False,
        bias: bool = True,
        self.embed_dim = embed_dim  # 设置模型的嵌入维度
        self.num_heads = num_heads  # 设置注意力头的数量
        self.dropout = dropout  # 设置dropout概率
        self.head_dim = embed_dim // num_heads  # 计算每个注意力头的维度

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
        self.scaling = self.head_dim**-0.5  # 缩放因子
        self.is_decoder = is_decoder  # 是否为解码器

        # 初始化线性投影层
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        # 添加内部注意力层规范化
        self.inner_attn_ln = None
        if add_inner_attn_layernorm:
            self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def _shape(self, projection: torch.Tensor) -> torch.Tensor:
        new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim)
        # 将投影重新形状以适应多头注意力的结构
        # (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
        new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
        return new_projection

    def forward(
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    # 定义神经网络的前向传播方法,接收隐藏状态作为输入
    def forward(self, hidden_states):
        # 将隐藏状态输入全连接层 fc1,并应用激活函数 activation_fn
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对隐藏状态进行 dropout 操作,以防止过拟合,根据训练状态决定是否执行
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout,
        # 对经过 dropout 后的隐藏状态进行层归一化处理
        hidden_states = self.ffn_layernorm(hidden_states)
        # 将归一化后的隐藏状态输入全连接层 fc2
        hidden_states = self.fc2(hidden_states)
        # 对最终输出的隐藏状态再次进行 dropout 操作,根据训练状态决定是否执行
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout,

        # 返回经过前向传播后的隐藏状态
        return hidden_states
# 定义一个名为 Kosmos2TextBlock 的神经网络模块,继承自 nn.Module
class Kosmos2TextBlock(nn.Module):
    # 初始化函数,接受一个名为 config 的 Kosmos2TextConfig 类型参数
    def __init__(self, config: Kosmos2TextConfig):
        # 调用父类 nn.Module 的初始化方法
        # 设置嵌入维度为 config 中的 embed_dim
        self.embed_dim = config.embed_dim

        # 创建自注意力层 KosmosTextAttention 对象
        self.self_attn = KosmosTextAttention(
        # 设置 dropout 概率
        self.dropout = config.dropout
        # 创建自注意力层的 LayerNorm 层
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        # 如果配置中包含交叉注意力设置
        if config.add_cross_attention:
            # 创建编码器注意力层 KosmosTextAttention 对象
            self.encoder_attn = KosmosTextAttention(
            # 创建编码器注意力层的 LayerNorm 层
            self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

        # 创建前馈神经网络对象 Kosmos2TextFFN
        self.ffn = Kosmos2TextFFN(config)
        # 创建最终输出层的 LayerNorm 层
        self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    # 前向传播函数,接受多个输入参数
    def forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # 保存原始的隐藏状态作为残差连接的基础
        residual = hidden_states

        # Self Attention
        # 如果有过去的键/值缓存,从中提取decoder单向self-attention的缓存键/值对,位置为1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # 对隐藏状态进行 layer normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 使用self-attention机制处理隐藏状态
        hidden_states, self_attn_weights, present_key_value = self.self_attn(

        # 对输出的隐藏状态进行dropout处理
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout,
        # 将残差连接到当前隐藏状态上
        hidden_states = residual + hidden_states

        # Cross-Attention Block
        cross_attn_present_key_value = None
        cross_attn_weights = None

        # 如果有encoder的隐藏状态
        if encoder_hidden_states is not None:
            # 检查是否存在cross-attention层,若不存在则抛出异常
            if not hasattr(self, "encoder_attn"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
                    " by setting `config.add_cross_attention=True`"

            # 保存当前隐藏状态作为残差连接的基础
            residual = hidden_states

            # 对隐藏状态进行 layer normalization
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

            # 如果有过去的键/值缓存,从中提取cross-attention的缓存键/值对,位置为3,4
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None

            # 使用cross-attention机制处理隐藏状态
            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(

            # 对输出的隐藏状态进行dropout处理
            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout,

            # 将残差连接到当前隐藏状态上
            hidden_states = residual + hidden_states

            # 将cross-attention的键/值对添加到当前的present_key_value中,位置为3,4
            present_key_value = present_key_value + cross_attn_present_key_value

        # Fully Connected
        # 保存当前隐藏状态作为残差连接的基础
        residual = hidden_states

        # 对隐藏状态进行 layer normalization
        hidden_states = self.final_layer_norm(hidden_states)

        # Feed Forward Network (FFN)
        hidden_states = self.ffn(hidden_states)

        # 将残差连接到当前隐藏状态上
        hidden_states = residual + hidden_states

        # 将最终的隐藏状态作为输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,将self-attention和cross-attention的权重也添加到输出中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        # 如果需要使用缓存,将当前的present_key_value添加到输出中
        if use_cache:
            outputs += (present_key_value,)

        return outputs
    Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].

        config: Kosmos2TextConfig
    def __init__(self, config: Kosmos2TextConfig):
        self.config = config  # 保存配置对象
        self.dropout = config.dropout  # 设置 dropout 概率
        self.layerdrop = config.layerdrop  # 设置层级 dropout 概率

        self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0  # 计算嵌入的缩放因子
        self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)  # 嵌入层,根据配置创建

        self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
        )  # 创建位置嵌入对象

        self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)])  # 创建多层 Transformer 块
        self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)  # 创建层归一化对象

        self.gradient_checkpointing = False  # 是否使用梯度检查点,默认为 False

    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
        # 创建因果注意力遮罩
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
            )  # 调用函数创建因果遮罩

        if attention_mask is not None:
            # 扩展注意力遮罩的维度
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
            combined_attention_mask = (
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
            )  # 合并注意力遮罩

        return combined_attention_mask

    def forward_embedding(
        inputs_embeds: torch.Tensor = None,
        image_embeds: torch.Tensor = None,
        img_input_mask: torch.Tensor = None,
        past_key_values_length: int = 0,
        position_ids: torch.Tensor = None,
        # 如果未提供 `inputs_embeds` 参数,则使用 `input_ids` 生成对应的嵌入表示
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # 如果提供了 `image_embeds` 参数,则将其融合到 `inputs_embeds` 中相应的位置
        if image_embeds is not None:
            # 使用 `img_input_mask` 将 `image_embeds` 插入到 `inputs_embeds` 中对应位置
            inputs_embeds[] =
                -1, image_embeds.size(-1)

        # 将 `inputs_embeds` 缩放乘以 `self.embed_scale`
        inputs_embeds = inputs_embeds * self.embed_scale

        # 嵌入位置信息
        positions = self.embed_positions(
        # 将位置嵌入张量移到与 `inputs_embeds` 相同的设备上
        positions =

        # 将位置嵌入张量与输入嵌入张量相加,得到隐藏状态张量
        hidden_states = inputs_embeds + positions

        # 在训练过程中进行 dropout 操作
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout,

        # 返回最终的隐藏状态张量作为前向传播的输出
        return hidden_states
class Kosmos2PreTrainedModel(PreTrainedModel):
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

    # 配置类,用于当前模型的配置信息
    config_class = Kosmos2Config
    # 是否支持梯度检查点(gradient checkpointing)
    supports_gradient_checkpointing = True
    # 不需要分割的模块列表
    _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]

class Kosmos2VisionModel(Kosmos2PreTrainedModel):
    # 使用的配置类
    config_class = Kosmos2VisionConfig
    # 主要输入名称为像素值
    main_input_name = "pixel_values"

    # 从 CLIPVisionModel.__init__ 复制而来,修改了命名空间和变量名
    def __init__(self, config: Kosmos2VisionConfig):
        # 创建视觉模型对象
        self.model = Kosmos2VisionTransformer(config)
        # 初始化权重并进行最终处理

    # 从 CLIPVisionModel.get_input_embeddings 复制而来,修改了命名空间和变量名
    def get_input_embeddings(self) -> nn.Module:
        # 返回嵌入层的 patch 嵌入
        return self.model.embeddings.patch_embedding

    # 添加了文档字符串到模型前向方法的装饰器
    # 替换了返回文档字符串,指定了输出类型和配置类
    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Kosmos2VisionConfig)
    def forward(
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:

        return self.model(

class Kosmos2TextModel(Kosmos2PreTrainedModel):
    # 使用的配置类
    config_class = Kosmos2TextConfig

    def __init__(self, config: Kosmos2TextConfig):
        # 创建文本模型对象
        self.model = Kosmos2TextTransformer(config)
        # 初始化权重并进行最终处理

    # 获取输入嵌入的方法
    def get_input_embeddings(self) -> nn.Module:
        # 返回嵌入令牌(embed_tokens)
        return self.model.embed_tokens

    # 设置输入嵌入的方法
    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    # 添加了文档字符串到模型前向方法的装饰器
    # 替换了返回文档字符串,指定了输出类型和配置类
    @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
    def forward(
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:

        - input_ids (Optional[torch.Tensor]): 输入的 token IDs 序列,默认为 None。
        - attention_mask (Optional[torch.Tensor]): 注意力遮罩张量,默认为 None。
        - image_embeds (Optional[torch.Tensor]): 图像嵌入张量,默认为 None。
        - image_embeds_position_mask (Optional[torch.Tensor]): 图像嵌入的位置遮罩张量,默认为 None。
        - encoder_hidden_states (Optional[torch.Tensor]): 编码器的隐藏状态张量,默认为 None。
        - encoder_attention_mask (Optional[torch.Tensor]): 编码器的注意力遮罩张量,默认为 None。
        - head_mask (Optional[torch.Tensor]): 头部遮罩张量,默认为 None。
        - cross_attn_head_mask (Optional[torch.Tensor]): 跨注意力头部遮罩张量,默认为 None。
        - past_key_values (Optional[List[torch.FloatTensor]]): 过去的键值对列表,默认为 None。
        - inputs_embeds (Optional[torch.Tensor]): 输入的嵌入张量,默认为 None。
        - position_ids (Optional[torch.Tensor]): 位置 ID 张量,默认为 None。
        - use_cache (Optional[bool]): 是否使用缓存,默认为 None。
        - output_attentions (Optional[bool]): 是否输出注意力权重,默认为 None。
        - output_hidden_states (Optional[bool]): 是否输出隐藏状态,默认为 None。
        - return_dict (Optional[bool]): 是否返回字典格式的输出,默认为 None。

        - Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 返回模型的输出,可能是一个元组或特定的输出类对象。

        # 调用模型的 forward 方法,将所有参数传递给模型,并返回模型的输出结果
        return self.model(
The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
# 基于KOSMOS-2的文本模型,顶部带有语言建模头部(线性层,权重与输入嵌入绑定)。

# 使用装饰器添加文档字符串到类的开头
    KOSMOS2_START_DOCSTRING,  # 使用预定义的起始文档字符串
class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel):
    config_class = Kosmos2TextConfig  # 设置配置类为Kosmos2TextConfig
    _tied_weights_keys = ["lm_head.weight"]  # 定义权重绑定的键名列表

    def __init__(self, config: Kosmos2TextConfig):

        # 初始化模型和语言建模头部线性层
        self.model = Kosmos2TextTransformer(config)  # 使用配置初始化文本转换器模型
        self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
        # 初始化线性层,输入维度为config.embed_dim,输出维度为config.vocab_size,无偏置

        # 初始化权重并应用最终处理

    def get_input_embeddings(self) -> nn.Module:
        return self.model.embed_tokens  # 返回模型的输入嵌入层

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value  # 设置模型的输入嵌入层为给定的value

    def get_output_embeddings(self) -> nn.Module:
        return self.lm_head  # 返回语言建模头部的输出嵌入层

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings  # 设置语言建模头部的输出嵌入层为给定的new_embeddings

    # 使用装饰器添加文档字符串到模型的前向方法
        output_type=CausalLMOutputWithCrossAttentions,  # 替换输出类型为带交叉注意力的因果语言建模输出
        config_class=Kosmos2TextConfig  # 替换配置类为Kosmos2TextConfig
    def forward(
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 前向传播函数,接受多种输入参数并返回模型输出
    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

            Depending on `return_dict`, either a tuple with `loss` and various outputs or an instance of
            `CausalLMOutputWithCrossAttentions` containing `loss`, `logits`, and other relevant model outputs.

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False  # 如果提供了labels,则强制关闭use_cache,避免使用缓存

        outputs = self.model(
            input_ids=input_ids,  # 输入的token IDs
            attention_mask=attention_mask,  # 注意力遮罩,用于指示哪些token需要被注意
            image_embeds=image_embeds,  # 图像嵌入向量,可选
            image_embeds_position_mask=image_embeds_position_mask,  # 图像嵌入位置掩码,可选
            encoder_hidden_states=encoder_hidden_states,  # 编码器的隐藏状态,用于多层编码器
            encoder_attention_mask=encoder_attention_mask,  # 编码器的注意力遮罩
            head_mask=head_mask,  # 多头注意力机制的头部遮罩
            cross_attn_head_mask=cross_attn_head_mask,  # 跨注意力头部的遮罩
            past_key_values=past_key_values,  # 过去的键值,用于生成
            inputs_embeds=inputs_embeds,  # 输入的嵌入向量,用于替代input_ids
            position_ids=position_ids,  # 位置IDs,指定每个token的位置
            use_cache=use_cache,  # 是否使用缓存,根据labels的存在动态设置
            output_attentions=output_attentions,  # 是否输出注意力权重
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态
            return_dict=return_dict,  # 是否返回字典形式的输出
        lm_logits = self.lm_head(outputs[0])  # 使用语言模型头部预测的logits

        loss = None
        if labels is not None:
            labels =  # 将labels移到与logits相同的设备上,以支持模型并行处理
            shift_logits = lm_logits[..., :-1, :].contiguous()  # 将logits向左移动一位,用于计算损失
            shift_labels = labels[..., 1:].contiguous()  # 将labels向左移动一位,与shift_logits对齐
            batch_size, seq_length, vocab_size = shift_logits.shape  # 获取logits的形状信息
            loss_fct = CrossEntropyLoss()  # 交叉熵损失函数
            loss = loss_fct(
                shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
            )  # 计算损失,将logits和labels展平成二维张量进行计算

        if not return_dict:
            output = (lm_logits,) + outputs[1:]  # 如果不返回字典,输出包含logits和其它模型输出
            return (loss,) + output if loss is not None else output  # 如果有损失,则返回损失和输出;否则只返回输出

        return CausalLMOutputWithCrossAttentions(
        )  # 返回一个包含损失、logits和其它相关模型输出的对象
        input_shape = input_ids.shape
        # 如果模型作为编码器-解码器模型的解码器使用,会动态创建解码器注意力掩码
        if attention_mask is None:
            # 如果没有提供注意力掩码,创建一个全为1的张量,形状与输入张量相同
            attention_mask = input_ids.new_ones(input_shape)

        position_ids = None

        # 如果使用了过去的键值对,根据输入的ID创建位置ID
        if past_key_values is not None:
            position_ids = create_position_ids_from_input_ids(
            )[:, -1:]

            # 截取输入的ID,仅保留最后一个
            input_ids = input_ids[:, -1:]
            # 图像信息已经编码到过去的键/值中,因此不需要额外的图像嵌入
            image_embeds = None
            image_embeds_position_mask = None
        elif image_embeds_position_mask is not None:
            # 将`False`追加到`image_embeds_position_mask`(因为在生成过程中`input_ids`会增长)
            batch_size, seq_len = input_ids.size()
            mask_len = image_embeds_position_mask.size()[-1]
            image_embeds_position_mask =
                    torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),

        return {
            "input_ids": input_ids,
            "image_embeds": image_embeds,
            "image_embeds_position_mask": image_embeds_position_mask,
            "past_key_values": past_key_values,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "use_cache": use_cache,

    # 从transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache复制过来的方法
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                # 根据beam_idx重新排序过去的状态
                tuple(past_state.index_select(0, for past_state in layer_past),
        return reordered_past
class Kosmos2ImageToTextProjection(nn.Module):
    """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""

    def __init__(self, config: Kosmos2Config):
        # 定义一个全连接层,将图像模型输出的隐藏状态映射到文本模型的嵌入维度
        self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
        # 定义一个可学习的查询参数矩阵,用于文本注意力机制
        self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))

        # 初始化文本注意力机制,用于处理图像到文本的投影
        self.x_attn = KosmosTextAttention(

    def forward(self, features):
        # 使用全连接层将图像特征转换为隐藏状态
        hidden_states = self.dense(features)

        # shape = [batch, latent_query_num, h_dim]
        # 准备 latent_query,扩展以匹配隐藏状态的形状
        latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
        # 将隐藏状态和 latent_query 连接起来,形成键值状态
        key_value_states =[hidden_states, latent_query], dim=1)

        # 应用文本注意力机制,处理图像到文本的转换过程
        hidden_states, attn_weights, _ = self.x_attn(

        # 返回处理后的隐藏状态和注意力权重
        return hidden_states, attn_weights

    KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
class Kosmos2Model(Kosmos2PreTrainedModel):
    config_class = Kosmos2Config
    main_input_name = "pixel_values"

    def __init__(self, config: Kosmos2Config):

        # 初始化文本模型、视觉模型和图像到文本的投影层
        self.text_model = Kosmos2TextModel(config.text_config)
        self.vision_model = Kosmos2VisionModel(config.vision_config)
        self.image_to_text_projection = Kosmos2ImageToTextProjection(config)

        # 初始化权重并应用最终处理

    def get_input_embeddings(self) -> nn.Module:
        # 返回文本模型的嵌入层
        return self.text_model.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置文本模型的嵌入层
        self.text_model.model.embed_tokens = value

    @replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
    # 定义模型的前向传播方法,接受多个可选的输入参数
    def forward(
        pixel_values: Optional[torch.Tensor] = None,  # 图像像素值的张量,可选
        input_ids: Optional[torch.Tensor] = None,  # 输入文本的张量表示,可选
        image_embeds_position_mask: Optional[torch.Tensor] = None,  # 图像嵌入位置掩码的张量,可选
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码的张量,可选
        head_mask: Optional[torch.Tensor] = None,  # 头部掩码的张量,可选
        past_key_values: Optional[List[torch.FloatTensor]] = None,  # 过去的键值对列表,可选
        image_embeds: Optional[torch.Tensor] = None,  # 图像嵌入的张量表示,可选
        inputs_embeds: Optional[torch.Tensor] = None,  # 输入嵌入的张量表示,可选
        position_ids: Optional[torch.Tensor] = None,  # 位置ID的张量表示,可选
        use_cache: Optional[bool] = None,  # 是否使用缓存,可选
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,可选
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,可选
        return_dict: Optional[bool] = None,  # 是否返回字典格式的结果,可选
KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
language model.
    KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
    language model.
class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel):
    # 指定配置类
    config_class = Kosmos2Config
    # 主要输入名称为像素值
    main_input_name = "pixel_values"
    # 绑定权重的键列表
    _tied_weights_keys = ["text_model.lm_head.weight"]

    def __init__(self, config: Kosmos2Config):
        # 调用父类初始化方法

        # 文本模型部分,使用给定的文本配置初始化
        self.text_model = Kosmos2TextForCausalLM(config.text_config)
        # 视觉模型部分,使用给定的视觉配置初始化
        self.vision_model = Kosmos2VisionModel(config.vision_config)

        # 图像到文本投影模块,使用给定的配置初始化
        self.image_to_text_projection = Kosmos2ImageToTextProjection(config)

        # 初始化权重并应用最终处理

    def get_input_embeddings(self) -> nn.Module:
        # 返回文本模型的嵌入层
        return self.text_model.model.embed_tokens

    def set_input_embeddings(self, value):
        # 设置文本模型的嵌入层
        self.text_model.model.embed_tokens = value

    def get_output_embeddings(self) -> nn.Module:
        # 返回文本模型的输出嵌入层
        return self.text_model.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        # 设置文本模型的输出嵌入层

    @replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        pixel_values: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        image_embeds: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        # 前向传播方法,详细参数和返回值请参考模型输入和输出文档字符串

    def generate(
        pixel_values: Optional[torch.Tensor] = None,
        image_embeds_position_mask: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        image_embeds: Optional[torch.Tensor] = None,
        # 生成方法,用于生成文本和边界框,接受多种输入参数
        # 为了允许 `inputs` 参数(如 `GenerationMixin` 中所需)
        inputs = kwargs.pop("inputs", None)
        # 如果 `pixel_values` 不为 None,并且 `inputs` 也不为 None,则抛出 ValueError
        if pixel_values is not None and inputs is not None:
            raise ValueError(
                f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
                f"Make sure to either pass `inputs` or pixel_values=..."
        # 如果 `pixel_values` 为 None 且 `inputs` 不为 None,则将 `pixel_values` 设置为 `inputs`
        if pixel_values is None and inputs is not None:
            pixel_values = inputs

        # 如果 `image_embeds` 为 None,则进行以下操作
        if image_embeds is None:
            # 使用 `self.vision_model` 处理 `pixel_values` 得到视觉模型的输出
            vision_model_output = self.vision_model(pixel_values)
            # 将整个 `last_hidden_state` 通过 `post_layernorm` 而不是只使用 `pooled_output`
            image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
            # 对特征进行归一化处理
            image_embeds = nn.functional.normalize(image_embeds, dim=-1)
            # 将图像嵌入向量转换为文本嵌入向量
            image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)

        # 使用 `self.text_model` 生成文本输出
        output = self.text_model.generate(

        # 返回生成的输出结果
        return output
