Transformers-源码解析-六十九-

Transformers 源码解析(六十九)

.\models\mamba\__init__.py

# 版权声明和许可证信息
# 版权 2024 年 HuggingFace 团队保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)授权;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,
# 不提供任何形式的担保或条件,无论是明示的还是默示的。
# 有关特定语言的权限,请参阅许可证。

# 导入所需的类型检查模块
from typing import TYPE_CHECKING

# 导入工具函数和异常类
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
)

# 定义模块的导入结构
_import_structure = {
    "configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"],
}

# 检查是否可以导入 Torch,如果不能则引发 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果能导入 Torch,则添加以下模块到导入结构中
    _import_structure["modeling_mamba"] = [
        "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
        "MambaForCausalLM",
        "MambaModel",
        "MambaPreTrainedModel",
    ]

# 如果当前是类型检查模式
if TYPE_CHECKING:
    # 导入配置相关的类型
    from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig

    # 再次检查 Torch 是否可用,如果不可用则忽略
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入建模相关的类型
        from .modeling_mamba import (
            MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
            MambaForCausalLM,
            MambaModel,
            MambaPreTrainedModel,
        )

# 如果不是类型检查模式
else:
    # 导入 sys 模块用于注册当前模块
    import sys

    # 将当前模块替换为延迟加载模块 _LazyModule
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\marian\configuration_marian.py

# coding=utf-8
# Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Marian model configuration
"""
# 从 collections 模块导入 OrderedDict 类
from collections import OrderedDict
# 从 typing 模块导入 Any, Mapping, Optional 类型
from typing import Any, Mapping, Optional

# 从 transformers 包中导入 PreTrainedTokenizer 类
from ... import PreTrainedTokenizer
# 从 transformers.configuration_utils 中导入 PretrainedConfig 类
from ...configuration_utils import PretrainedConfig
# 从 transformers.onnx 中导入 OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast 类
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
# 从 transformers.onnx.utils 中导入 compute_effective_axis_dimension 函数
from ...onnx.utils import compute_effective_axis_dimension
# 从 transformers.utils 中导入 TensorType, is_torch_available, logging 函数
from ...utils import TensorType, is_torch_available, logging

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 定义 MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP 字典,映射模型名称到配置文件 URL
MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "Helsinki-NLP/opus-mt-en-de": "https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/config.json",
    # 查看所有 Marian 模型的链接:https://huggingface.co/models?filter=marian
}

# 定义 MarianConfig 类,继承自 PretrainedConfig 类
class MarianConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MarianModel`]. It is used to instantiate an
    Marian model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the Marian
    [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Examples:

    ```
    >>> from transformers import MarianModel, MarianConfig

    >>> # Initializing a Marian Helsinki-NLP/opus-mt-en-de style configuration
    >>> configuration = MarianConfig()

    >>> # Initializing a model from the Helsinki-NLP/opus-mt-en-de style configuration
    >>> model = MarianModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """

    # 模型类型为 "marian"
    model_type = "marian"
    # 推理阶段忽略的键列表为 ["past_key_values"]
    keys_to_ignore_at_inference = ["past_key_values"]
    # 属性映射字典,将 num_attention_heads 映射为 encoder_attention_heads,hidden_size 映射为 d_model
    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
    def __init__(
        self,
        vocab_size=58101,
        decoder_vocab_size=None,
        max_position_embeddings=1024,
        encoder_layers=12,
        encoder_ffn_dim=4096,
        encoder_attention_heads=16,
        decoder_layers=12,
        decoder_ffn_dim=4096,
        decoder_attention_heads=16,
        encoder_layerdrop=0.0,
        decoder_layerdrop=0.0,
        use_cache=True,
        is_encoder_decoder=True,
        activation_function="gelu",
        d_model=1024,
        dropout=0.1,
        attention_dropout=0.0,
        activation_dropout=0.0,
        init_std=0.02,
        decoder_start_token_id=58100,
        scale_embedding=False,
        pad_token_id=58100,
        eos_token_id=0,
        forced_eos_token_id=0,
        share_encoder_decoder_embeddings=True,
        **kwargs,
    ):
        # 初始化方法,设置模型的各种参数和选项
        self.vocab_size = vocab_size
        self.decoder_vocab_size = decoder_vocab_size or vocab_size  # 如果未指定解码器词汇大小,则与编码器相同
        self.max_position_embeddings = max_position_embeddings  # 最大位置嵌入数
        self.d_model = d_model  # 模型维度
        self.encoder_ffn_dim = encoder_ffn_dim  # 编码器中全连接层的维度
        self.encoder_layers = encoder_layers  # 编码器层数
        self.encoder_attention_heads = encoder_attention_heads  # 编码器注意力头数
        self.decoder_ffn_dim = decoder_ffn_dim  # 解码器中全连接层的维度
        self.decoder_layers = decoder_layers  # 解码器层数
        self.decoder_attention_heads = decoder_attention_heads  # 解码器注意力头数
        self.dropout = dropout  # 总体dropout率
        self.attention_dropout = attention_dropout  # 注意力机制中的dropout率
        self.activation_dropout = activation_dropout  # 激活函数中的dropout率
        self.activation_function = activation_function  # 激活函数类型,默认为GELU
        self.init_std = init_std  # 参数初始化标准差
        self.encoder_layerdrop = encoder_layerdrop  # 编码器层级dropout率
        self.decoder_layerdrop = decoder_layerdrop  # 解码器层级dropout率
        self.use_cache = use_cache  # 是否使用缓存
        self.num_hidden_layers = encoder_layers  # 隐藏层的数量等于编码器层数
        self.scale_embedding = scale_embedding  # 如果为True,则嵌入将按sqrt(d_model)进行缩放
        self.share_encoder_decoder_embeddings = share_encoder_decoder_embeddings  # 是否共享编码器和解码器的嵌入
        super().__init__(
            pad_token_id=pad_token_id,  # 用于填充的标记ID
            eos_token_id=eos_token_id,  # EOS(结束)标记ID
            is_encoder_decoder=is_encoder_decoder,  # 是否是编码-解码模型
            decoder_start_token_id=decoder_start_token_id,  # 解码器起始标记ID
            forced_eos_token_id=forced_eos_token_id,  # 强制EOS(结束)标记ID
            **kwargs,
        )
class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
    @property
    # 从 transformers.models.bart.configuration_bart.BartOnnxConfig.inputs 复制而来,定义了模型输入的结构
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        # 根据任务类型配置输入结构
        if self.task in ["default", "seq2seq-lm"]:
            # 对于默认或序列到序列语言模型任务,设置常规输入
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )

            if self.use_past:
                # 如果使用过去信息,则调整decoder的输入结构
                common_inputs["decoder_input_ids"] = {0: "batch"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
            else:
                common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}

            if self.use_past:
                # 如果使用过去信息,填充对应的键值
                self.fill_with_past_key_values_(common_inputs, direction="inputs")
        elif self.task == "causal-lm":
            # 处理因果语言模型任务,暂时留下TODO
            # 目前仅设置常规的输入结构
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                ]
            )
            if self.use_past:
                # 如果使用过去信息,根据编码器层数设置键值对
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        else:
            # 处理其他任务类型,设置完整的输入结构,包括编码器和解码器
            common_inputs = OrderedDict(
                [
                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
                    ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
                    ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
                ]
            )

        return common_inputs

    @property
    # 从 transformers.models.bart.configuration_bart.BartOnnxConfig.outputs 复制而来,定义了模型输出的结构
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        # 根据任务类型配置输出结构
        if self.task in ["default", "seq2seq-lm"]:
            # 对于默认或序列到序列语言模型任务,使用超类的输出结构
            common_outputs = super().outputs
        else:
            # 对于其他任务类型,使用带过去信息的超类的输出结构
            common_outputs = super(OnnxConfigWithPast, self).outputs
            if self.use_past:
                # 如果使用过去信息,根据编码器层数设置输出结构
                num_encoder_layers, _ = self.num_layers
                for i in range(num_encoder_layers):
                    common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
                    common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
        return common_outputs
    # 定义一个私有方法 `_generate_dummy_inputs_for_default_and_seq2seq_lm`
    # 该方法用于生成用于默认语言模型和序列到序列语言模型的虚拟输入数据
    # 参数说明:
    #   - self: 表示类本身,即类的实例对象
    #   - tokenizer: 预训练分词器对象,用于处理文本数据
    #   - batch_size: 批次大小,控制生成的虚拟数据批次的数量
    #   - seq_length: 序列长度,控制每个生成的虚拟数据序列的长度
    #   - is_pair: 布尔值,表示是否生成成对的输入数据(例如用于序列到序列模型)
    #   - framework: 可选参数,指定生成数据的框架类型,如TensorFlow或PyTorch等
        ) -> Mapping[str, Any]:
        # 生成编码器输入数据的虚拟数据,用于模型输入
        encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
            tokenizer, batch_size, seq_length, is_pair, framework
        )

        # 生成解码器输入数据的虚拟数据
        # 如果使用过去状态(self.use_past=True),解码器序列长度为1,否则与编码器序列长度相同
        decoder_seq_length = seq_length if not self.use_past else 1
        decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
            tokenizer, batch_size, decoder_seq_length, is_pair, framework
        )
        # 将解码器输入数据的名称修改为以 "decoder_" 开头的形式
        decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
        # 合并编码器和解码器的输入数据
        common_inputs = dict(**encoder_inputs, **decoder_inputs)

        if self.use_past:
            # 检查是否安装了 PyTorch,否则抛出错误
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            # 获取批量大小和编码器序列长度
            batch, encoder_seq_length = common_inputs["input_ids"].shape
            # 获取解码器输入序列长度
            decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
            # 获取注意力头的数量,包括编码器和解码器
            num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
            # 定义编码器和解码器形状
            encoder_shape = (
                batch,
                num_encoder_attention_heads,
                encoder_seq_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )
            decoder_past_length = decoder_seq_length + 3
            decoder_shape = (
                batch,
                num_decoder_attention_heads,
                decoder_past_length,
                self._config.hidden_size // num_decoder_attention_heads,
            )

            # 在解码器注意力掩码后面添加一个全1张量,以扩展过去的键
            common_inputs["decoder_attention_mask"] = torch.cat(
                [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
            )

            common_inputs["past_key_values"] = []
            # 根据模型配置中的编码器和解码器层数,生成过去的键值对
            num_encoder_layers, num_decoder_layers = self.num_layers
            min_num_layers = min(num_encoder_layers, num_decoder_layers)
            max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
            remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"

            # 为每一层生成初始的过去键值对
            for _ in range(min_num_layers):
                common_inputs["past_key_values"].append(
                    (
                        torch.zeros(decoder_shape),
                        torch.zeros(decoder_shape),
                        torch.zeros(encoder_shape),
                        torch.zeros(encoder_shape),
                    )
                )
            # TODO: test this.
            # 根据剩余的层数,继续生成过去的键值对
            shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
            for _ in range(min_num_layers, max_num_layers):
                common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
        # 返回整合了所有输入数据的字典
        return common_inputs
    # 为因果语言建模生成虚拟输入数据,返回一个包含各种输入的字典
    def _generate_dummy_inputs_for_causal_lm(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 调用_encoder_and_decoder生成虚拟输入的共同部分
        common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
            tokenizer, batch_size, seq_length, is_pair, framework
        )

        # 如果使用过去的键(past_key_values)
        if self.use_past:
            # 检查是否安装了PyTorch,否则抛出错误
            if not is_torch_available():
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
            else:
                import torch
            # 获取输入ids的批次大小和序列长度
            batch, seqlen = common_inputs["input_ids"].shape
            # 为past_key_values设置一个不同于输入ids长度的长度
            past_key_values_length = seqlen + 2
            # 获取编码器层数和注意力头数
            num_encoder_layers, _ = self.num_layers
            num_encoder_attention_heads, _ = self.num_attention_heads
            # 定义past_key_values的形状
            past_shape = (
                batch,
                num_encoder_attention_heads,
                past_key_values_length,
                self._config.hidden_size // num_encoder_attention_heads,
            )

            # 获取mask的数据类型
            mask_dtype = common_inputs["attention_mask"].dtype
            # 扩展attention_mask的长度以包括past_key_values
            common_inputs["attention_mask"] = torch.cat(
                [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
            )
            # 初始化past_key_values为零张量的列表
            common_inputs["past_key_values"] = [
                (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
            ]
        # 返回生成的虚拟输入字典
        return common_inputs

    # 从BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering复制而来
    # 由于Marian模型没有序列分类或问答头,我们重命名了这个函数
    def _generate_dummy_inputs_for_encoder_and_decoder(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
    # 定义方法 generate_dummy_inputs,生成模型的虚拟输入数据
    def generate_dummy_inputs(
        self,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = -1,
        seq_length: int = -1,
        is_pair: bool = False,
        framework: Optional[TensorType] = None,
    ) -> Mapping[str, Any]:
        # 如果任务类型为 "default" 或者 "seq2seq-lm"
        if self.task in ["default", "seq2seq-lm"]:
            # 调用内部方法生成默认和序列到序列语言模型的虚拟输入数据
            common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        else:
            # 调用内部方法生成因果语言模型的虚拟输入数据
            common_inputs = self._generate_dummy_inputs_for_causal_lm(
                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
            )
        
        # 返回生成的虚拟输入数据
        return common_inputs

    # 定义方法 _flatten_past_key_values_,用于处理过去键值对的展平操作
    def _flatten_past_key_values_(self, flattened_output, name, idx, t):
        # 如果任务类型为 "default" 或者 "seq2seq-lm",则调用父类方法展平过去键值对
        if self.task in ["default", "seq2seq-lm"]:
            flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
        else:
            # 否则,调用具有过去信息的序列到序列配置类的父类方法展平过去键值对
            flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
                flattened_output, name, idx, t
            )

    # 定义属性 atol_for_validation,返回验证过程中的绝对误差容限
    @property
    def atol_for_validation(self) -> float:
        return 1e-4

.\models\marian\convert_marian_tatoeba_to_pytorch.py

# 导入必要的模块和库
import argparse  # 解析命令行参数的库
import datetime  # 处理日期和时间的库
import json  # 处理 JSON 数据的库
import os  # 提供与操作系统交互的功能
import re  # 提供正则表达式操作的库
from pathlib import Path  # 提供处理文件路径的类
from typing import Tuple  # 提供类型提示支持

import yaml  # 处理 YAML 格式的库
from tqdm import tqdm  # 提供进度条功能

# 从 transformers 库中导入相关模块和函数
from transformers.models.marian.convert_marian_to_pytorch import (
    FRONT_MATTER_TEMPLATE,  # 导入一个变量:Marian 模型转换时使用的前置模板
    convert,  # 导入一个函数:用于转换模型
    convert_opus_name_to_hf_name,  # 导入一个函数:用于转换 OPUS 模型名称为 HF 模型名称
    download_and_unzip,  # 导入一个函数:用于下载并解压文件
    get_system_metadata,  # 导入一个函数:获取系统元数据
)

# 设置默认的仓库名称和模型目录路径
DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")

# 定义语言代码信息的 URL
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
# 定义 ISO 语言代码的 URL
ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
# 定义存储 ISO 语言代码的本地路径
ISO_PATH = "lang_code_data/iso-639-3.csv"
# 定义存储语言代码信息的本地路径
LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
# 定义 Tatoeba 模型下载 URL
TATOEBA_MODELS_URL = "https://object.pouta.csc.fi/Tatoeba-MT-models"


class TatoebaConverter:
    """
    Convert Tatoeba-Challenge models to huggingface format.

    Steps:

        1. Convert numpy state dict to hf format (same code as OPUS-MT-Train conversion).
        2. Rename opus model to huggingface format. This means replace each alpha3 code with an alpha2 code if a unique
           one exists. e.g. aav-eng -> aav-en, heb-eng -> he-en
        3. Select the best model for a particular pair, parse the yml for it and write a model card. By default the
           best model is the one listed first in released-model-results, but it's also possible to specify the most
           recent one.
    """

    def __init__(self, save_dir="marian_converted"):
        # 检查默认仓库是否存在,否则给出错误提示
        assert Path(DEFAULT_REPO).exists(), "need git clone git@github.com:Helsinki-NLP/Tatoeba-Challenge.git"

        # 下载语言信息
        self.download_lang_info()

        # 加载模型结果数据
        self.model_results = json.load(open("Tatoeba-Challenge/models/released-model-results.json"))

        # 初始化 alpha3 到 alpha2 映射字典
        self.alpha3_to_alpha2 = {}
        # 从 ISO 文件中读取 alpha3 到 alpha2 的映射关系
        for line in open(ISO_PATH):
            parts = line.split("\t")
            if len(parts[0]) == 3 and len(parts[3]) == 2:
                self.alpha3_to_alpha2[parts[0]] = parts[3]

        # 从语言代码文件中读取 alpha3 到 alpha2 的映射关系
        for line in open(LANG_CODE_PATH):
            parts = line.split(",")
            if len(parts[0]) == 3 and len(parts[1]) == 2:
                self.alpha3_to_alpha2[parts[0]] = parts[1]

        # 设置模型卡片输出目录
        self.model_card_dir = Path(save_dir)

        # 初始化标签到名称的映射字典
        self.tag2name = {}
        # 从 GROUP_MEMBERS 中获取标签和名称的映射关系
        for key, value in GROUP_MEMBERS.items():
            self.tag2name[key] = value[0]
    # 将给定的 Tatoeba IDs 转换为模型元数据列表,如果 dry_run 为 True,则仅进行试运行
    def convert_models(self, tatoeba_ids, dry_run=False):
        # 解析每个 Tatoeba ID 对应的模型元数据,形成列表
        models_to_convert = [self.parse_metadata(x) for x in tatoeba_ids]
        # 设置保存目录为 "marian_ckpt"
        save_dir = Path("marian_ckpt")
        # 设置目标目录为指定的模型卡片目录,并确保目录存在
        dest_dir = Path(self.model_card_dir)
        dest_dir.mkdir(exist_ok=True)
        # 遍历待转换的模型元数据列表,显示进度条
        for model in tqdm(models_to_convert):  # k, prepro, download, test_set_url in tqdm(model_list):
            # 如果模型的预处理步骤中不包含 "SentencePiece",则跳过转换
            if "SentencePiece" not in model["pre-processing"]:
                print(f"Skipping {model['release']} because it doesn't appear to use SentencePiece")
                continue
            # 如果保存目录中不存在当前模型的文件夹,则下载并解压对应的模型文件
            if not os.path.exists(save_dir / model["_name"]):
                download_and_unzip(f"{TATOEBA_MODELS_URL}/{model['release']}", save_dir / model["_name"])
            # 将模型从 Marian 转换为 PyTorch 格式,并保存到目标目录
            # 模型名称转换为适合 HF 格式的名称
            opus_language_groups_to_hf = convert_opus_name_to_hf_name
            pair_name = opus_language_groups_to_hf(model["_name"])
            convert(save_dir / model["_name"], dest_dir / f"opus-mt-{pair_name}")
            # 将模型的元数据写入模型卡片,如果 dry_run 为 True,则仅进行试运行
            self.write_model_card(model, dry_run=dry_run)

    # 根据组名扩展为其成员的两字母代码列表
    def expand_group_to_two_letter_codes(self, grp_name):
        return [self.alpha3_to_alpha2.get(x, x) for x in GROUP_MEMBERS[grp_name][1]]

    # 判断给定的代码和名称是否代表一个语言组
    def is_group(self, code, name):
        return "languages" in name or len(GROUP_MEMBERS.get(code, [])) > 1

    # 根据代码和名称获取标签列表
    def get_tags(self, code, name):
        if len(code) == 2:
            # 对于两字母代码,名称中不应包含 "languages"
            assert "languages" not in name, f"{code}: {name}"
            return [code]
        elif self.is_group(code, name):
            # 如果是语言组,则将组成员的两字母代码列表返回,并加入原始代码
            group = self.expand_group_to_two_letter_codes(code)
            group.append(code)
            return group
        else:  # zho-> zh
            # 对于三字母单一语言代码,输出警告信息
            print(f"Three letter monolingual code: {code}")
            return [code]

    # 解析语言代码,将源语言和目标语言转换为标签列表
    def resolve_lang_code(self, src, tgt) -> Tuple[str, str]:
        src_tags = self.get_tags(src, self.tag2name[src])
        tgt_tags = self.get_tags(tgt, self.tag2name[tgt])
        return src_tags, tgt_tags

    # 从模型名称中获取模型类型信息,返回一个字典
    @staticmethod
    def model_type_info_from_model_name(name):
        info = {"_has_backtranslated_data": False}
        if "1m" in name:
            info["_data_per_pair"] = str(1e6)
        if "2m" in name:
            info["_data_per_pair"] = str(2e6)
        if "4m" in name:
            info["_data_per_pair"] = str(4e6)
        if "+bt" in name:
            info["_has_backtranslated_data"] = True
        if "tuned4" in name:
            info["_tuned"] = re.search(r"tuned4[^-]+", name).group()
        return info
        content = (
            f"""
* model: {model_dict['modeltype']}
* source language code{src_multilingual*'s'}: {', '.join(a2_src_tags)}
* target language code{tgt_multilingual*'s'}: {', '.join(a2_tgt_tags)}
* dataset: opus {backtranslated_data}
* release date: {model_dict['release-date']}
* pre-processing: {model_dict['pre-processing']}
"""
            + multilingual_data
            + tuned
            + download
            + langtoken
            + datainfo
            + testset
            + testscores
            + scorestable
        )
        # 构建模型卡片的内容,包括模型类型、源语言和目标语言代码、数据集信息等

        content = FRONT_MATTER_TEMPLATE.format(lang_tags) + extra_markdown + content
        # 将模型卡片的前置模板和额外的 markdown 内容插入到卡片内容开头

        items = "\n".join([f"* {k}: {v}" for k, v in metadata.items()])
        # 构建元数据字典的字符串表示,每个键值对形如 "* key: value"

        sec3 = "\n### System Info: \n" + items
        # 构建系统信息部分的标题和元数据内容

        content += sec3
        # 将系统信息部分添加到模型卡片的内容末尾

        if dry_run:
            # 如果 dry_run 为 True,则打印内容和元数据并返回,不执行后续操作
            print("CONTENT:")
            print(content)
            print("METADATA:")
            print(metadata)
            return

        sub_dir = self.model_card_dir / model_dict["_hf_model_id"]
        sub_dir.mkdir(exist_ok=True)
        # 创建模型卡片的存储子目录,如果不存在则创建

        dest = sub_dir / "README.md"
        dest.open("w").write(content)
        # 将构建好的模型卡片内容写入 README.md 文件中

        for k, v in metadata.items():
            if isinstance(v, datetime.date):
                metadata[k] = datetime.datetime.strftime(v, "%Y-%m-%d")
        # 将元数据中的日期对象转换成字符串形式 "%Y-%m-%d"

        with open(sub_dir / "metadata.json", "w", encoding="utf-8") as writeobj:
            json.dump(metadata, writeobj)
        # 将元数据以 JSON 格式写入 metadata.json 文件中

    def download_lang_info(self):
        Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True)
        # 确保存储语言代码文件的目录存在,如果不存在则创建

        import wget
        # 导入 wget 模块用于下载文件

        if not os.path.exists(ISO_PATH):
            wget.download(ISO_URL, ISO_PATH)
        # 如果 ISO 文件不存在,则使用 wget 下载 ISO 文件

        if not os.path.exists(LANG_CODE_PATH):
            wget.download(LANG_CODE_URL, LANG_CODE_PATH)
        # 如果语言代码文件不存在,则使用 wget 下载语言代码文件
    # 解析模型元数据的方法,给定模型名称、存储库路径和解析方法
    def parse_metadata(self, model_name, repo_path=DEFAULT_MODEL_DIR, method="best"):
        # 构建模型在存储库中的路径
        p = Path(repo_path) / model_name

        # 定义一个函数,从URL中提取文件名(不含扩展名)
        def url_to_name(url):
            return url.split("/")[-1].split(".")[0]

        # 如果模型名称不在模型结果中,则模型结果不明确,使用最新的模型
        if model_name not in self.model_results:
            method = "newest"

        # 如果解析方法为“best”
        if method == "best":
            # 根据下载链接提取模型文件名列表
            results = [url_to_name(model["download"]) for model in self.model_results[model_name]]
            # 在路径p中查找所有以".yml"结尾且名称在results列表中的文件
            ymls = [f for f in os.listdir(p) if f.endswith(".yml") and f[:-4] in results]
            # 根据results列表中模型文件名的顺序排序ymls列表
            ymls.sort(key=lambda x: results.index(x[:-4]))
            # 加载第一个符合条件的YAML文件的元数据
            metadata = yaml.safe_load(open(p / ymls[0]))
            # 更新元数据,添加模型类型信息
            metadata.update(self.model_type_info_from_model_name(ymls[0][:-4]))
        # 如果解析方法为“newest”
        elif method == "newest":
            # 找到所有以".yml"结尾的文件
            ymls = [f for f in os.listdir(p) if f.endswith(".yml")]
            # 按日期排序
            ymls.sort(
                key=lambda x: datetime.datetime.strptime(re.search(r"\d\d\d\d-\d\d?-\d\d?", x).group(), "%Y-%m-%d")
            )
            # 加载最新的YAML文件的元数据
            metadata = yaml.safe_load(open(p / ymls[-1]))
            # 更新元数据,添加模型类型信息
            metadata.update(self.model_type_info_from_model_name(ymls[-1][:-4]))
        else:
            # 抛出未实现的错误,指明不支持的解析方法
            raise NotImplementedError(f"Don't know argument method='{method}' to parse_metadata()")
        
        # 添加模型名称作为元数据的一个字段
        metadata["_name"] = model_name
        # 返回解析得到的元数据
        return metadata
GROUP_MEMBERS = {
    # 三字母代码 -> (语言组/语言名称, {成员...}
    # 如果语言在目标端,成员可以作为目标语言代码使用。
    # 如果语言在源端,它们可以在没有特殊代码的情况下被本地支持。
    "aav": ("Austro-Asiatic languages", {"hoc", "hoc_Latn", "kha", "khm", "khm_Latn", "mnw", "vie", "vie_Hani"}),
    "afa": (
        "Afro-Asiatic languages",
        {
            "acm",
            "afb",
            "amh",
            "apc",
            "ara",
            "arq",
            "ary",
            "arz",
            "hau_Latn",
            "heb",
            "kab",
            "mlt",
            "rif_Latn",
            "shy_Latn",
            "som",
            "thv",
            "tir",
        },
    ),
    "afr": ("Afrikaans", {"afr"}),
    "alv": (
        "Atlantic-Congo languages",
        {
            "ewe",
            "fuc",
            "fuv",
            "ibo",
            "kin",
            "lin",
            "lug",
            "nya",
            "run",
            "sag",
            "sna",
            "swh",
            "toi_Latn",
            "tso",
            "umb",
            "wol",
            "xho",
            "yor",
            "zul",
        },
    ),
    "ara": ("Arabic", {"afb", "apc", "apc_Latn", "ara", "ara_Latn", "arq", "arq_Latn", "arz"}),
    "art": (
        "Artificial languages",
        {
            "afh_Latn",
            "avk_Latn",
            "dws_Latn",
            "epo",
            "ido",
            "ido_Latn",
            "ile_Latn",
            "ina_Latn",
            "jbo",
            "jbo_Cyrl",
            "jbo_Latn",
            "ldn_Latn",
            "lfn_Cyrl",
            "lfn_Latn",
            "nov_Latn",
            "qya",
            "qya_Latn",
            "sjn_Latn",
            "tlh_Latn",
            "tzl",
            "tzl_Latn",
            "vol_Latn",
        },
    ),
    "aze": ("Azerbaijani", {"aze_Latn"}),
    "bat": ("Baltic languages", {"lit", "lav", "prg_Latn", "ltg", "sgs"}),
    "bel": ("Belarusian", {"bel", "bel_Latn"}),
    "ben": ("Bengali", {"ben"}),
    "bnt": (
        "Bantu languages",
        {"kin", "lin", "lug", "nya", "run", "sna", "swh", "toi_Latn", "tso", "umb", "xho", "zul"},
    ),
    "bul": ("Bulgarian", {"bul", "bul_Latn"}),
    "cat": ("Catalan", {"cat"}),
    "cau": ("Caucasian languages", {"abk", "kat", "che", "ady"}),
    "ccs": ("South Caucasian languages", {"kat"}),
    "ceb": ("Cebuano", {"ceb"}),
    "cel": ("Celtic languages", {"gla", "gle", "bre", "cor", "glv", "cym"}),
    "ces": ("Czech", {"ces"}),
    "cpf": ("Creoles and pidgins, French‑based", {"gcf_Latn", "hat", "mfe"}),
    "cpp": (
        "Creoles and pidgins, Portuguese-based",
        {"zsm_Latn", "ind", "pap", "min", "tmw_Latn", "max_Latn", "zlm_Latn"},
    ),
    "cus": ("Cushitic languages", {"som"}),
    "dan": ("Danish", {"dan"}),
    "deu": ("German", {"deu"}),
}
    "dra": ("Dravidian languages", {"tam", "kan", "mal", "tel"}),  # 定义键为"dra"的元组,包含语言族名称和语言代码集合
    "ell": ("Modern Greek (1453-)", {"ell"}),  # 定义键为"ell"的元组,包含语言名称和单一的语言代码集合
    "eng": ("English", {"eng"}),  # 定义键为"eng"的元组,包含语言名称和单一的语言代码集合
    "epo": ("Esperanto", {"epo"}),  # 定义键为"epo"的元组,包含语言名称和单一的语言代码集合
    "est": ("Estonian", {"est"}),  # 定义键为"est"的元组,包含语言名称和单一的语言代码集合
    "euq": ("Basque (family)", {"eus"}),  # 定义键为"euq"的元组,包含语言家族名称和单一的语言代码集合
    "eus": ("Basque", {"eus"}),  # 定义键为"eus"的元组,包含语言名称和单一的语言代码集合
    "fin": ("Finnish", {"fin"}),  # 定义键为"fin"的元组,包含语言名称和单一的语言代码集合
    "fiu": (  # 定义键为"fiu"的元组,包含语言家族名称和语言代码集合
        "Finno-Ugrian languages",
        {
            "est",  # 爱沙尼亚语代码
            "fin",  # 芬兰语代码
            "fkv_Latn",  # 科瓦林语的拉丁字母代码
            "hun",  # 匈牙利语代码
            "izh",  # 苏里奥语代码
            "kpv",  # 科米语代码
            "krl",  # 卡累利阿语代码
            "liv_Latn",  # 利沃尼亚语的拉丁字母代码
            "mdf",  # 莫克沙语代码
            "mhr",  # 马里语代码
            "myv",  # 厄尔茨亚语代码
            "sma",  # 南萨米语代码
            "sme",  # 北萨米语代码
            "udm",  # 乌德穆尔特语代码
            "vep",  # 维普尔语代码
            "vro",  # 维兰语代码
        },
    ),
    "fra": ("French", {"fra"}),  # 定义键为"fra"的元组,包含语言名称和单一的语言代码集合
    "gem": (  # 定义键为"gem"的元组,包含语言家族名称和语言代码集合
        "Germanic languages",
        {
            "afr",  # 南非荷兰语代码
            "ang_Latn",  # 古英语的拉丁字母代码
            "dan",  # 丹麦语代码
            "deu",  # 德语代码
            "eng",  # 英语代码
            "enm_Latn",  # 中古英语的拉丁字母代码
            "fao",  # 法罗语代码
            "frr",  # 北弗里西语代码
            "fry",  # 弗里西语代码
            "gos",  # 弗兰克-萨克逊语代码
            "got_Goth",  # 哥特语代码
            "gsw",  # 瑞士德语代码
            "isl",  # 冰岛语代码
            "ksh",  # 科隆语代码
            "ltz",  # 卢森堡语代码
            "nds",  # 下地德语代码
            "nld",  # 荷兰语代码
            "nno",  # 新挪威语代码
            "nob",  # 书面挪威语代码
            "nob_Hebr",  # 书面挪威语的希伯来字母代码
            "non_Latn",  # 古挪威语的非拉丁字母代码
            "pdc",  # 宾夕法尼亚德语代码
            "sco",  # 苏格兰语代码
            "stq",  # 萨特弗里斯兰语代码
            "swe",  # 瑞典语代码
            "swg",  # 沃特兰弗兰克语代码
            "yid",  # 意第绪语代码
        },
    ),
    "gle": ("Irish", {"gle"}),  # 定义键为"gle"的元组,包含语言名称和单一的语言代码集合
    "glg": ("Galician", {"glg"}),  # 定义键为"glg"的元组,包含语言名称和单一的语言代码集合
    "gmq": (  # 定义键为"gmq"的元组,包含语言家族名称和语言代码集合
        "North Germanic languages",
        {
            "dan",  # 丹麦语代码
            "nob",  # 书面挪威语代码
            "nob_Hebr",  # 书面挪威语的希伯来字母代码
            "swe",  # 瑞典语代码
            "isl",  # 冰岛语代码
            "nno",  # 新挪威语代码
            "non_Latn",  # 古挪威语的非拉丁字母代码
            "fao",  # 法罗语代码
        },
    ),
    "gmw": (  # 定义键为"gmw"的元组,包含语言家族名称和语言代码集合
        "West Germanic languages",
        {
            "afr",  # 南非荷兰语代码
            "ang_Latn",  # 古英语的拉丁字母代码
            "deu",  # 德语代码
            "eng",  # 英语代码
            "enm_Latn",  # 中古英语的拉丁字母代码
            "frr",  # 北弗里西语代码
            "fry",  # 弗里西语代码
            "gos",  # 弗兰克-萨克逊语代码
            "gsw",  # 瑞士德语代码
            "ksh",  # 科隆语代码
            "ltz",  # 卢森堡语代码
            "nds",  # 下地德语代码
            "nld",  # 荷兰语代码
            "pdc",  # 宾夕法尼亚德语代码
            "sco",  # 苏格兰语代码
            "stq",  # 萨特弗里斯兰语代码
            "swg",  # 沃特兰弗兰克语代码
            "yid",  # 意第绪语代码
        },
    ),
    "grk": ("Greek languages", {"grc_Grek", "ell"}),  # 定义键为"grk"的元组,包含语言族名称和语言代码集合
    "hbs": ("Serbo-Croatian", {"hrv", "srp_Cyrl", "bos_Latn", "srp_Latn"}),  # 定义键为"hbs"的元组,包含语言名称和语言代码集合
    "heb": ("Hebrew", {"heb"}),  # 定义键为"heb"的元组,包含语言名称和单一的语言代码集合
    "hin": ("Hindi", {"hin"}),  # 定义键为"hin"的元组,包
    "inc": (
        "Indic languages",  # "inc" 键对应的值是一个元组,包含了 "Indic languages" 和一个集合
        {
            "asm",          # 集合中包含 "asm",代表阿萨姆语
            "awa",          # 集合中包含 "awa",代表阿瓦德语
            "ben",          # 集合中包含 "ben",代表孟加拉语
            "bho",          # 集合中包含 "bho",代表博杰普尔语
            "gom",          # 集合中包含 "gom",代表孔卡尼语
            "guj",          # 集合中包含 "guj",代表古吉拉特语
            "hif_Latn",     # 集合中包含 "hif_Latn",代表斐济印地语(拉丁字母)
            "hin",          # 集合中包含 "hin",代表印地语
            "mai",          # 集合中包含 "mai",代表迈蒂利语
            "mar",          # 集合中包含 "mar",代表马拉地语
            "npi",          # 集合中包含 "npi",代表尼泊尔文
            "ori",          # 集合中包含 "ori",代表奥里亚语
            "pan_Guru",     # 集合中包含 "pan_Guru",代表旁遮普语(古鲁穆基字母)
            "pnb",          # 集合中包含 "pnb",代表西旁遮普语
            "rom",          # 集合中包含 "rom",代表罗姆语
            "san_Deva",     # 集合中包含 "san_Deva",代表梵语(天城文)
            "sin",          # 集合中包含 "sin",代表僧伽罗语
            "snd_Arab",     # 集合中包含 "snd_Arab",代表信德语(阿拉伯字母)
            "urd",          # 集合中包含 "urd",代表乌尔都语
        },
    ),
    "ine": (
        "Indo-European languages",  # 定义键值对 "ine",表示印欧语系语言,值为元组
        {
            "afr", "afr_Arab", "aln", "ang_Latn", "arg", "asm", "ast", "awa", "bel",  # 定义一个包含多个字符串的集合,表示不同印欧语系语言的标识符
            "bel_Latn", "ben", "bho", "bjn", "bos_Latn", "bre", "bul", "bul_Latn", "cat",
            "ces", "cor", "cos", "csb_Latn", "cym", "dan", "deu", "dsb", "egl", "ell",
            "eng", "enm_Latn", "ext", "fao", "fra", "frm_Latn", "frr", "fry", "gcf_Latn",
            "gla", "gle", "glg", "glv", "gom", "gos", "got_Goth", "grc_Grek", "gsw",
            "guj", "hat", "hif_Latn", "hin", "hrv", "hsb", "hye", "hye_Latn", "ind",
            "isl", "ita", "jdt_Cyrl", "ksh", "kur_Arab", "kur_Latn", "lad", "lad_Latn",
            "lat_Grek", "lat_Latn", "lav", "lij", "lit", "lld_Latn", "lmo", "ltg", "ltz",
            "mai", "mar", "max_Latn", "mfe", "min", "mkd", "mwl", "nds", "nld", "nno",
            "nob", "nob_Hebr", "non_Latn", "npi", "oci", "ori", "orv_Cyrl", "oss",
            "pan_Guru", "pap", "pcd", "pdc", "pes", "pes_Latn", "pes_Thaa", "pms",
            "pnb", "pol", "por", "prg_Latn", "pus", "roh", "rom", "ron", "rue", "rus",
            "rus_Latn", "san_Deva", "scn", "sco", "sgs", "sin", "slv", "snd_Arab",
            "spa", "sqi", "srd", "srp_Cyrl", "srp_Latn", "stq", "swe", "swg", "tgk_Cyrl",
            "tly_Latn", "tmw_Latn", "ukr", "urd", "vec", "wln", "yid", "zlm_Latn",
            "zsm_Latn", "zza"
        },  # 这些字符串代表各种不同印欧语系语言的标识符
    ),
    "isl": ("Icelandic", {"isl"}),  # 定义键值对 "isl",表示冰岛语,值为包含字符串 "isl" 的集合
    "ita": ("Italian", {"ita"}),  # 定义键值对 "ita",表示意大利语,值为包含字符串 "ita" 的集合
    "itc": (
        "Italic languages",  # 键 'itc',代表意大利语族的语言
        {  # 值是一个集合,包含多个字符串,代表具体的语言码
            "arg",  # 阿拉贡语
            "ast",  # 阿斯图里亚斯语
            "bjn",  # 班亚尔语
            "cat",  # 加泰罗尼亚语
            "cos",  # 科西嘉语
            "egl",  # 艾米利安语
            "ext",  # 埃斯特雷马杜拉语
            "fra",  # 法语
            "frm_Latn",  # 中古法语(拉丁字母版)
            "gcf_Latn",  # 古典法罗语(拉丁字母版)
            "glg",  # 加利西亚语
            "hat",  # 海地克里奥尔语
            "ind",  # 印度尼西亚语
            "ita",  # 意大利语
            "lad",  # 罗马尼亚吉普赛语
            "lad_Latn",  # 罗马尼亚吉普赛语(拉丁字母版)
            "lat_Grek",  # 拉丁语(希腊字母版)
            "lat_Latn",  # 拉丁语(拉丁字母版)
            "lij",  # 利古里亚语
            "lld_Latn",  # 皮德蒙特语(拉丁字母版)
            "lmo",  # 伦巴第语
            "max_Latn",  # 马萨伊语(拉丁字母版)
            "mfe",  # 毛里求斯克里奥尔语
            "min",  # 明边语
            "mwl",  # 米兰德语
            "oci",  # 奥克语
            "pap",  # 比道语
            "pcd",  # 皮卡第语
            "pms",  # 皮埃蒙特语
            "por",  # 葡萄牙语
            "roh",  # 罗曼什语
            "ron",  # 罗马尼亚语
            "scn",  # 西西里语
            "spa",  # 西班牙语
            "srd",  # 萨丁语
            "tmw_Latn",  # 提姆西语(拉丁字母版)
            "vec",  # 威尼斯语
            "wln",  # 瓦隆语
            "zlm_Latn",  # 马来语(拉丁字母版)
            "zsm_Latn",  # 马来语(拉丁字母版)
        },
    ),
    "jpn": (  # 键 'jpn',代表日语
        "Japanese",  # 日语的全称
        {  # 值是一个集合,包含多个字符串,代表具体的日语方言或使用不同字母表的形式
            "jpn",  # 日语
            "jpn_Bopo",  # 日语(注音符号版)
            "jpn_Hang",  # 日语(朝鲜字母版)
            "jpn_Hani",  # 日语(汉字版)
            "jpn_Hira",  # 日语(平假名版)
            "jpn_Kana",  # 日语(假名版)
            "jpn_Latn",  # 日语(拉丁字母版)
            "jpn_Yiii",  # 日语(纳西字母版)
        },
    ),
    "jpx": (  # 键 'jpx',代表日语的家族
        "Japanese (family)",  # 日语的家族名
        {"jpn"},  # 包含日语
    ),
    "kat": (  # 键 'kat',代表格鲁吉亚语
        "Georgian",  # 格鲁吉亚语的全称
        {"kat"},  # 包含格鲁吉亚语
    ),
    "kor": (  # 键 'kor',代表韩语
        "Korean",  # 韩语的全称
        {  # 值是一个集合,包含多个字符串,代表具体的韩语方言或使用不同字母表的形式
            "kor_Hani",  # 韩语(汉字版)
            "kor_Hang",  # 韩语(朝鲜字母版)
            "kor_Latn",  # 韩语(拉丁字母版)
            "kor",  # 韩语
        },
    ),
    "lav": (  # 键 'lav',代表拉脱维亚语
        "Latvian",  # 拉脱维亚语的全称
        {"lav"},  # 包含拉脱维亚语
    ),
    "lit": (  # 键 'lit',代表立陶宛语
        "Lithuanian",  # 立陶宛语的全称
        {"lit"},  # 包含立陶宛语
    ),
    "mkd": (  # 键 'mkd',代表马其顿语
        "Macedonian",  # 马其顿语的全称
        {"mkd"},  # 包含马其顿语
    ),
    "mkh": (  # 键 'mkh',代表蒙高—湄语族
        "Mon-Khmer languages",  # 蒙高—湄语族的全称
        {  # 值是一个集合,包含多个字符串,代表具体的蒙高—湄语族语言或使用不同字母表的形式
            "vie_Hani",  # 越南语(汉字版)
            "mnw",  # 孟语
            "vie",  # 越南语
            "kha",  # 卡西语
            "khm_Latn",  # 高棉语(拉丁字母版)
            "khm",  # 高棉语
        },
    ),
    "msa": (  # 键 'msa',代表马来语(宏语言)
        "Malay (macrolanguage)",  # 马来语(宏语言)的全称
        {  # 值是一个集合,包含多个字符串,代表具体的马来语及其变体
            "zsm_Latn",  # 马来语(马来文拉丁字母版)
            "ind",  # 印度尼西亚语
            "max_Latn",  # 马德佩勒马语(拉丁字母版)
            "zlm_Latn",  # 马来语(马来亚文拉丁字母版)
            "min",  # 明边语
        },
    ),
    "nic": (  # 键 'nic',代表尼日尔—科尔多凡语族
        "Niger-Kordofanian languages",  # 尼日尔—科尔多凡语族的全称
        {  # 值是一个集合,包含多个字符串,代表具体的尼日尔—科尔多凡语族语言
            "bam_Latn",  # 班巴拉语(拉丁字母版)
            "ewe",  # 埃维语
            "fuc",  # 富拉语
            "fuv",  # 富拉语
            "ibo",  # 伊博语
            "kin",  # 卢安达语
    "roa": (
        "Romance languages",
        {  # 这是一个集合,包含多种罗曼语系的语言代码
            "arg",  # 阿拉贡语
            "ast",  # 阿斯图里亚斯语
            "cat",  # 加泰罗尼亚语
            "cos",  # 科西嘉语
            "egl",  # 埃米利亚-罗马涅语
            "ext",  # 埃斯特雷马杜拉语
            "fra",  # 法语
            "frm_Latn",  # 中古法语(拉丁文书写)
            "gcf_Latn",  # 海地克里奥尔法语(拉丁文书写)
            "glg",  # 加利西亚语
            "hat",  # 海地克里奥尔语
            "ind",  # 印尼语
            "ita",  # 意大利语
            "lad",  # 犹太西班牙语
            "lad_Latn",  # 犹太西班牙语(拉丁文书写)
            "lij",  # 利古里亚语
            "lld_Latn",  # 皮德蒙特语(拉丁文书写)
            "lmo",  # 里米尼语
            "max_Latn",  # 里诺罗曼语(拉丁文书写)
            "mfe",  # 毛里求斯克里奥尔语
            "min",  # 明亚克语
            "mwl",  # 米兰达语
            "oci",  # 奥克语
            "pap",  # 帕皮亚门托语
            "pms",  # 皮埃蒙特语
            "por",  # 葡萄牙语
            "roh",  # 罗曼什语
            "ron",  # 罗马尼亚语
            "scn",  # 西西里语
            "spa",  # 西班牙语
            "tmw_Latn",  # 特米纳语(拉丁文书写)
            "vec",  # 威尼斯语
            "wln",  # 瓦隆语
            "zlm_Latn",  # 马来语(拉丁文书写)
            "zsm_Latn",  # 马来语(新加坡拉丁文书写)
        },
    ),
    "ron": ("Romanian", {"ron"}),  # 罗马尼亚语
    "run": ("Rundi", {"run"}),  # 鲁恩迪语
    "rus": ("Russian", {"rus"}),  # 俄语
    "sal": ("Salishan languages", {"shs_Latn"}),  # 沙利什语系
    "sem": (
        "Semitic languages",
        {  # 这是一个集合,包含多种闪米特语系的语言代码
            "acm",  # 中阿拉伯语
            "afb",  # 南布尔语
            "amh",  # 阿姆哈拉语
            "apc",  # 联合阿拉伯语
            "ara",  # 阿拉伯语
            "arq",  # 阿尔及利亚阿拉伯语
            "ary",  # 摩洛哥阿拉伯语
            "arz",  # 埃及阿拉伯语
            "heb",  # 希伯来语
            "mlt",  # 马耳他语
            "tir",  # 提格利尼亚语
        },
    ),
    "sla": (
        "Slavic languages",
        {  # 这是一个集合,包含多种斯拉夫语系的语言代码
            "bel",  # 白俄罗斯语
            "bel_Latn",  # 白俄罗斯语(拉丁文书写)
            "bos_Latn",  # 波斯尼亚语(拉丁文书写)
            "bul",  # 保加利亚语
            "bul_Latn",  # 保加利亚语(拉丁文书写)
            "ces",  # 捷克语
            "csb_Latn",  # 卡舒比亚语(拉丁文书写)
            "dsb",  # 下索布语
            "hrv",  # 克罗地亚语
            "hsb",  # 上索布语
            "mkd",  # 马其顿语
            "orv_Cyrl",  # 古教会斯拉夫语(西里尔文书写)
            "pol",  # 波兰语
            "rue",  # 卢森尼亚语
            "rus",  # 俄语
            "slv",  # 斯洛文尼亚语
            "srp_Cyrl",  # 塞尔维亚语(西里尔文书写)
            "srp_Latn",  # 塞尔维亚语(拉丁文书写)
            "ukr",  # 乌克兰语
        },
    ),
    "slv": ("Slovenian", {"slv"}),  # 斯洛文尼亚语
    "spa": ("Spanish", {"spa"}),  # 西班牙语
    "swe": ("Swedish", {"swe"}),  # 瑞典语
    "taw": ("Tai", {"lao", "tha"}),  # 泰语系
    "tgl": ("Tagalog", {"tgl_Latn"}),  # 菲律宾语
    "tha": ("Thai", {"tha"}),  # 泰语
    "trk": (
        "Turkic languages",
        {  # 这是一个集合,包含多种突厥语系的语言代码
            "aze_Latn",  # 阿塞拜疆语(拉丁文书写)
            "bak",  # 巴什基尔语
            "chv",  # 楚瓦什语
            "crh",  # 克里米亚土耳其语
            "crh_Latn",  # 克里米亚土耳其语(拉丁文书写)
            "kaz_Cyrl",  # 哈萨克语(西里尔文书写)
            "kaz_Latn",  # 哈萨克语(拉丁文书写)
            "kir_Cyrl",  # 柯尔克孜语(西里尔文书写)
            "kjh",  # 喀尔巴阡罗姆语
            "kum",  # 库梅克语
            "ota_Arab",  # 奥斯曼土耳其语(阿拉伯文书写)
            "ota_Latn",  # 奥斯曼土耳其语(拉丁文书写)
            "sah",  # 萨哈语
            "tat",  # 塔塔尔语
            "tat_Arab",  # 塔塔尔语(阿拉伯文书写)
            "tat_Latn",  # 塔塔尔语(拉丁文书写)
            "tuk",  # 土库曼语
            "tuk_Latn",  # 土库曼语(拉丁文书写)
            "tur",  # 土耳其语
            "tyv",  # 图瓦语
            "uig_Arab",  # 维吾尔语(阿拉伯文书写)
            "uig_Cyrl",  # 维吾尔语(西里尔文书写)
            "uzb_Cyrl",
    "zho": (
        "Chinese",
        {  # 定义一个包含多个元素的集合,表示中文相关的语言代码
            "cjy_Hans",  # 简体中文
            "cjy_Hant",  # 繁体中文
            "cmn",       # 普通话(中文)
            "cmn_Bopo",  # 普通话拼音
            "cmn_Hang",  # 普通话汉字
            "cmn_Hani",  # 普通话汉字
            "cmn_Hans",  # 普通话简体字
            "cmn_Hant",  # 普通话繁体字
            "cmn_Hira",  # 普通话平假名
            "cmn_Kana",  # 普通话假名
            "cmn_Latn",  # 普通话拉丁字母
            "cmn_Yiii",  # 普通话伊语
            "gan",       # 赣语
            "hak_Hani",  # 客家话汉字
            "lzh",       # 文言文
            "lzh_Bopo",  # 文言文拼音
            "lzh_Hang",  # 文言文汉字
            "lzh_Hani",  # 文言文汉字
            "lzh_Hans",  # 文言文简体字
            "lzh_Hira",  # 文言文平假名
            "lzh_Kana",  # 文言文假名
            "lzh_Yiii",  # 文言文伊语
            "nan",       # 台湾闽南语
            "nan_Hani",  # 台湾闽南语汉字
            "wuu",       # 吴语
            "wuu_Bopo",  # 吴语拼音
            "wuu_Hani",  # 吴语汉字
            "wuu_Latn",  # 吴语拉丁字母
            "yue",       # 粤语
            "yue_Bopo",  # 粤语拼音
            "yue_Hang",  # 粤语汉字
            "yue_Hani",  # 粤语汉字
            "yue_Hans",  # 粤语简体字
            "yue_Hant",  # 粤语繁体字
            "yue_Hira",  # 粤语平假名
            "yue_Kana",  # 粤语假名
            "zho",       # 中文
            "zho_Hans",  # 中文简体字
            "zho_Hant",  # 中文繁体字
        },
    ),
    "zle": (
        "East Slavic languages",
        {  # 定义一个包含多个元素的集合,表示东斯拉夫语族的语言代码
            "bel",       # 白俄罗斯语
            "orv_Cyrl",  # 古教会斯拉夫语(西里尔字母)
            "bel_Latn",  # 白俄罗斯语拉丁字母
            "rus",       # 俄语
            "ukr",       # 乌克兰语
            "rue",       # 卢森堡文
        },
    ),
    "zls": (
        "South Slavic languages",
        {  # 定义一个包含多个元素的集合,表示南斯拉夫语族的语言代码
            "bos_Latn",  # 波斯尼亚语拉丁字母
            "bul",       # 保加利亚语
            "bul_Latn",  # 保加利亚语拉丁字母
            "hrv",       # 克罗地亚语
            "mkd",       # 马其顿语
            "slv",       # 斯洛文尼亚语
            "srp_Cyrl",  # 塞尔维亚语(西里尔字母)
            "srp_Latn",  # 塞尔维亚语拉丁字母
        },
    ),
    "zlw": (
        "West Slavic languages",
        {  # 定义一个包含多个元素的集合,表示西斯拉夫语族的语言代码
            "csb_Latn",  # 卡舒比语拉丁字母
            "dsb",       # 下索布语
            "hsb",       # 上索布语
            "pol",       # 波兰语
            "ces",       # 捷克语
        },
    ),
}

# l2front_matter 函数:接受一个语言列表,返回一个包含每种语言前缀的字符串
def l2front_matter(langs):
    return "".join(f"- {l}\n" for l in langs)

# dedup 函数:移除列表中的重复项,并保持原有顺序
def dedup(lst):
    """Preservers order"""
    new_lst = []
    for item in lst:
        if not item or item in new_lst:
            continue
        else:
            new_lst.append(item)
    return new_lst

# 程序主入口,用于命令行参数解析和调用相关功能
if __name__ == "__main__":
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加命令行参数选项:models,要求必须提供,可以多次指定
    parser.add_argument(
        "-m", "--models", action="append", help="<Required> Set flag", required=True, nargs="+", dest="models"
    )
    # 添加命令行参数选项:save_dir,用于指定模型转换后的保存目录,默认为"marian_converted"
    parser.add_argument("-save_dir", "--save_dir", default="marian_converted", help="where to save converted models")
    # 解析命令行参数
    args = parser.parse_args()
    # 创建 TatoebaConverter 的实例,保存目录由命令行参数 save_dir 指定
    resolver = TatoebaConverter(save_dir=args.save_dir)
    # 调用 TatoebaConverter 实例的 convert_models 方法,传入命令行参数 models 的第一个参数作为模型列表
    resolver.convert_models(args.models[0])

.\models\marian\convert_marian_to_pytorch.py

# 导入必要的库
import argparse  # 用于命令行参数解析
import json  # 用于处理 JSON 数据
import os  # 提供与操作系统交互的功能
import socket  # 提供网络通信的功能
import time  # 提供时间相关的功能
import warnings  # 用于处理警告信息
from pathlib import Path  # 提供操作文件路径的功能
from typing import Dict, List, Union  # 提供类型提示支持
from zipfile import ZipFile  # 用于处理 ZIP 文件

import numpy as np  # 提供数值计算支持
import torch  # 提供深度学习框架支持
from huggingface_hub.hf_api import list_models  # 用于获取模型列表的功能
from torch import nn  # 提供神经网络模块的支持
from tqdm import tqdm  # 提供进度条功能

from transformers import MarianConfig, MarianMTModel, MarianTokenizer  # 导入 Hugging Face 的模型相关组件


def remove_suffix(text: str, suffix: str):
    # 如果文本以指定后缀结尾,则移除后缀并返回
    if text.endswith(suffix):
        return text[: -len(suffix)]
    return text  # 如果没有匹配的后缀,则返回原始文本


def remove_prefix(text: str, prefix: str):
    # 如果文本以指定前缀开头,则移除前缀并返回
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text  # 如果没有匹配的前缀,则返回原始文本


def convert_encoder_layer(opus_dict, layer_prefix: str, converter: dict):
    # 将 OPUS 字典中特定前缀的层转换成 PyTorch 可用的状态字典
    sd = {}
    for k in opus_dict:
        if not k.startswith(layer_prefix):
            continue
        stripped = remove_prefix(k, layer_prefix)
        v = opus_dict[k].T  # 除了嵌入层,所有内容都需要转置
        sd[converter[stripped]] = torch.tensor(v).squeeze()
    return sd


def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decoder=False):
    # 加载 OPUS 状态字典中的编码器或解码器层到模型的指定层列表中
    for i, layer in enumerate(layer_lst):
        layer_tag = f"decoder_l{i + 1}_" if is_decoder else f"encoder_l{i + 1}_"
        sd = convert_encoder_layer(opus_state, layer_tag, converter)
        layer.load_state_dict(sd, strict=False)


def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
    """查找可以接受指定源语言并输出目标语言的模型列表。"""
    prefix = "Helsinki-NLP/opus-mt-"
    model_list = list_models()  # 获取模型列表信息
    model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
    src_and_targ = [
        remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
    ]  # 只选择不含有 "+" 的模型
    matching = [f"{prefix}{a}-{b}" for (a, b) in src_and_targ if src_lang in a and tgt_lang in b]
    return matching  # 返回匹配的模型列表


def add_emb_entries(wemb, final_bias, n_special_tokens=1):
    # 添加特殊的嵌入条目和偏置项到词嵌入和偏置中
    vsize, d_model = wemb.shape
    embs_to_add = np.zeros((n_special_tokens, d_model))
    new_embs = np.concatenate([wemb, embs_to_add])
    bias_to_add = np.zeros((n_special_tokens, 1))
    new_bias = np.concatenate((final_bias, bias_to_add), axis=1)
    return new_embs, new_bias


def _cast_yaml_str(v):
    bool_dct = {"true": True, "false": False}
    # 检查变量 v 是否不是字符串类型,如果是其他类型则直接返回 v
    if not isinstance(v, str):
        return v
    # 如果 v 是布尔值字典 bool_dct 中的键,返回其对应的值
    elif v in bool_dct:
        return bool_dct[v]
    # 尝试将 v 转换为整数类型,如果成功则返回转换后的整数值
    try:
        return int(v)
    # 如果转换失败(TypeError 或 ValueError),则返回原始的 v 值
    except (TypeError, ValueError):
        return v
# 将原始配置字典中的每个值转换为 YAML 字符串,并返回新的字典
def cast_marian_config(raw_cfg: Dict[str, str]) -> Dict:
    return {k: _cast_yaml_str(v) for k, v in raw_cfg.items()}

# 定义配置文件的键名
CONFIG_KEY = "special:model.yml"

# 从给定的字典中加载配置信息,并返回转换后的配置字典
def load_config_from_state_dict(opus_dict):
    import yaml
    
    # 将从状态字典中取得的配置信息转换为字符串
    cfg_str = "".join([chr(x) for x in opus_dict[CONFIG_KEY]])
    # 使用 YAML 解析器加载配置字符串,使用 BaseLoader 作为加载器
    yaml_cfg = yaml.load(cfg_str[:-1], Loader=yaml.BaseLoader)
    # 调用 cast_marian_config 函数对 YAML 配置进行类型转换,并返回结果
    return cast_marian_config(yaml_cfg)

# 根据目标目录查找模型文件,并确保只有一个模型文件存在,返回该模型文件路径
def find_model_file(dest_dir):  # this one better
    model_files = list(Path(dest_dir).glob("*.npz"))
    if len(model_files) != 1:
        raise ValueError(f"Found more than one model file: {model_files}")
    model_file = model_files[0]
    return model_file

# 定义 ROMANCE 组的语言列表
ROM_GROUP = (
    "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
    "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
    "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
)

# 定义语言组的列表,每个元组包含语言列表和对应的组名
GROUPS = [
    ("cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh", "ZH"),
    (ROM_GROUP, "ROMANCE"),
    ("de+nl+fy+af+da+fo+is+no+nb+nn+sv", "NORTH_EU"),
    ("da+fo+is+no+nb+nn+sv", "SCANDINAVIA"),
    ("se+sma+smj+smn+sms", "SAMI"),
    ("nb_NO+nb+nn_NO+nn+nog+no_nb+no", "NORWAY"),
    ("ga+cy+br+gd+kw+gv", "CELTIC"),  # https://en.wikipedia.org/wiki/Insular_Celtic_languages
]

# 定义从组名到 OPUS 模型名称的映射字典
GROUP_TO_OPUS_NAME = {
    "opus-mt-ZH-de": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-de",
    "opus-mt-ZH-fi": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-fi",
    "opus-mt-ZH-sv": "cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh-sv",
    "opus-mt-SCANDINAVIA-SCANDINAVIA": "da+fo+is+no+nb+nn+sv-da+fo+is+no+nb+nn+sv",
    "opus-mt-NORTH_EU-NORTH_EU": "de+nl+fy+af+da+fo+is+no+nb+nn+sv-de+nl+fy+af+da+fo+is+no+nb+nn+sv",
    "opus-mt-de-ZH": "de-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
    "opus-mt-en_el_es_fi-en_el_es_fi": "en+el+es+fi-en+el+es+fi",
    "opus-mt-en-ROMANCE": (
        "en-fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
        "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
        "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la"
    ),
    "opus-mt-en-CELTIC": "en-ga+cy+br+gd+kw+gv",
    "opus-mt-es-NORWAY": "es-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
    "opus-mt-fi_nb_no_nn_ru_sv_en-SAMI": "fi+nb+no+nn+ru+sv+en-se+sma+smj+smn+sms",
    "opus-mt-fi-ZH": "fi-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
    "opus-mt-fi-NORWAY": "fi-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
    "opus-mt-ROMANCE-en": (
        "fr+fr_BE+fr_CA+fr_FR+wa+frp+oc+ca+rm+lld+fur+lij+lmo+es+es_AR+es_CL+es_CO+es_CR+es_DO"
        "+es_EC+es_ES+es_GT+es_HN+es_MX+es_NI+es_PA+es_PE+es_PR+es_SV+es_UY+es_VE+pt+pt_br+pt_BR"
        "+pt_PT+gl+lad+an+mwl+it+it_IT+co+nap+scn+vec+sc+ro+la-en"
    ),
    "opus-mt-CELTIC-en": "ga+cy+br+gd+kw+gv-en",
    # 为键 "opus-mt-CELTIC-en" 添加值 "ga+cy+br+gd+kw+gv-en"
    "opus-mt-sv-ZH": "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh",
    # 为键 "opus-mt-sv-ZH" 添加值 "sv-cmn+cn+yue+ze_zh+zh_cn+zh_CN+zh_HK+zh_tw+zh_TW+zh_yue+zhs+zht+zh"
    "opus-mt-sv-NORWAY": "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no",
    # 为键 "opus-mt-sv-NORWAY" 添加值 "sv-nb_NO+nb+nn_NO+nn+nog+no_nb+no"
}
# OPUS-GitHub 项目的 URL
OPUS_GITHUB_URL = "https://github.com/Helsinki-NLP/OPUS-MT-train/blob/master/models/"
# 组织名称
ORG_NAME = "Helsinki-NLP/"


def convert_opus_name_to_hf_name(x):
    """将 OPUS-MT-Train 名称转换为 Hugging Face 模型名称(已弃用)"""
    # 根据 GROUPS 中的替换规则,将 x 中的子字符串替换为对应的组名称
    for substr, grp_name in GROUPS:
        x = x.replace(substr, grp_name)
    return x.replace("+", "_")


def convert_hf_name_to_opus_name(hf_model_name):
    """
    根据假设,假设在不在 GROUP_TO_OPUS_NAME 中的模型中没有像 pt_br 这样的语言代码。
    将 Hugging Face 模型名称转换为 OPUS-MT-Train 名称
    """
    # 去除模型名称中的 ORG_NAME 前缀
    hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
    if hf_model_name in GROUP_TO_OPUS_NAME:
        opus_w_prefix = GROUP_TO_OPUS_NAME[hf_model_name]
    else:
        opus_w_prefix = hf_model_name.replace("_", "+")
    return remove_prefix(opus_w_prefix, "opus-mt-")


def get_system_metadata(repo_root):
    import git

    # 返回系统元数据字典,包括 Helsinki 的 Git SHA、transformers 的 Git SHA、运行机器名、当前时间
    return {
        "helsinki_git_sha": git.Repo(path=repo_root, search_parent_directories=True).head.object.hexsha,
        "transformers_git_sha": git.Repo(path=".", search_parent_directories=True).head.object.hexsha,
        "port_machine": socket.gethostname(),
        "port_time": time.strftime("%Y-%m-%d-%H:%M"),
    }


# docstyle-ignore
# 前置内容模板,用于生成模型卡片的前置元数据
FRONT_MATTER_TEMPLATE = """---
language:
{}
tags:
- translation

license: apache-2.0
---
"""
# 默认仓库名称
DEFAULT_REPO = "Tatoeba-Challenge"
# 默认模型目录路径
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")


def write_model_card(
    hf_model_name: str,
    repo_root=DEFAULT_REPO,
    save_dir=Path("marian_converted"),
    dry_run=False,
    extra_metadata={},
) -> str:
    """
    复制最新模型的 readme 部分来自 OPUS,并添加元数据。上传命令: aws s3 sync model_card_dir
    s3://models.huggingface.co/bert/Helsinki-NLP/ --dryrun
    """
    import pandas as pd

    # 去除模型名称中的 ORG_NAME 前缀
    hf_model_name = remove_prefix(hf_model_name, ORG_NAME)
    # 将 Hugging Face 模型名称转换为 OPUS-MT-Train 名称
    opus_name: str = convert_hf_name_to_opus_name(hf_model_name)
    if repo_root not in ("OPUS-MT-train", "Tatoeba-Challenge"):
        raise ValueError(f"Repos root is {repo_root}. Expected either OPUS-MT-train or Tatoeba-Challenge")
    # 构建 OPUS readme 文件路径
    opus_readme_path = Path(repo_root).joinpath("models", opus_name, "README.md")
    if not (opus_readme_path.exists()):
        raise ValueError(f"Readme file {opus_readme_path} not found")

    # 分离 OPUS 名称中的源语言和目标语言
    opus_src, opus_tgt = [x.split("+") for x in opus_name.split("-")]

    # 构建 OPUS README 在 GitHub 上的 URL
    readme_url = f"https://github.com/Helsinki-NLP/{repo_root}/tree/master/models/{opus_name}/README.md"

    s, t = ",".join(opus_src), ",".join(opus_tgt)
    # 构建元数据字典
    metadata = {
        "hf_name": hf_model_name,
        "source_languages": s,
        "target_languages": t,
        "opus_readme_url": readme_url,
        "original_repo": repo_root,
        "tags": ["translation"],
    }
    metadata.update(extra_metadata)
    # 添加系统元数据到元数据字典中
    metadata.update(get_system_metadata(repo_root))

    # 合并 OPUS readme 的 markdown 内容
    extra_markdown = (
        f"### {hf_model_name}\n\n* source group: {metadata['src_name']} \n* target group: "
        f"{metadata['tgt_name']} \n*  OPUS readme: [{opus_name}]({readme_url})\n"
    )
    # 构建额外的 Markdown 格式字符串,包含模型名称、源语言组、目标语言组和 OPUS readme 链接

    content = opus_readme_path.open().read()
    # 读取 OPUS readme 文件的内容

    content = content.split("\n# ")[-1]  # Get the lowest level 1 header in the README -- the most recent model.
    # 通过分割文本获取 README 中最底层的一级标题,即最近的模型信息

    splat = content.split("*")[2:]
    # 使用星号分割内容,从第三个星号开始获取后面的所有部分

    print(splat[3])
    # 打印第四个分割后的部分,假设这里是输出特定信息的调试步骤

    content = "*".join(splat)
    # 将分割后的内容重新连接起来,使用星号作为连接符

    content = (
        FRONT_MATTER_TEMPLATE.format(metadata["src_alpha2"])
        + extra_markdown
        + "\n* "
        + content.replace("download", "download original weights")
    )
    # 构建最终的内容字符串,包括前置模板、额外的 Markdown 信息和处理后的内容部分

    items = "\n\n".join([f"- {k}: {v}" for k, v in metadata.items()])
    # 将元数据中的键值对格式化为列表项

    sec3 = "\n### System Info: \n" + items
    # 构建系统信息部分的 Markdown 标题和元数据列表

    content += sec3
    # 将系统信息部分添加到最终的内容字符串中

    if dry_run:
        return content, metadata
    # 如果是 dry_run 模式,则返回内容字符串和元数据

    sub_dir = save_dir / f"opus-mt-{hf_model_name}"
    # 构建保存子目录路径,包括模型名称

    sub_dir.mkdir(exist_ok=True)
    # 创建保存子目录,如果已存在则忽略

    dest = sub_dir / "README.md"
    # 构建 README 文件路径

    dest.open("w").write(content)
    # 将最终的内容写入 README 文件

    pd.Series(metadata).to_json(sub_dir / "metadata.json")
    # 将元数据以 JSON 格式保存到子目录的 metadata.json 文件中

    # if dry_run:
    return content, metadata
    # 返回最终的内容字符串和元数据
# 创建注册表函数,用于处理特定路径下的模型注册
def make_registry(repo_path="Opus-MT-train/models"):
    # 检查指定路径下的 README.md 文件是否存在,如果不存在则抛出数值错误
    if not (Path(repo_path) / "fr-en" / "README.md").exists():
        raise ValueError(
            f"repo_path:{repo_path} does not exist: "
            "You must run: git clone git@github.com:Helsinki-NLP/Opus-MT-train.git before calling."
        )
    # 初始化结果字典
    results = {}
    # 遍历指定路径下的所有子目录和文件
    for p in Path(repo_path).iterdir():
        # 统计当前路径名称中 "-" 的数量
        n_dash = p.name.count("-")
        # 如果没有 "-",则跳过当前路径
        if n_dash == 0:
            continue
        else:
            # 读取当前路径下的 README.md 文件的所有行
            lns = list(open(p / "README.md").readlines())
            # 使用解析函数处理 README.md 的内容,并存入结果字典
            results[p.name] = _parse_readme(lns)
    # 返回结果列表,包含每个模型的关键信息
    return [(k, v["pre-processing"], v["download"], v["download"][:-4] + ".test.txt") for k, v in results.items()]


# 批量转换所有 SentencePiece 模型
def convert_all_sentencepiece_models(model_list=None, repo_path=None, dest_dir=Path("marian_converted")):
    """Requires 300GB"""
    # 设置保存目录和目标目录
    save_dir = Path("marian_ckpt")
    dest_dir = Path(dest_dir)
    dest_dir.mkdir(exist_ok=True)
    # 初始化保存路径列表
    save_paths = []
    # 如果未指定模型列表,则调用 make_registry 函数获取模型列表
    if model_list is None:
        model_list: list = make_registry(repo_path=repo_path)
    # 遍历模型列表
    for k, prepro, download, test_set_url in tqdm(model_list):
        # 如果预处理中不包含 "SentencePiece",则跳过当前模型
        if "SentencePiece" not in prepro:  # dont convert BPE models.
            continue
        # 如果保存目录中不存在当前模型的文件夹,则下载并解压缩模型文件
        if not os.path.exists(save_dir / k):
            download_and_unzip(download, save_dir / k)
        # 将 Opus 模型名转换为 Hugging Face 模型名
        pair_name = convert_opus_name_to_hf_name(k)
        # 执行模型转换操作
        convert(save_dir / k, dest_dir / f"opus-mt-{pair_name}")

        # 将转换后的模型保存路径加入保存路径列表
        save_paths.append(dest_dir / f"opus-mt-{pair_name}")
    # 返回所有转换后模型的保存路径列表
    return save_paths


# 自定义列表映射函数,对输入列表中的每个元素应用给定函数
def lmap(f, x) -> List:
    return list(map(f, x))


# 下载测试集并返回源语言、金标准和模型输出列表
def fetch_test_set(test_set_url):
    import wget

    # 使用 wget 下载测试集文件到本地
    fname = wget.download(test_set_url, "opus_test.txt")
    # 读取下载的文件的所有行
    lns = Path(fname).open().readlines()
    # 提取源语言、金标准和模型输出的列表,并进行字符串修剪
    src = lmap(str.strip, lns[::4])
    gold = lmap(str.strip, lns[1::4])
    mar_model = lmap(str.strip, lns[2::4])
    # 检查三个列表的长度是否相等,如果不相等则抛出数值错误
    if not (len(gold) == len(mar_model) == len(src)):
        raise ValueError(f"Gold, marian and source lengths {len(gold)}, {len(mar_model)}, {len(src)} mismatched")
    # 删除下载的测试集文件
    os.remove(fname)
    # 返回源语言列表、模型输出列表和金标准列表
    return src, mar_model, gold


# 批量转换指定目录下的所有模型文件
def convert_whole_dir(path=Path("marian_ckpt/")):
    # 遍历指定路径下的所有子目录
    for subdir in tqdm(list(path.ls())):
        # 设置目标目录路径
        dest_dir = f"marian_converted/{subdir.name}"
        # 如果目标目录中已存在 pytorch_model.bin 文件,则跳过当前子目录
        if (dest_dir / "pytorch_model.bin").exists():
            continue
        # 执行模型转换操作
        convert(source_dir, dest_dir)


# 解析 README.md 文件内容,获取 Opus 模型的链接和元数据
def _parse_readme(lns):
    """Get link and metadata from opus model card equivalent."""
    # 初始化子结果字典
    subres = {}
    # 遍历所有行
    for ln in [x.strip() for x in lns]:
        # 如果行不以 "*" 开头,则跳过当前行
        if not ln.startswith("*"):
            continue
        # 去掉首部的 "*" 符号
        ln = ln[1:].strip()

        # 遍历关键词列表,识别关键词并提取对应的值
        for k in ["download", "dataset", "models", "model", "pre-processing"]:
            if ln.startswith(k):
                break
        else:
            continue
        # 根据关键词类型处理对应的值
        if k in ["dataset", "model", "pre-processing"]:
            splat = ln.split(":")
            _, v = splat
            subres[k] = v
        elif k == "download":
            v = ln.split("(")[-1][:-1]
            subres[k] = v
    # 返回子结果字典,包含从 README.md 中提取的所有信息
    return subres


# 保存分词器配置到指定目录
def save_tokenizer_config(dest_dir: Path, separate_vocabs=False):
    # 将目标目录的名称按照 "-" 分割成列表
    dname = dest_dir.name.split("-")
    # 构建包含目标语言、源语言和是否分开词汇表的字典
    dct = {"target_lang": dname[-1], "source_lang": "-".join(dname[:-1]), "separate_vocabs": separate_vocabs}
    # 将字典保存为 JSON 文件,文件名为 "tokenizer_config.json",保存在目标目录中
    save_json(dct, dest_dir / "tokenizer_config.json")
# 向词汇表中添加特殊标记,如果需要分开处理词汇表,则加载源和目标词汇表并分别处理
def add_special_tokens_to_vocab(model_dir: Path, separate_vocab=False) -> None:
    if separate_vocab:
        # 加载源语言词汇表并转换为整数键值对
        vocab = load_yaml(find_src_vocab_file(model_dir))
        vocab = {k: int(v) for k, v in vocab.items()}
        # 向词汇表中添加特殊标记"<pad>",返回添加的标记数目
        num_added = add_to_vocab_(vocab, ["<pad>"])
        # 将更新后的词汇表保存为 JSON 文件
        save_json(vocab, model_dir / "vocab.json")

        # 加载目标语言词汇表并转换为整数键值对
        vocab = load_yaml(find_tgt_vocab_file(model_dir))
        vocab = {k: int(v) for k, v in vocab.items()}
        # 向词汇表中添加特殊标记"<pad>",返回添加的标记数目
        num_added = add_to_vocab_(vocab, ["<pad>"])
        # 将更新后的目标语言词汇表保存为 JSON 文件
        save_json(vocab, model_dir / "target_vocab.json")
        # 保存分词器配置
        save_tokenizer_config(model_dir, separate_vocabs=separate_vocab)
    else:
        # 加载统一词汇表并转换为整数键值对
        vocab = load_yaml(find_vocab_file(model_dir))
        vocab = {k: int(v) for k, v in vocab.items()}
        # 向词汇表中添加特殊标记"<pad>",返回添加的标记数目
        num_added = add_to_vocab_(vocab, ["<pad>"])
        # 打印添加的标记数目
        print(f"added {num_added} tokens to vocab")
        # 将更新后的词汇表保存为 JSON 文件
        save_json(vocab, model_dir / "vocab.json")
        # 保存分词器配置
        save_tokenizer_config(model_dir)



# 检查两个键对应的值是否相等,若不相等则抛出 ValueError 异常
def check_equal(marian_cfg, k1, k2):
    v1, v2 = marian_cfg[k1], marian_cfg[k2]
    if v1 != v2:
        raise ValueError(f"hparams {k1},{k2} differ: {v1} != {v2}")



# 检索指定目录下的第一个以 "*vocab.yml" 结尾的文件并返回其路径
def find_vocab_file(model_dir):
    return list(model_dir.glob("*vocab.yml"))[0]



# 检索指定目录下的第一个以 "*src.vocab.yml" 结尾的文件并返回其路径
def find_src_vocab_file(model_dir):
    return list(model_dir.glob("*src.vocab.yml"))[0]



# 检索指定目录下的第一个以 "*trg.vocab.yml" 结尾的文件并返回其路径
def find_tgt_vocab_file(model_dir):
    return list(model_dir.glob("*trg.vocab.yml"))[0]



# 向词汇表中添加特殊标记,根据词汇表中最大的值确定起始位置
def add_to_vocab_(vocab: Dict[str, int], special_tokens: List[str]):
    start = max(vocab.values()) + 1  # 确定新添加标记的起始位置
    added = 0  # 初始化添加的标记数目
    for tok in special_tokens:
        if tok in vocab:
            continue
        vocab[tok] = start + added  # 将特殊标记添加到词汇表中
        added += 1  # 更新添加的标记数目
    return added  # 返回添加的标记数目



# 检查 marian_cfg 中指定的配置项是否符合预期设置
def check_marian_cfg_assumptions(marian_cfg):
    assumed_settings = {
        "layer-normalization": False,
        "right-left": False,
        "transformer-ffn-depth": 2,
        "transformer-aan-depth": 2,
        "transformer-no-projection": False,
        "transformer-postprocess-emb": "d",
        "transformer-postprocess": "dan",  # Dropout, add, normalize
        "transformer-preprocess": "",
        "type": "transformer",
        "ulr-dim-emb": 0,
        "dec-cell-base-depth": 2,
        "dec-cell-high-depth": 1,
        "transformer-aan-nogate": False,
    }
    for k, v in assumed_settings.items():
        actual = marian_cfg[k]
        if actual != v:
            raise ValueError(f"Unexpected config value for {k} expected {v} got {actual}")



# BART 模型的配置映射,将不同的层权重映射到对应的键
BIAS_KEY = "decoder_ff_logit_out_b"
BART_CONVERTER = {  # 用于每个编码器和解码器层
    "self_Wq": "self_attn.q_proj.weight",
    "self_Wk": "self_attn.k_proj.weight",
    "self_Wv": "self_attn.v_proj.weight",
    "self_Wo": "self_attn.out_proj.weight",
    "self_bq": "self_attn.q_proj.bias",
    "self_bk": "self_attn.k_proj.bias",
    "self_bv": "self_attn.v_proj.bias",
    "self_bo": "self_attn.out_proj.bias",
    "self_Wo_ln_scale": "self_attn_layer_norm.weight",
    "self_Wo_ln_bias": "self_attn_layer_norm.bias",
    "ffn_W1": "fc1.weight",
    "ffn_b1": "fc1.bias",
}
    # 权重矩阵和偏置向量对应于神经网络的第二个全连接层
    "ffn_W2": "fc2.weight",
    "ffn_b2": "fc2.bias",
    
    # 最终层归一化的缩放因子和偏置项
    "ffn_ffn_ln_scale": "final_layer_norm.weight",
    "ffn_ffn_ln_bias": "final_layer_norm.bias",
    
    # 解码器交叉注意力机制中的权重矩阵和偏置向量
    "context_Wk": "encoder_attn.k_proj.weight",
    "context_Wo": "encoder_attn.out_proj.weight",
    "context_Wq": "encoder_attn.q_proj.weight",
    "context_Wv": "encoder_attn.v_proj.weight",
    "context_bk": "encoder_attn.k_proj.bias",
    "context_bo": "encoder_attn.out_proj.bias",
    "context_bq": "encoder_attn.q_proj.bias",
    "context_bv": "encoder_attn.v_proj.bias",
    
    # 编码器注意力层归一化的缩放因子和偏置项
    "context_Wo_ln_scale": "encoder_attn_layer_norm.weight",
    "context_Wo_ln_bias": "encoder_attn_layer_norm.bias",
    }

class OpusState:
    # 检查层条目的有效性,初始化编码器和解码器的第一层的键列表
    def _check_layer_entries(self):
        self.encoder_l1 = self.sub_keys("encoder_l1")  # 获取编码器第一层的键列表
        self.decoder_l1 = self.sub_keys("decoder_l1")  # 获取解码器第一层的键列表
        self.decoder_l2 = self.sub_keys("decoder_l2")  # 获取解码器第二层的键列表
        # 检查编码器第一层键的数量是否为16,如果不是则发出警告
        if len(self.encoder_l1) != 16:
            warnings.warn(f"Expected 16 keys for each encoder layer, got {len(self.encoder_l1)}")
        # 检查解码器第一层键的数量是否为26,如果不是则发出警告
        if len(self.decoder_l1) != 26:
            warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")
        # 检查解码器第二层键的数量是否为26,如果不是则发出警告
        if len(self.decoder_l2) != 26:
            warnings.warn(f"Expected 26 keys for each decoder layer, got {len(self.decoder_l1)}")

    @property
    # 获取额外的键列表,排除特定的键
    def extra_keys(self):
        extra = []
        # 遍历状态键列表,排除特定的键,生成额外的键列表
        for k in self.state_keys:
            if (
                k.startswith("encoder_l")
                or k.startswith("decoder_l")
                or k in [CONFIG_KEY, "Wemb", "encoder_Wemb", "decoder_Wemb", "Wpos", "decoder_ff_logit_out_b"]
            ):
                continue
            else:
                extra.append(k)
        return extra

    # 获取给定层前缀的子键列表
    def sub_keys(self, layer_prefix):
        return [remove_prefix(k, layer_prefix) for k in self.state_dict if k.startswith(layer_prefix)]

    # 加载分词器,根据源目录加载Marian分词器
    def load_tokenizer(self):
        add_special_tokens_to_vocab(self.source_dir, not self.share_encoder_decoder_embeddings)  # 将特殊标记添加到词汇表中
        return MarianTokenizer.from_pretrained(str(self.source_dir))  # 返回从预训练模型加载的Marian分词器
    # 加载 MarianMTModel 模型的方法,返回一个 MarianMTModel 对象
    def load_marian_model(self) -> MarianMTModel:
        # 获取状态字典和 HF 配置
        state_dict, cfg = self.state_dict, self.hf_config

        # 如果配置中 static_position_embeddings 不为 True,则抛出数值错误异常
        if not cfg.static_position_embeddings:
            raise ValueError("config.static_position_embeddings should be True")

        # 根据配置创建 MarianMTModel 模型对象
        model = MarianMTModel(cfg)

        # 如果配置中包含 "hidden_size" 键,抛出数值错误异常
        if "hidden_size" in cfg.to_dict():
            raise ValueError("hidden_size is in config")

        # 加载编码器层的状态字典到模型中,使用 BART_CONVERTER 转换
        load_layers_(
            model.model.encoder.layers,
            state_dict,
            BART_CONVERTER,
        )

        # 加载解码器层的状态字典到模型中,使用 BART_CONVERTER 转换,并指定为解码器层
        load_layers_(model.model.decoder.layers, state_dict, BART_CONVERTER, is_decoder=True)

        # 处理与层无关的张量
        if self.cfg["tied-embeddings-src"]:
            # 如果源语言嵌入被绑定,创建源语言嵌入张量和偏置张量,并分配给模型共享的权重
            wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
            bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
            model.model.shared.weight = wemb_tensor
            model.model.encoder.embed_tokens = model.model.decoder.embed_tokens = model.model.shared
        else:
            # 如果未绑定源语言嵌入,创建源语言嵌入张量,并分配给编码器的嵌入权重
            wemb_tensor = nn.Parameter(torch.FloatTensor(self.wemb))
            model.model.encoder.embed_tokens.weight = wemb_tensor

            # 创建解码器嵌入张量、偏置张量,并分配给解码器的嵌入权重和最终偏置
            decoder_wemb_tensor = nn.Parameter(torch.FloatTensor(self.dec_wemb))
            bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
            model.model.decoder.embed_tokens.weight = decoder_wemb_tensor

        # 将最终偏置张量分配给模型的最终对数偏置
        model.final_logits_bias = bias_tensor

        # 如果状态字典中存在 "Wpos" 键,打印警告信息
        if "Wpos" in state_dict:
            print("Unexpected: got Wpos")
            # 创建 Wpos 张量并分配给编码器和解码器的位置嵌入权重
            wpos_tensor = torch.tensor(state_dict["Wpos"])
            model.model.encoder.embed_positions.weight = wpos_tensor
            model.model.decoder.embed_positions.weight = wpos_tensor

        # 如果配置中启用了嵌入归一化
        if cfg.normalize_embedding:
            # 如果状态字典中缺少 "encoder_emb_ln_scale_pre" 键,抛出数值错误异常
            if "encoder_emb_ln_scale_pre" not in state_dict:
                raise ValueError("encoder_emb_ln_scale_pre is not in state dictionary")
            # 抛出未实现错误,需要转换 layernorm_embedding
            raise NotImplementedError("Need to convert layernorm_embedding")

        # 如果存在额外的键,抛出数值错误异常
        if self.extra_keys:
            raise ValueError(f"Failed to convert {self.extra_keys}")

        # 如果模型的输入嵌入的填充索引与 self.pad_token_id 不匹配,抛出数值错误异常
        if model.get_input_embeddings().padding_idx != self.pad_token_id:
            raise ValueError(
                f"Padding tokens {model.get_input_embeddings().padding_idx} and {self.pad_token_id} mismatched"
            )

        # 返回加载完成的 MarianMTModel 模型对象
        return model
    """
    Tatoeba conversion instructions in scripts/tatoeba/README.md
    """
    # 导入 argparse 模块,用于处理命令行参数
    parser = argparse.ArgumentParser()
    # 必需参数
    parser.add_argument("--src", type=str, help="path to marian model sub dir", default="en-de")
    parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model.")
    # 解析命令行参数
    args = parser.parse_args()

    # 将源目录路径转换为 Path 对象
    source_dir = Path(args.src)
    # 如果源目录不存在,则抛出 ValueError 异常
    if not source_dir.exists():
        raise ValueError(f"Source directory {source_dir} not found")
    # 将目标目录转换为路径字符串,默认情况下是在源目录名前加上 'converted-'
    dest_dir = f"converted-{source_dir.name}" if args.dest is None else args.dest
    # 调用 convert 函数,进行模型转换
    convert(source_dir, dest_dir)

.\models\marian\modeling_flax_marian.py

# coding=utf-8
# 版权所有 2021 年 The Marian Team 作者和 Google Flax Team 作者以及 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的要求,否则您不能使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则不得以“原样”分发软件,
# 没有任何明示或暗示的保证或条件。请查阅许可证了解具体语言。
""" Flax Marian model."""

import math
import random
from functools import partial
from typing import Callable, Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax
from jax.random import PRNGKey

from ...modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPastAndCrossAttentions,
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
)
from ...modeling_flax_utils import (
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_marian import MarianConfig


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
_CONFIG_FOR_DOC = "MarianConfig"


MARIAN_START_DOCSTRING = r"""
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
    Parameters:
        config ([`MarianConfig`]): Model configuration class with all the parameters of the model.
            初始化模型配置类,包含所有模型参数。
            通过配置文件初始化不会加载与模型相关的权重,仅加载配置。
            参考 [`~FlaxPreTrainedModel.from_pretrained`] 方法以加载模型权重。

        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            计算的数据类型。可以是 `jax.numpy.float32`、`jax.numpy.float16`(在GPU上)和 `jax.numpy.bfloat16`(在TPU上)之一。

            可用于在GPU或TPU上启用混合精度训练或半精度推断。
            如果指定,所有计算将使用给定的 `dtype` 执行。

            **注意,这仅指定计算的数据类型,不影响模型参数的数据类型。**

            如果要更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""

MARIAN_INPUTS_DOCSTRING = r"""
"""


MARIAN_ENCODE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            输入序列标记的索引,用于词汇表中的标记。默认情况下会忽略填充。

            可以使用 [`AutoTokenizer`] 获得这些索引。参见 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`] 了解详情。

            [什么是输入 ID?](../glossary#input-ids)
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
            遮罩,用于避免在填充的标记索引上进行注意力计算。遮罩值选在 `[0, 1]` 范围内:

            - 1 表示**不遮罩**的标记,
            - 0 表示**遮罩**的标记。

            [什么是注意力遮罩?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *可选*):
            输入序列标记在位置嵌入中的位置索引。选择范围是 `[0, config.max_position_embeddings - 1]`。
        output_attentions (`bool`, *可选*):
            是否返回所有注意力层的注意力张量。更多细节请参见返回的张量中的 `attentions` 字段。
        output_hidden_states (`bool`, *可选*):
            是否返回所有层的隐藏状态。更多细节请参见返回的张量中的 `hidden_states` 字段。
        return_dict (`bool`, *可选*):
            是否返回一个 [`~utils.ModelOutput`] 而不是简单的元组。
"""

MARIAN_DECODE_INPUTS_DOCSTRING = r"""
"""


def create_sinusoidal_positions(n_pos, dim):
    """
    创建正弦位置编码。

    Args:
        n_pos (int): 位置数量。
        dim (int): 编码维度。

    Returns:
        jnp.ndarray: 正弦位置编码的数组。
    """
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    sentinel = dim // 2 + dim % 2
    out = np.zeros_like(position_enc)
    out[:, 0:sentinel] = np.sin(position_enc[:, 0::2])
    out[:, sentinel:] = np.cos(position_enc[:, 1::2])

    return jnp.array(out)


# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    """
    将输入的标记向右移动一位。
    """
    shifted_input_ids = jnp.zeros_like(input_ids)
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)

    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids


# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->Marian
class FlaxMarianAttention(nn.Module):
    """
    Marian 模型的注意力机制模块。
    """
    config: MarianConfig
    embed_dim: int
    num_heads: int
    dropout: float = 0.0
    causal: bool = False
    bias: bool = True
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型,默认为 jnp.float32

    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads  # 计算每个头部的维度
        if self.head_dim * self.num_heads != self.embed_dim:  # 检查 embed_dim 是否能被 num_heads 整除
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 创建一个部分应用了 nn.Dense 的函数,用于创建全连接层
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # 初始化查询、键、值、输出的全连接层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        # 初始化 dropout 层
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        if self.causal:
            # 如果需要因果注意力,创建因果 mask
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    def _split_heads(self, hidden_states):
        # 将隐藏状态张量按照头部数目和头部维度进行重塑
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    def _merge_heads(self, hidden_states):
        # 将分离的头部重新合并成原来的形状
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    @nn.compact
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否初始化缓存数据
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取缓存的键(key)并初始化为零数组,其形状和数据类型与输入的键(key)相同
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        # 获取缓存的值(value)并初始化为零数组,其形状和数据类型与输入的值(value)相同
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取缓存索引(index),如果不存在则初始化为零
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 获取批次维度的数量和最大长度、注意力头数、每个头部的深度
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 更新键(key)和值(value)缓存,使用新的一维空间切片
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键(key)和值(value)
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新的缓存向量数
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 对于缓存的解码器自注意力,创建因果掩码:我们的单个查询位置只能关注已生成和缓存的键位置,而不是剩余的零元素
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并因果掩码和输入的注意力掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        # 返回更新后的键(key)、值(value)和注意力掩码
        return key, value, attention_mask
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer 复制代码并将 Bart->Marian 替换
class FlaxMarianEncoderLayer(nn.Module):
    # Marian 模型配置
    config: MarianConfig
    # 计算的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 设置层的初始化操作
    def setup(self) -> None:
        # 设置嵌入维度为模型配置中的 d_model
        self.embed_dim = self.config.d_model
        # 定义自注意力层
        self.self_attn = FlaxMarianAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 自注意力层后的 Layer Normalization
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # Dropout 层
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 激活函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 激活函数后的 Dropout 层
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
        # 第一个全连接层,使用 jax 的正态分布初始化
        self.fc1 = nn.Dense(
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 第二个全连接层,输出维度为 embed_dim,同样使用正态分布初始化
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 最终的 Layer Normalization
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 实现类的调用方法,对输入的隐藏状态进行编码处理
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        # 保存残差连接
        residual = hidden_states
        # 应用自注意力机制,得到新的隐藏状态和注意力权重
        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)

        # 应用 Dropout
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 残差连接和新隐藏状态相加
        hidden_states = residual + hidden_states
        # 应用自注意力层后的 Layer Normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 保存残差连接
        residual = hidden_states
        # 应用激活函数和第一个全连接层
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用激活函数后的 Dropout
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用第二个全连接层
        hidden_states = self.fc2(hidden_states)
        # 应用 Dropout
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 残差连接和新隐藏状态相加
        hidden_states = residual + hidden_states
        # 应用最终的 Layer Normalization
        hidden_states = self.final_layer_norm(hidden_states)

        # 输出为一个元组,包含最终的隐藏状态
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,加入到输出元组中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs


# 从 transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection 复制代码并将 Bart->Marian 替换
class FlaxMarianEncoderLayerCollection(nn.Module):
    # Marian 模型配置
    config: MarianConfig
    # 计算的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 设置层的初始化操作
    def setup(self):
        # 创建编码层的集合,每个编码层使用 FlaxMarianEncoderLayer 创建
        self.layers = [
            FlaxMarianEncoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.encoder_layers)
        ]
        # 编码层的 dropout 率
        self.layerdrop = self.config.encoder_layerdrop
    # 定义一个特殊方法 __call__,使得对象可以被调用
    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 如果需要输出注意力权重,则初始化空的元组用于存储所有注意力权重
        all_attentions = () if output_attentions else None
        # 如果需要输出隐藏状态,则初始化空的元组用于存储所有隐藏状态
        all_hidden_states = () if output_hidden_states else None

        # 遍历所有的编码器层
        for encoder_layer in self.layers:
            # 如果需要输出隐藏状态,则将当前隐藏状态添加到 all_hidden_states 中
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 添加 LayerDrop 功能,参见论文 https://arxiv.org/abs/1909.11556 的描述
            dropout_probability = random.uniform(0, 1)
            # 如果非确定性且随机数小于层级丢弃率,则跳过当前层
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None)  # 跳过层的输出
            else:
                # 否则,调用当前编码器层的前向传播函数
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    output_attentions,
                    deterministic,
                )
            # 更新隐藏状态为当前层的输出的第一个元素
            hidden_states = layer_outputs[0]
            # 如果需要输出注意力权重,则将当前层的注意力权重添加到 all_attentions 中
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # 如果需要输出隐藏状态,则将最后一个隐藏状态添加到 all_hidden_states 中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 将最终的输出整合为一个元组,根据 return_dict 决定返回类型
        outputs = (hidden_states, all_hidden_states, all_attentions)

        # 如果不需要以字典形式返回,则返回一个元组,去除其中为 None 的部分
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 否则,以 FlaxBaseModelOutput 对象形式返回所有输出
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )
# 从 transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer 复制代码,并将 Bart 更改为 Marian
class FlaxMarianDecoderLayer(nn.Module):
    # 使用 MarianConfig 类型的配置参数 config
    config: MarianConfig
    # 默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 设置方法,初始化层的各项参数
    def setup(self) -> None:
        # 获取嵌入维度,等于配置中的 d_model
        self.embed_dim = self.config.d_model
        # 定义自注意力层,使用 FlaxMarianAttention 类
        self.self_attn = FlaxMarianAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True,
            dtype=self.dtype,
        )
        # 定义 dropout 层,用于 self-attention 和全连接层之间
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 激活函数,根据配置中的激活函数选择对应的函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 激活函数的 dropout 层
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)

        # 自注意力层的 LayerNorm 层
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 定义编码器注意力层,使用 FlaxMarianAttention 类
        self.encoder_attn = FlaxMarianAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 编码器注意力层的 LayerNorm 层
        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 第一个全连接层,输入维度为 decoder_ffn_dim,输出维度与嵌入维度相同
        self.fc1 = nn.Dense(
            self.config.decoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 第二个全连接层,输入维度与嵌入维度相同,输出维度也与嵌入维度相同
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 最终输出的 LayerNorm 层
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 对象调用方法,定义层的前向传播逻辑
    def __call__(
        self,
        hidden_states: jnp.ndarray,  # 输入的隐藏状态
        attention_mask: jnp.ndarray,  # 注意力掩码
        encoder_hidden_states: Optional[jnp.ndarray] = None,  # 编码器的隐藏状态(可选)
        encoder_attention_mask: Optional[jnp.ndarray] = None,  # 编码器的注意力掩码(可选)
        init_cache: bool = False,  # 是否初始化缓存(默认为 False)
        output_attentions: bool = True,  # 是否输出注意力权重(默认为 True)
        deterministic: bool = True,  # 是否确定性计算(默认为 True)
    ) -> Tuple[jnp.ndarray]:
        residual = hidden_states

        # Self Attention
        # 使用自注意力机制处理隐藏状态,返回处理后的隐藏状态和注意力权重
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
        )
        # 应用 dropout 层,用于防止过拟合
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 对处理后的隐藏状态进行层归一化
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # Cross-Attention Block
        cross_attn_weights = None
        # 如果有编码器隐藏状态,执行交叉注意力机制
        if encoder_hidden_states is not None:
            residual = hidden_states

            # 使用编码器注意力机制处理隐藏状态,返回处理后的隐藏状态和注意力权重
            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            # 应用 dropout 层,用于防止过拟合
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
            # 添加残差连接
            hidden_states = residual + hidden_states
            # 对处理后的隐藏状态进行层归一化
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

        # Fully Connected
        residual = hidden_states
        # 应用激活函数和全连接层 fc1
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用 dropout 层,用于防止过拟合
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用全连接层 fc2
        hidden_states = self.fc2(hidden_states)
        # 应用 dropout 层,用于防止过拟合
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 对处理后的隐藏状态进行层归一化
        hidden_states = self.final_layer_norm(hidden_states)

        outputs = (hidden_states,)

        # 如果需要输出注意力权重,将自注意力和交叉注意力的权重添加到输出中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        # 返回最终输出
        return outputs
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection with Bart->Marian
# 定义一个名为FlaxMarianDecoderLayerCollection的类,作为Marian模型的解码器层集合

class FlaxMarianDecoderLayerCollection(nn.Module):
    config: MarianConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    def setup(self):
        # 初始化解码器层列表,每个解码器层使用FlaxMarianDecoderLayer构造,数量由配置文件self.config.decoder_layers决定
        self.layers = [
            FlaxMarianDecoderLayer(self.config, name=str(i), dtype=self.dtype)
            for i in range(self.config.decoder_layers)
        ]
        # 设置LayerDrop的概率,从配置文件self.config.decoder_layerdrop中获取

        self.layerdrop = self.config.decoder_layerdrop

    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # decoder layers
        # 如果需要输出隐藏状态,则初始化all_hidden_states为一个空元组,否则为None
        all_hidden_states = () if output_hidden_states else None
        # 如果需要输出注意力分布,则初始化all_self_attns为一个空元组,否则为None
        all_self_attns = () if output_attentions else None
        # 如果需要输出交叉注意力分布,并且encoder_hidden_states不为None,则初始化all_cross_attentions为一个空元组,否则为None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # 遍历每个解码器层进行处理
        for decoder_layer in self.layers:
            if output_hidden_states:
                # 如果需要输出隐藏状态,将当前的hidden_states加入all_hidden_states中
                all_hidden_states += (hidden_states,)
                # 添加LayerDrop功能,详情见论文https://arxiv.org/abs/1909.11556

            # 生成一个0到1之间的随机数,作为Dropout的概率
            dropout_probability = random.uniform(0, 1)
            # 如果不是确定性的计算,并且随机数小于self.layerdrop,则不执行当前解码器层的计算
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None, None)
            else:
                # 否则,执行当前解码器层的计算,传入相应的参数
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )

            # 更新hidden_states为当前解码器层的输出中的第一个元素
            hidden_states = layer_outputs[0]
            if output_attentions:
                # 如果需要输出注意力分布,将当前解码器层的注意力分布加入all_self_attns中
                all_self_attns += (layer_outputs[1],)

                # 如果encoder_hidden_states不为None,则将当前解码器层的交叉注意力分布加入all_cross_attentions中
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 将最后一个解码器层的隐藏状态加入all_hidden_states中
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 汇总所有的输出信息到outputs列表中
        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]

        # 如果不需要以字典形式返回结果,则返回outputs中不为None的元素构成的元组
        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 否则,以FlaxBaseModelOutputWithPastAndCrossAttentions对象的形式返回结果
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


# 定义一个名为FlaxMarianEncoder的类,作为Marian模型的编码器
class FlaxMarianEncoder(nn.Module):
    config: MarianConfig
    embed_tokens: nn.Embed
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    # 初始化模型的设置,包括dropout层和embedding相关的参数设置
    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 设置embedding的维度
        embed_dim = self.config.d_model
        # 设置最大的位置编码长度
        self.max_source_positions = self.config.max_position_embeddings
        # 如果设置了scale_embedding标志位,则对embedding进行缩放
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0

        # 创建sinusoidal位置编码矩阵
        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)
        # 初始化encoder层集合
        self.layers = FlaxMarianEncoderLayerCollection(self.config, self.dtype)

    # 模型的调用方法,输入参数和返回类型可选
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 获取输入的形状信息
        input_shape = input_ids.shape
        # 重新整形输入id
        input_ids = input_ids.reshape(-1, input_shape[-1])

        # 对输入id进行embedding并缩放
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        # 根据位置id从预先创建的位置编码中取出对应的位置信息
        positions = jnp.take(self.embed_positions, position_ids, axis=0)
        # 明确地将位置信息的数据类型转换为和输入embedding相同的数据类型
        positions = positions.astype(inputs_embeds.dtype)

        # 将embedding和位置信息相加得到最终的隐藏状态表示
        hidden_states = inputs_embeds + positions
        # 应用dropout层到隐藏状态
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        # 调用模型的encoder层进行前向传播
        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果不要求返回字典形式的输出,则直接返回模型的outputs对象
        if not return_dict:
            return outputs

        # 返回以FlaxBaseModelOutput对象封装的输出结果,包括最终的隐藏状态、所有隐藏状态以及注意力分布
        return FlaxBaseModelOutput(
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
class FlaxMarianDecoder(nn.Module):
    config: MarianConfig  # 类型注解,指定config属性为MarianConfig类型
    embed_tokens: nn.Embed  # 类型注解,指定embed_tokens属性为nn.Embed类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型,默认为jnp.float32

    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)  # 初始化dropout层,使用config中的dropout率

        embed_dim = self.config.d_model  # 获取config中的d_model作为嵌入维度
        self.max_target_positions = self.config.max_position_embeddings  # 设置最大目标位置为config中的max_position_embeddings
        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0  # 根据scale_embedding标志设置嵌入缩放因子

        self.embed_positions = create_sinusoidal_positions(self.config.max_position_embeddings, embed_dim)  # 创建正弦位置编码
        self.layers = FlaxMarianDecoderLayerCollection(self.config, self.dtype)  # 初始化解码器层集合

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        input_shape = input_ids.shape  # 获取输入张量的形状
        input_ids = input_ids.reshape(-1, input_shape[-1])  # 将输入张量重新形状为二维张量

        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale  # 使用嵌入令牌和缩放因子对输入进行嵌入

        # 嵌入位置信息
        positions = jnp.take(self.embed_positions, position_ids, axis=0)
        # 明确地将位置转换为与inputs_embeds相同的数据类型,因为self.embed_positions未注册为参数
        positions = positions.astype(inputs_embeds.dtype)

        hidden_states = inputs_embeds + positions  # 将嵌入的输入和位置编码相加

        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)  # 应用dropout层

        outputs = self.layers(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # 将hidden_states传递给解码器层进行处理

        if not return_dict:
            return outputs

        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )  # 如果return_dict为True,则返回带有注意力信息的输出


class FlaxMarianModule(nn.Module):
    config: MarianConfig  # 类型注解,指定config属性为MarianConfig类型
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型,默认为jnp.float32

    def setup(self):
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )  # 初始化共享的嵌入层,使用config中的词汇大小和d_model,并使用正态分布初始化器初始化

        self.encoder = FlaxMarianEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)  # 初始化编码器
        self.decoder = FlaxMarianDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)  # 初始化解码器

    def _get_encoder_module(self):
        return self.encoder  # 返回编码器模块
    # 返回解码器模块对象
    def _get_decoder_module(self):
        return self.decoder

    # 实现调用操作,执行序列到序列模型的前向传播
    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 使用编码器模型处理输入序列
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 使用解码器模型处理目标序列
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 如果不需要返回字典格式的输出,则将编码器和解码器的输出拼接并返回
        if not return_dict:
            return decoder_outputs + encoder_outputs

        # 返回序列到序列模型的输出对象,其中包含解码器和编码器的相关隐藏状态和注意力权重
        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
    # 使用 MarianConfig 作为配置类
    config_class = MarianConfig
    # 基础模型前缀为 "model"
    base_model_prefix: str = "model"
    # 模块类暂未定义
    module_class: nn.Module = None

    def __init__(
        self,
        config: MarianConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 创建模块实例,传入配置和其他参数
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类构造函数初始化模型
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量 input_ids
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 设置 input_ids 的最后一个位置为 eos_token_id
        input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
        # 初始化 attention_mask 为全1的张量
        attention_mask = jnp.ones_like(input_ids)
        # 将 decoder_input_ids 初始化为 input_ids
        decoder_input_ids = input_ids
        # 将 decoder_attention_mask 初始化为全1的张量
        decoder_attention_mask = jnp.ones_like(input_ids)

        # 获取 input_ids 的形状信息
        batch_size, sequence_length = input_ids.shape
        # 生成 position_ids,广播形状为 (batch_size, sequence_length)
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # 生成 decoder_position_ids,广播形状为 (batch_size, sequence_length)
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 分割随机数生成器 rng,返回 params_rng 和 dropout_rng
        params_rng, dropout_rng = jax.random.split(rng)
        # 构建随机数字典 rngs,包含 params_rng 和 dropout_rng
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模块的初始化方法初始化模型参数,返回随机生成的参数 random_params
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        # 如果传入了预定义的参数 params
        if params is not None:
            # 展平 random_params 和 params
            random_params = flatten_dict(unfreeze(random_params))
            params = flatten_dict(unfreeze(params))
            # 处理缺失的键
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            # 清空缺失键集合
            self._missing_keys = set()
            # 冻结并返回 params
            return freeze(unflatten_dict(params))
        else:
            # 返回随机生成的参数 random_params
            return random_params
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        """
        # 初始化用于检索缓存的输入变量
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        # 创建与decoder_input_ids相同形状的全1张量,用作解码器的注意力遮罩
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        # 使用广播方式生成位置编码,形状与decoder_input_ids相同
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
        )

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            # 获取解码器模块
            decoder_module = module._get_decoder_module()
            # 调用解码器模块进行前向传播
            return decoder_module(decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs)

        # 使用给定的输入参数初始化模型的变量
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,  # 只需调用解码器以初始化缓存
        )
        # 返回解冻后的初始化变量中的缓存部分
        return unfreeze(init_variables["cache"])



    @add_start_docstrings(MARIAN_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=MarianConfig)
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    @add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=MarianConfig)
    # 使用指定的文档字符串注解这个方法,将其标记为用于解码的函数,并替换返回值的文档字符串
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FlaxMarianMTModel

        >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
        >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=64, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)
        ```
        
        Defines whether to output attentions or not. Defaults to `True` if `output_attentions` is not `None`, else `False`.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        
        Defines whether to output hidden states or not. Defaults to `True` if `output_hidden_states` is not `None`, else `False`.
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        
        Defines whether to return outputs as a dictionary. Defaults to `True` if `return_dict` is not `None`, else `False`.
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # If no attention mask is provided, create one with all elements set to 1.
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        
        # If no position ids are provided, generate them based on input_ids dimensions.
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # Define the function to perform the forward pass through the encoder.
        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        # Apply the encoder module on the input data and return the outputs.
        return self.module.apply(
            {"params": params or self.params},
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )
    # 定义一个特殊方法 __call__,使实例对象可以像函数一样调用
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 确定是否输出注意力权重信息,默认使用配置中的设置
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 确定是否输出隐藏状态信息,默认使用配置中的设置
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否返回字典格式的输出,默认使用配置中的设置
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 准备编码器的输入
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)  # 如果未提供注意力掩码,则创建一个全为1的数组作为掩码
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
            # 如果未提供位置编码,则根据输入的长度创建位置编码

        # 准备解码器的输入
        if decoder_input_ids is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
            )
            # 如果未提供解码器输入的令牌 ID,则通过右移输入的令牌来创建解码器输入
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
            # 如果未提供解码器的注意力掩码,则创建一个全为1的数组作为掩码
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )
            # 如果未提供解码器的位置编码,则根据解码器输入的长度创建位置编码

        # 处理可能需要的任何随机数生成器
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        # 调用内部模块的 apply 方法,传递参数和所有输入,以执行模型的前向传播
        return self.module.apply(
            {"params": params or self.params},  # 传递模型参数
            input_ids=jnp.array(input_ids, dtype="i4"),  # 转换输入令牌 ID 为 JAX 数组
            attention_mask=jnp.array(attention_mask, dtype="i4"),  # 转换注意力掩码为 JAX 数组
            position_ids=jnp.array(position_ids, dtype="i4"),  # 转换位置编码为 JAX 数组
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),  # 转换解码器输入为 JAX 数组
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),  # 转换解码器的注意力掩码为 JAX 数组
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),  # 转换解码器的位置编码为 JAX 数组
            output_attentions=output_attentions,  # 是否输出注意力权重信息
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态信息
            return_dict=return_dict,  # 是否返回字典格式的输出
            deterministic=not train,  # 是否处于推理模式
            rngs=rngs,  # 随机数生成器的字典
        )
@add_start_docstrings(
    "The bare Marian Model transformer outputting raw hidden-states without any specific head on top.",
    MARIAN_START_DOCSTRING,
)
class FlaxMarianModel(FlaxMarianPreTrainedModel):
    config: MarianConfig
    dtype: jnp.dtype = jnp.float32  # 计算时使用的数据类型
    module_class = FlaxMarianModule


append_call_sample_docstring(FlaxMarianModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)


class FlaxMarianMTModule(nn.Module):
    config: MarianConfig
    dtype: jnp.dtype = jnp.float32
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros

    def setup(self):
        self.model = FlaxMarianModule(config=self.config, dtype=self.dtype)  # 初始化Marian模型
        self.lm_head = nn.Dense(
            self.model.shared.num_embeddings,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )  # 初始化语言模型头部,用于生成LM预测
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))  # 初始化最终的logits偏置项

    def _get_encoder_module(self):
        return self.model.encoder  # 返回编码器模块

    def _get_decoder_module(self):
        return self.model.decoder  # 返回解码器模块

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )  # 调用Marian模型进行正向传播计算

        hidden_states = outputs[0]  # 获取模型输出的隐藏状态

        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["shared"]["embedding"]
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
        else:
            lm_logits = self.lm_head(hidden_states)  # 计算LM预测的logits

        lm_logits += self.final_logits_bias.astype(self.dtype)  # 加上最终的logits偏置项

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return output  # 返回LM预测logits和其他可能的输出

        return FlaxSeq2SeqLMOutput(
            logits=lm_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )  # 返回Seq2Seq LM模型的输出
# 添加模型文档字符串,标识该类为带语言建模头的MARIAN模型,可用于翻译任务
@add_start_docstrings(
    "The MARIAN Model with a language modeling head. Can be used for translation.", MARIAN_START_DOCSTRING
)
# 定义FlaxMarianMTModel类,继承自FlaxMarianPreTrainedModel类
class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
    # 指定模块类为FlaxMarianMTModule
    module_class = FlaxMarianMTModule
    # 数据类型为32位浮点数
    dtype: jnp.dtype = jnp.float32

    # 添加解码函数的文档字符串,引用MARIAN_DECODE_INPUTS_DOCSTRING
    @add_start_docstrings(MARIAN_DECODE_INPUTS_DOCSTRING)
    # 替换返回值文档字符串,指定输出类型为FlaxCausalLMOutputWithCrossAttentions,配置类为MarianConfig
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=MarianConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 此函数用于调整logits,确保不生成填充标记
        def _adapt_logits_for_beam_search(self, logits):
            """This function enforces the padding token never to be generated."""
            logits = logits.at[:, :, self.config.pad_token_id].set(float("-inf"))
            return logits

        # 准备生成所需的输入
        def prepare_inputs_for_generation(
            self,
            decoder_input_ids,
            max_length,
            attention_mask: Optional[jax.Array] = None,
            decoder_attention_mask: Optional[jax.Array] = None,
            encoder_outputs=None,
            **kwargs,
        ):
            # 初始化缓存
            batch_size, seq_length = decoder_input_ids.shape

            # 使用init_cache方法初始化past_key_values
            past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)

            # 通常情况下,需要在attention_mask中为x > input_ids.shape[-1]和x < cache_length的位置放置0,
            # 但由于解码器使用因果掩码,这些位置已经被掩码了。
            # 因此,可以在此处创建一个静态的attention_mask,这对编译更加高效。
            extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
            if decoder_attention_mask is not None:
                position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
                extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
            else:
                position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

            return {
                "past_key_values": past_key_values,
                "encoder_outputs": encoder_outputs,
                "encoder_attention_mask": attention_mask,
                "decoder_attention_mask": extended_attention_mask,
                "decoder_position_ids": position_ids,
            }

        # 更新生成过程的输入
        def update_inputs_for_generation(self, model_outputs, model_kwargs):
            model_kwargs["past_key_values"] = model_outputs.past_key_values
            model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
            return model_kwargs
# 定义一个多行字符串常量,包含函数 `FlaxMarianMTModel` 的文档字符串。
FLAX_MARIAN_MT_DOCSTRING = """
    Returns:

    Example:

    ```
    >>> from transformers import AutoTokenizer, FlaxMarianMTModel

    >>> model = FlaxMarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
    >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")

    >>> text = "My friends are cool but they eat too many carbs."
    >>> input_ids = tokenizer(text, max_length=64, return_tensors="jax").input_ids

    >>> sequences = model.generate(input_ids, max_length=64, num_beams=2).sequences

    >>> outputs = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    >>> # should give *Meine Freunde sind cool, aber sie essen zu viele Kohlenhydrate.*
    ```
"""

# 调用 `overwrite_call_docstring` 函数,将 `FlaxMarianMTModel` 的文档字符串修改为
# 既有的 `MARIAN_INPUTS_DOCSTRING` 和 `FLAX_MARIAN_MT_DOCSTRING` 的组合。
overwrite_call_docstring(
    FlaxMarianMTModel,
    MARIAN_INPUTS_DOCSTRING + FLAX_MARIAN_MT_DOCSTRING,
)

# 调用 `append_replace_return_docstrings` 函数,修改 `FlaxMarianMTModel` 类的返回文档字符串,
# 设置输出类型为 `FlaxSeq2SeqLMOutput`,配置类为 `_CONFIG_FOR_DOC`。
append_replace_return_docstrings(FlaxMarianMTModel, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)

.\models\marian\modeling_marian.py

# coding=utf-8
# 版权 2021 年 Marian Team 作者和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据“现状”提供软件
# 分发,无论是明示的还是暗示的,但是没有任何担保或条件
# 特定用途的适用性,包括但不限于对适销性和特定用途的适用性的隐含保证。
"""从 Marian C++ 仓库移植的 PyTorch MarianMTModel 模型。"""


import copy
import math
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
    BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_marian import MarianConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "MarianConfig"
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"


MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "Helsinki-NLP/opus-mt-en-de",
    # 查看所有 Marian 模型请访问 https://huggingface.co/models?filter=marian
]


# 从 transformers.models.bart.modeling_bart.shift_tokens_right 复制而来
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    将输入的 token 向右移动一位。
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id 必须被定义。")
    # 将标签中可能存在的 -100 值替换为 `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


class MarianSinusoidalPositionalEmbedding(nn.Embedding):
    """此模块生成任意长度的正弦位置嵌入。"""

    def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
        super().__init__(num_positions, embedding_dim)
        self.weight = self._init_weight(self.weight)

    @staticmethod
    def _init_weight(out: nn.Parameter) -> nn.Parameter:
        """
        初始化权重矩阵,类似于 XLM 的 create_sinusoidal_embeddings 函数,但特征没有交错。
        余弦特征位于向量的后半部分。[dim // 2:]
        """
        n_pos, dim = out.shape
        # 创建位置编码矩阵,使用 numpy 数组生成,用于 Transformer 的位置编码
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        # 设置 requires_grad 属性为 False,以避免在 PyTorch 1.8+ 版本中出现错误
        out.requires_grad = False
        # 计算中间分隔点位置,处理偶数和奇数维度的情况
        sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
        # 将正弦值填充到权重矩阵的前半部分
        out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        # 将余弦值填充到权重矩阵的后半部分
        out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        # 将权重矩阵从计算图中分离出来,不再进行梯度计算
        out.detach_()
        return out

    @torch.no_grad()
    def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
        """
        前向传播函数,用于计算位置编码的张量。

        `input_ids_shape` 应该是 [bsz x seqlen] 的形状。
        """
        bsz, seq_len = input_ids_shape[:2]
        # 根据序列长度和过去键值对长度计算位置编码的位置张量
        positions = torch.arange(
            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
        )
        # 调用父类的 forward 方法,返回位置编码的张量
        return super().forward(positions)
# Copied from transformers.models.marian.modeling_marian.MarianAttention with Marian->Marian
class MarianAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        is_causal: bool = False,
        config: Optional[MarianConfig] = None,
    ):
        super().__init__()
        self.embed_dim = embed_dim  # 设置注意力机制的嵌入维度
        self.num_heads = num_heads  # 设置注意力头的数量
        self.dropout = dropout  # 设置dropout比率
        self.head_dim = embed_dim // num_heads  # 计算每个注意力头的维度
        self.config = config  # 存储Marian配置信息

        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于注意力权重计算
        self.is_decoder = is_decoder  # 是否为解码器
        self.is_causal = is_causal  # 是否为因果注意力

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 创建键的投影层
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 创建值的投影层
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 创建查询的投影层
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)  # 创建输出的投影层

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ):
        # 前向传播函数,实现注意力机制的计算
        ...



# Copied from transformers.models.marian.modeling_marian.MarianEncoderLayer with Marian->Marian, BART->MARIAN
class MarianEncoderLayer(nn.Module):
    def __init__(self, config: MarianConfig):
        super().__init__()
        self.embed_dim = config.d_model  # 设置编码器层的嵌入维度

        self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](  # 创建自注意力层
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
            config=config,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)  # 创建自注意力层的LayerNorm层
        self.dropout = config.dropout  # 设置dropout比率
        self.activation_fn = ACT2FN[config.activation_function]  # 激活函数
        self.activation_dropout = config.activation_dropout  # 激活函数的dropout比率
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)  # 第一个全连接层
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)  # 第二个全连接层
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)  # 最终的LayerNorm层

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: torch.FloatTensor,
        layer_head_mask: torch.FloatTensor,
        output_attentions: Optional[bool] = False,
    ):
        # 前向传播函数,实现编码器层的计算
        ...
    ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states  # 保存原始输入以便后续的残差连接
        hidden_states, attn_weights, _ = self.self_attn(  # 使用自注意力机制进行计算
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 在计算后的隐藏状态上应用 dropout
        hidden_states = residual + hidden_states  # 残差连接
        hidden_states = self.self_attn_layer_norm(hidden_states)  # 对残差连接后的隐藏状态进行 layer normalization

        residual = hidden_states  # 保存当前层处理后的结果以便后续的残差连接
        hidden_states = self.activation_fn(self.fc1(hidden_states))  # 使用激活函数处理线性转换
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)  # 应用 dropout
        hidden_states = self.fc2(hidden_states)  # 进行第二个线性转换
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)  # 再次应用 dropout
        hidden_states = residual + hidden_states  # 残差连接
        hidden_states = self.final_layer_norm(hidden_states)  # 最终的 layer normalization

        if hidden_states.dtype == torch.float16 and (
            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)  # 处理潜在的浮点数溢出或 NaN 值

        outputs = (hidden_states,)  # 输出最终的隐藏状态作为主要结果

        if output_attentions:
            outputs += (attn_weights,)  # 如果需要输出注意力权重,则将其添加到输出中

        return outputs  # 返回输出结果
# 定义一个字典,将字符串 "eager" 映射到类 MarianAttention
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}


# 从 transformers.models.bart.modeling_bart.BartDecoderLayer 复制而来,将 Bart 替换为 Marian,BART 替换为 MARIAN
class MarianDecoderLayer(nn.Module):
    def __init__(self, config: MarianConfig):
        super().__init__()
        # 设置嵌入维度为配置中的 d_model
        self.embed_dim = config.d_model

        # 初始化自注意力机制,根据配置选择不同的实现类
        self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            is_causal=True,
            config=config,
        )

        # 设置 dropout 概率
        self.dropout = config.dropout
        # 设置激活函数为配置中指定的激活函数
        self.activation_fn = ACT2FN[config.activation_function]
        # 设置激活函数的 dropout 概率
        self.activation_dropout = config.activation_dropout

        # 初始化自注意力机制的 LayerNorm
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 初始化编码器注意力机制,根据配置选择不同的实现类
        self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
            config=config,
        )
        # 初始化编码器注意力机制的 LayerNorm
        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        # 定义第一个线性层,将嵌入维度映射到配置中指定的解码器前馈神经网络维度
        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
        # 定义第二个线性层,将解码器前馈神经网络维度映射回嵌入维度
        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)

        # 初始化最终层的 LayerNorm
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = True,
        ):
        # 此处需要补充具体的前向传播逻辑,但根据代码结构,主要负责模型层的连接与数据流动
        pass


# 定义 MarianPreTrainedModel 类,继承自 PreTrainedModel
class MarianPreTrainedModel(PreTrainedModel):
    # 设置配置类为 MarianConfig
    config_class = MarianConfig
    # 设置基础模型前缀为 "model"
    base_model_prefix = "model"
    # 支持梯度检查点
    supports_gradient_checkpointing = True

    def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
        # 初始化权重函数,根据模块类型选择不同的初始化方法
        std = self.config.init_std
        if isinstance(module, nn.Linear):
            # 对线性层进行正态分布初始化
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                # 如果有偏置,则将偏置初始化为零
                module.bias.data.zero_()
        elif isinstance(module, MarianSinusoidalPositionalEmbedding):
            # 对 MarianSinusoidalPositionalEmbedding 类型不进行任何操作
            pass
        elif isinstance(module, nn.Embedding):
            # 对嵌入层进行正态分布初始化
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                # 如果有填充索引,则将填充索引的权重初始化为零
                module.weight.data[module.padding_idx].zero_()

    @property
    # 定义一个方法用于生成虚拟输入数据
    def dummy_inputs(self):
        # 获取配置中的填充标记 ID
        pad_token = self.config.pad_token_id
        # 创建包含两个示例序列的张量,设备为当前对象的设备
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        # 构建虚拟输入字典,包括注意力掩码、输入序列和解码器输入序列
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),  # 使用填充标记 ID 生成注意力掩码
            "input_ids": input_ids,  # 将输入序列添加到字典中
            "decoder_input_ids": input_ids,  # 将解码器输入序列添加到字典中,与输入序列相同
        }
        # 返回生成的虚拟输入字典
        return dummy_inputs
MARIAN_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`MarianConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

MARIAN_GENERATION_EXAMPLE = r"""
    Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available
    models are listed [here](https://huggingface.co/models?search=Helsinki-NLP).

    Examples:

    ```
    >>> from transformers import AutoTokenizer, MarianMTModel

    >>> src = "fr"  # source language
    >>> trg = "en"  # target language

    >>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
    >>> model = MarianMTModel.from_pretrained(model_name)
    >>> tokenizer = AutoTokenizer.from_pretrained(model_name)

    >>> sample_text = "où est l'arrêt de bus ?"
    >>> batch = tokenizer([sample_text], return_tensors="pt")

    >>> generated_ids = model.generate(**batch)
    >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    "Where's the bus stop?"
    ```
"""

MARIAN_INPUTS_DOCSTRING = r"""
"""


class MarianEncoder(MarianPreTrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`MarianEncoderLayer`].

    Args:
        config: MarianConfig
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
        embed_tokens (nn.Embedding): output embedding
            A PyTorch embedding layer representing the output embeddings of the model.
    """
    # 初始化函数,接受一个配置对象 config 和一个可选的嵌入词向量 embed_tokens
    def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
        # 调用父类的初始化方法
        super().__init__(config)

        # 设置 dropout 概率为配置中的值
        self.dropout = config.dropout
        # 设置 encoder 层级的 dropout 概率为配置中的值
        self.layerdrop = config.encoder_layerdrop

        # 从配置中获取词嵌入的维度
        embed_dim = config.d_model
        # 获取填充索引
        self.padding_idx = config.pad_token_id
        # 获取最大源序列位置
        self.max_source_positions = config.max_position_embeddings
        # 如果配置中设置了缩放嵌入,则设置嵌入缩放因子为 sqrt(embed_dim),否则为 1.0
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        # 如果传入了 embed_tokens,则使用传入的词嵌入,否则创建一个新的 nn.Embedding 对象
        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)

        # 使用 MarianSinusoidalPositionalEmbedding 类创建位置嵌入对象,设置最大位置和嵌入维度
        self.embed_positions = MarianSinusoidalPositionalEmbedding(
            config.max_position_embeddings, embed_dim, self.padding_idx
        )
        
        # 创建一个包含多个 MarianEncoderLayer 的列表,数量为配置中指定的编码器层数
        self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False

        # 调用初始化后的后处理方法
        self.post_init()

    # 返回当前模型的词嵌入对象
    def get_input_embeddings(self):
        return self.embed_tokens

    # 设置当前模型的词嵌入对象为给定的值
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # 前向传播函数,接受多个参数,处理输入序列以生成输出
    def forward(
        self,
        input_ids: torch.LongTensor = None,  # 输入的词 id 序列,默认为 None
        attention_mask: Optional[torch.LongTensor] = None,  # 注意力掩码,默认为 None
        head_mask: Optional[torch.Tensor] = None,  # 头部掩码,默认为 None
        inputs_embeds: Optional[torch.FloatTensor] = None,  # 输入的嵌入向量,默认为 None
        output_attentions: Optional[bool] = None,  # 是否输出注意力权重,默认为 None
        output_hidden_states: Optional[bool] = None,  # 是否输出隐藏状态,默认为 None
        return_dict: Optional[bool] = None,  # 是否以字典形式返回结果,默认为 None
class MarianDecoder(MarianPreTrainedModel):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MarianDecoderLayer`]

    Args:
        config: MarianConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)
        self.dropout = config.dropout                  # 初始化 dropout 概率
        self.layerdrop = config.decoder_layerdrop      # 初始化层间 dropout 概率
        self.padding_idx = config.pad_token_id         # 初始化填充 token 的索引
        self.max_target_positions = config.max_position_embeddings  # 最大目标位置
        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0  # 嵌入缩放系数

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens          # 如果提供了嵌入 tokens,则使用提供的
        else:
            self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
                                                    # 否则创建一个新的嵌入 tokens 对象

        self.embed_positions = MarianSinusoidalPositionalEmbedding(
            config.max_position_embeddings, config.d_model, self.padding_idx
        )                                           # 使用正弦位置编码初始化位置嵌入对象

        self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
                                                    # 创建多层解码器层列表

        self.gradient_checkpointing = False          # 梯度检查点,默认为 False

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens                     # 返回输入嵌入 tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value                    # 设置输入嵌入 tokens

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """
        Forward pass for the MarianDecoder module.

        Args:
            input_ids (torch.LongTensor): Input token IDs
            attention_mask (torch.Tensor): Attention mask for masking out padded tokens
            encoder_hidden_states (torch.FloatTensor): Hidden states from the encoder
            encoder_attention_mask (torch.LongTensor): Attention mask for encoder's hidden states
            head_mask (torch.Tensor): Mask for heads in the self-attention layers
            cross_attn_head_mask (torch.Tensor): Mask for heads in the cross-attention layers
            past_key_values (Tuple[Tuple[torch.FloatTensor]]): Cached key-value pairs for fast decoding
            inputs_embeds (torch.FloatTensor): Optional tensor of embedded inputs
            use_cache (bool): Whether to use cached key-value pairs
            output_attentions (bool): Whether to output attentions
            output_hidden_states (bool): Whether to output hidden states
            return_dict (bool): Whether to return a dictionary as output

        Returns:
            Various outputs depending on the configuration (return_dict or not)
        """
        # Forward pass logic will be implemented here in subsequent code
    def __init__(self, config: MarianConfig):
        super().__init__(config)

        padding_idx, vocab_size = config.pad_token_id, config.vocab_size

        # We always use self.shared for token embeddings to ensure compatibility with all marian models
        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
        if self.config.share_encoder_decoder_embeddings:
            # If embeddings are shared between encoder and decoder, use the same instance
            encoder_embed_tokens = decoder_embed_tokens = self.shared
        else:
            # If embeddings are not shared, create separate instances for encoder and decoder
            # to ensure they are not tied.
            encoder_embed_tokens = copy.deepcopy(self.shared)
            decoder_embed_tokens = copy.deepcopy(self.shared)
            self.shared = None  # Reset self.shared to None for separate embeddings

        # Initialize encoder and decoder with respective embeddings
        self.encoder = MarianEncoder(config, encoder_embed_tokens)
        self.decoder = MarianDecoder(config, decoder_embed_tokens)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        # Returns the shared embeddings if they are shared, otherwise returns encoder embeddings
        return self.get_encoder().get_input_embeddings()

    def set_input_embeddings(self, value):
        if self.config.share_encoder_decoder_embeddings:
            # If embeddings are shared, set the shared instance for both encoder and decoder
            self.shared = value
            self.encoder.embed_tokens = self.shared
            self.decoder.embed_tokens = self.shared
        else:
            # If embeddings are not shared, only set encoder embeddings
            self.encoder.embed_tokens = value

    def get_decoder_input_embeddings(self):
        if self.config.share_encoder_decoder_embeddings:
            # Error if decoder embeddings are accessed when they are shared with encoder
            raise ValueError(
                "`get_decoder_input_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
                "is `True`. Please use `get_input_embeddings` instead."
            )
        # Return decoder embeddings (should not be reached if embeddings are shared)
        return self.get_decoder().get_input_embeddings()

    def set_decoder_input_embeddings(self, value):
        if self.config.share_encoder_decoder_embeddings:
            # Error if trying to set decoder embeddings when they are shared with encoder
            raise ValueError(
                "`config.share_encoder_decoder_embeddings` is set to `True` meaning the decoder input embeddings "
                "are shared with the encoder. In order to set the decoder input embeddings, you should simply set "
                "the encoder input embeddings by calling `set_input_embeddings` with the appropriate embeddings."
            )
        # Set decoder embeddings (should not be reached if embeddings are shared)
        self.decoder.embed_tokens = value

    def get_encoder(self):
        # Returns the encoder instance
        return self.encoder

    def get_decoder(self):
        # Returns the decoder instance
        return self.decoder
    @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
    # 使用指定的文档字符串装饰器来添加输入参数的描述信息,此处为Marian模型的输入说明文档字符串
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
    # 使用指定的文档字符串装饰器来替换返回值的描述信息,指定输出类型为Seq2SeqModelOutput,并使用_CONFIG_FOR_DOC类的配置信息

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        # 输入序列的token IDs,数据类型为torch的LongTensor,默认为None
        attention_mask: Optional[torch.Tensor] = None,
        # 注意力掩码,用于指示模型在哪些位置上需要注意力,数据类型为可选的torch.Tensor,默认为None
        decoder_input_ids: Optional[torch.LongTensor] = None,
        # 解码器的输入token IDs,数据类型为可选的torch.LongTensor,默认为None
        decoder_attention_mask: Optional[torch.Tensor] = None,
        # 解码器的注意力掩码,数据类型为可选的torch.Tensor,默认为None
        head_mask: Optional[torch.Tensor] = None,
        # 头部掩码,用于控制层间的连接,数据类型为可选的torch.Tensor,默认为None
        decoder_head_mask: Optional[torch.Tensor] = None,
        # 解码器头部掩码,用于控制解码器层间的连接,数据类型为可选的torch.Tensor,默认为None
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        # 跨注意力头部掩码,用于跨模块(encoder-decoder)的头部连接,数据类型为可选的torch.Tensor,默认为None
        encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
        # 编码器的输出,数据类型为可选的Union类型(包括元组或BaseModelOutput),默认为None
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        # 过去的键值对,用于缓存解码器的过去状态信息,数据类型为可选的元组类型(包含元组的torch.FloatTensor),默认为None
        inputs_embeds: Optional[torch.FloatTensor] = None,
        # 输入嵌入向量,数据类型为可选的torch.FloatTensor,默认为None
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        # 解码器输入嵌入向量,数据类型为可选的torch.FloatTensor,默认为None
        use_cache: Optional[bool] = None,
        # 是否使用缓存,用于控制是否缓存中间状态以加快推理速度,数据类型为可选的布尔值,默认为None
        output_attentions: Optional[bool] = None,
        # 是否输出注意力权重,数据类型为可选的布尔值,默认为None
        output_hidden_states: Optional[bool] = None,
        # 是否输出隐藏状态,数据类型为可选的布尔值,默认为None
        return_dict: Optional[bool] = None,
        # 是否返回字典格式的输出,数据类型为可选的布尔值,默认为None
@add_start_docstrings(
    "The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING
)
# 定义了一个继承自MarianPreTrainedModel的类MarianMTModel,用于Marian模型的语言建模任务和摘要生成任务
class MarianMTModel(MarianPreTrainedModel):
    # 指定基础模型的前缀
    base_model_prefix = "model"
    # 在加载过程中忽略的键列表
    _keys_to_ignore_on_load_missing = [
        "final_logits_bias",
        "encoder.embed_positions.weight",
        "decoder.embed_positions.weight",
    ]
    # 在保存过程中忽略的键列表
    _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
    # 共享权重的键列表
    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]

    # 初始化函数,接受一个MarianConfig类型的配置对象
    def __init__(self, config: MarianConfig):
        # 调用父类的初始化方法
        super().__init__(config)
        # 创建一个MarianModel类型的模型实例
        self.model = MarianModel(config)

        # 根据配置决定目标词汇表大小,用于创建final_logits_bias缓冲区
        target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
        # 注册一个缓冲区final_logits_bias,全零初始化,形状为(1, target_vocab_size)
        self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
        # 创建一个线性层lm_head,用于生成模型的最终输出,输入维度为config.d_model,输出维度为target_vocab_size,无偏置
        self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)

        # 执行初始化权重和应用最终处理的函数
        self.post_init()

    # 获取编码器的方法
    def get_encoder(self):
        return self.model.get_encoder()

    # 获取解码器的方法
    def get_decoder(self):
        return self.model.get_decoder()

    # 调整token嵌入大小的方法
    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
        # 调用父类方法resize_token_embeddings进行token嵌入大小的调整
        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
        # 如果共享编码器和解码器嵌入,则调整final_logits_bias的大小
        if self.config.share_encoder_decoder_embeddings:
            self._resize_final_logits_bias(new_num_tokens)
        return new_embeddings

    # 内部方法,用于调整token嵌入大小
    def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding:
        # 获取当前输入嵌入
        old_embeddings = self.get_input_embeddings()
        # 调用内部方法_get_resized_embeddings进行嵌入的调整
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
        # 设置调整后的输入嵌入
        self.set_input_embeddings(new_embeddings)

        # 更新config.decoder_vocab_size,如果嵌入被绑定
        new_num_tokens = new_embeddings.weight.shape[0]
        if self.config.share_encoder_decoder_embeddings:
            self.config.decoder_vocab_size = new_num_tokens

        # 如果单词嵌入未绑定,则确保lm head也被调整大小
        if (
            self.config.share_encoder_decoder_embeddings
            and self.get_output_embeddings() is not None
            and not self.config.tie_word_embeddings
        ):
            # 获取当前输出嵌入
            old_lm_head = self.get_output_embeddings()
            # 调用内部方法_get_resized_lm_head进行lm head的调整
            new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
            # 设置调整后的输出嵌入(lm head)
            self.set_output_embeddings(new_lm_head)

        return self.get_input_embeddings()
    # 调整解码器的 token embeddings 大小
    def resize_decoder_token_embeddings(self, new_num_tokens):
        # 如果配置中指定共享编码器和解码器的 embeddings,则抛出数值错误
        if self.config.share_encoder_decoder_embeddings:
            raise ValueError(
                "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
                "is `True`. Please use `resize_token_embeddings` instead."
            )

        # 获取当前解码器的输入 embeddings
        old_embeddings = self.model.get_decoder_input_embeddings()
        # 根据新的 token 数量调整 embeddings 大小
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        # 设置调整后的 embeddings 到模型中
        self.model.set_decoder_input_embeddings(new_embeddings)

        # 如果输出 embeddings 存在且不与输入 embeddings 绑定,确保语言模型头也被调整大小
        if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
            # 获取当前语言模型头
            old_lm_head = self.get_output_embeddings()
            # 根据新的 token 数量调整语言模型头的大小
            new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
            # 设置调整后的语言模型头
            self.set_output_embeddings(new_lm_head)

        # 获取调整后的解码器输入 embeddings
        model_embeds = self.model.get_decoder_input_embeddings()

        # 如果新的 token 数量为 None,则返回当前的 embeddings
        if new_num_tokens is None:
            return model_embeds

        # 更新基础模型和当前模型配置中的解码器词汇表大小
        self.config.decoder_vocab_size = new_num_tokens

        # 如果需要,重新绑定权重
        self.tie_weights()

        # 调整最终 logits 偏置
        self._resize_final_logits_bias(new_num_tokens)

        # 返回调整后的解码器输入 embeddings
        return model_embeds

    # 调整最终 logits 的偏置
    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
        # 获取当前 logits 偏置的旧 token 数量
        old_num_tokens = self.final_logits_bias.shape[-1]
        # 如果新的 token 数量小于等于旧的 token 数量,则截取当前偏置
        if new_num_tokens <= old_num_tokens:
            new_bias = self.final_logits_bias[:, :new_num_tokens]
        else:
            # 否则,创建额外的偏置,并将其与当前偏置连接起来
            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
        # 将调整后的偏置注册为模型的缓冲区
        self.register_buffer("final_logits_bias", new_bias)

    # 获取输出 embeddings(语言模型头)
    def get_output_embeddings(self):
        return self.lm_head

    # 设置输出 embeddings(语言模型头)
    def set_output_embeddings(self, new_embeddings: nn.Embedding):
        self.lm_head = new_embeddings
    def tie_weights(self):
        """
        Tie the weights between the input embeddings and the output embeddings.

        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
        weights instead.
        """
        # 获取输出嵌入层
        output_embeddings = self.get_output_embeddings()
        # 检查是否存在输出嵌入层,并且是否允许共享参数
        if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
            # 获取解码器的输入嵌入层(如果嵌入层被共享,返回共享的嵌入层;否则返回解码器的embed_tokens)
            word_embeddings = self.get_decoder().get_input_embeddings()
            # 调用函数来共享或克隆权重
            self._tie_or_clone_weights(output_embeddings, word_embeddings)

        # 如果模型是编码-解码结构并且配置允许编码器和解码器共享权重
        if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
            # 如果对象有基础模型前缀,则将当前实例设置为基础模型的实例
            if hasattr(self, self.base_model_prefix):
                self = getattr(self, self.base_model_prefix)
            # 调用函数来共享编码器和解码器的权重
            self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)

        # 遍历模型的所有模块
        for module in self.modules():
            # 如果模块有 `_tie_weights` 方法,则调用该方法
            if hasattr(module, "_tie_weights"):
                module._tie_weights()

    @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    @add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ) -> Seq2SeqLMOutput:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            `Seq2SeqLMOutput`: A class representing the outputs of the Seq2Seq language model.

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # If labels are provided, adjust `use_cache` and prepare `decoder_input_ids` if necessary
        if labels is not None:
            # Issue a warning and set `use_cache` to False if `labels` are provided
            if use_cache:
                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
            use_cache = False
            
            # If `decoder_input_ids` and `decoder_inputs_embeds` are not provided, prepare `decoder_input_ids`
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                # Shift the labels to the right to align with decoder inputs
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Pass the inputs to the underlying model for computation
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # Compute the logits from the language model head and add bias
        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias

        masked_lm_loss = None
        if labels is not None:
            # Compute the masked language modeling loss if `labels` are provided
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))

        # Prepare the output based on `return_dict` setting
        if not return_dict:
            # Return a tuple with logits and additional outputs if `return_dict` is False
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

        # Return a `Seq2SeqLMOutput` object with specified attributes if `return_dict` is True
        return Seq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids: torch.LongTensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
        **kwargs,
    ) -> Dict:
        """
        Prepare inputs for text generation.

        Args:
            decoder_input_ids: Input IDs for decoder.
            past_key_values: Tuple of past key and value tensors.
            attention_mask: Mask to avoid attention on padding tokens.
            head_mask: Mask to nullify selected heads of the attention modules.
            decoder_head_mask: Mask to nullify selected heads of the decoder self-attention modules.
            cross_attn_head_mask: Mask to nullify selected heads of the cross-attention modules.
            use_cache: Flag to control whether to use caching.
            encoder_outputs: Output tensors from the encoder.

        Returns:
            Dictionary containing prepared inputs.
        """

        # cut decoder_input_ids if past is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if decoder_input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = decoder_input_ids.shape[1] - 1

            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        """
        Shift labels to the right to prep inputs for decoder.

        Args:
            labels: Tensor of labels.

        Returns:
            Tensor of shifted labels.
        """
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        """
        Reorder past key and value tensors based on beam index.

        Args:
            past_key_values: Tuple of past key and value tensors.
            beam_idx: Tensor containing indices to reorder with.

        Returns:
            Reordered tuple of past key and value tensors.
        """
        reordered_past = ()
        for layer_past in past_key_values:
            # cached cross_attention states don't have to be reordered -> they are always the same
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
                + layer_past[2:],
            )
        return reordered_past
# 从 transformers.models.bart.modeling_bart.BartDecoderWrapper 复制并将 Bart 改为 Marian 的类定义
class MarianDecoderWrapper(MarianPreTrainedModel):
    """
    这个包装类是一个辅助类,用于在使用因果语言模型与 EncoderDecoderModel 框架组合时正确加载预训练的检查点。
    """

    def __init__(self, config):
        # 调用父类的构造函数,传入配置信息
        super().__init__(config)
        # 初始化 MarianDecoder 实例
        self.decoder = MarianDecoder(config)

    def forward(self, *args, **kwargs):
        # 将前向计算委托给 self.decoder 对象
        return self.decoder(*args, **kwargs)


# 从 transformers.models.bart.modeling_bart.BartForCausalLM 复制并将 Bart 改为 Marian,facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
class MarianForCausalLM(MarianPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        # 深度复制配置信息
        config = copy.deepcopy(config)
        # 标记为解码器
        config.is_decoder = True
        # 标记为非编码-解码器结构
        config.is_encoder_decoder = False
        # 调用父类的构造函数,传入配置信息
        super().__init__(config)
        # 初始化 MarianDecoderWrapper 实例
        self.model = MarianDecoderWrapper(config)

        # 初始化线性层,连接解码器隐藏状态到词汇表大小,无偏置
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # 初始化权重并应用最终处理
        self.post_init()

    def get_input_embeddings(self):
        # 返回模型的输入嵌入层
        return self.model.decoder.embed_tokens

    def set_input_embeddings(self, value):
        # 设置模型的输入嵌入层
        self.model.decoder.embed_tokens = value

    def get_output_embeddings(self):
        # 返回模型的输出嵌入层(lm_head)
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        # 设置模型的输出嵌入层(lm_head)
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        # 设置模型的解码器
        self.model.decoder = decoder

    def get_decoder(self):
        # 返回模型的解码器
        return self.model.decoder

    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # 前向传播函数,接受多个输入参数并返回结果
        ...

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
    ):
        # 为生成准备输入数据的函数,接受多个输入参数并返回结果
        ...
    ):
        # 如果模型用作编码器-解码器模型中的解码器,解码器注意力掩码是即时创建的
        if attention_mask is None:
            # 如果注意力掩码为空,则创建一个形状与输入相同的全1张量作为注意力掩码
            attention_mask = input_ids.new_ones(input_ids.shape)

        if past_key_values:
            # 获取过去键值对中第一个层的过去长度
            past_length = past_key_values[0][0].shape[2]

            # 有些生成方法已经只传递最后一个输入 ID
            if input_ids.shape[1] > past_length:
                # 如果输入的长度大于过去长度,则计算要移除的前缀长度
                remove_prefix_length = past_length
            else:
                # 否则,默认行为是保留最后一个 ID
                remove_prefix_length = input_ids.shape[1] - 1

            # 截取输入序列,保留从 remove_prefix_length 开始到末尾的部分
            input_ids = input_ids[:, remove_prefix_length:]
        # 第一步,decoder_cached_states 是空的
        return {
            "input_ids": input_ids,  # encoder_outputs 已定义。不再需要 input_ids
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "use_cache": use_cache,
        }

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        # 重新排序过去的缓存,根据 beam_idx 对每一层的过去状态进行索引选择
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past

.\models\marian\modeling_tf_marian.py

# coding=utf-8
# 版权所有 2021 年 The Marian Team 作者和 HuggingFace Inc. 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本许可;
# 除非符合许可证的规定,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件是基于“按现状”提供的,
# 没有任何明示或暗示的保证或条件。
# 请参阅许可证以获取特定语言的权限和限制。
""" TF 2.0 Marian model."""

from __future__ import annotations

import random
from typing import Optional, Tuple, Union

import numpy as np
import tensorflow as tf

from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
    TFBaseModelOutput,
    TFBaseModelOutputWithPastAndCrossAttentions,
    TFSeq2SeqLMOutput,
    TFSeq2SeqModelOutput,
)

# 公共 API
from ...modeling_tf_utils import (
    TFCausalLanguageModelingLoss,
    TFPreTrainedModel,
    keras,
    keras_serializable,
    unpack_inputs,
)
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
from ...utils import (
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_marian import MarianConfig

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
_CONFIG_FOR_DOC = "MarianConfig"

LARGE_NEGATIVE = -1e8

# 从 transformers.models.bart.modeling_tf_bart.shift_tokens_right 复制而来
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
    pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
    decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
    # 创建一个形状为 (input_ids 的行数, 1) 的张量,填充为 decoder_start_token_id
    start_tokens = tf.fill(
        (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
    )
    # 将 input_ids 右移一位,将 start_tokens 和 input_ids 的前 n-1 列拼接起来
    shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
    # 如果 labels 中存在 -100 的值,则用 pad_token_id 替换
    shifted_input_ids = tf.where(
        shifted_input_ids == -100,
        tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
        shifted_input_ids,
    )

    # "验证 `labels` 只包含正值和 -100"
    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))

    # 确保断言操作被调用,通过将结果包装在 identity 操作中
    with tf.control_dependencies([assert_gte0]):
        shifted_input_ids = tf.identity(shifted_input_ids)

    return shifted_input_ids
# 创建一个用于双向自注意力的因果遮罩
def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
    """
    Make causal mask used for bi-directional self-attention.
    """
    # 获取批次大小
    bsz = input_ids_shape[0]
    # 获取目标序列长度
    tgt_len = input_ids_shape[1]
    # 创建一个形状为 (tgt_len, tgt_len) 的全1矩阵,并乘以一个大负数以表示遮罩
    mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
    # 创建一个序列长度的范围
    mask_cond = tf.range(shape_list(mask)[-1])

    # 将遮罩中对角线以下的元素设置为0
    mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)

    # 如果过去键值长度大于0,则在遮罩左侧添加零矩阵
    if past_key_values_length > 0:
        mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)

    # 将遮罩扩展为四维张量并返回
    return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))


# 从 transformers.models.bart.modeling_tf_bart._expand_mask 复制过来的函数
def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    # 获取源序列长度
    src_len = shape_list(mask)[1]
    # 如果未提供目标长度,则使用源序列长度
    tgt_len = tgt_len if tgt_len is not None else src_len
    # 创建常数张量1.0
    one_cst = tf.constant(1.0)
    # 将遮罩转换为浮点数类型
    mask = tf.cast(mask, dtype=one_cst.dtype)
    # 将遮罩在第二维上复制tgt_len次,以扩展为四维张量
    expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))

    # 返回扩展后的遮罩并应用大负数
    return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFMarianSinusoidalPositionalEmbedding(keras.layers.Layer):
    """This module produces sinusoidal positional embeddings of any length."""

    def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
        super().__init__(**kwargs)

        # 如果嵌入维度为奇数,则抛出错误
        if embedding_dim % 2 != 0:
            raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")

        self.embedding_dim = embedding_dim
        self.num_positions = num_positions

    def build(self, input_shape: tf.TensorShape):
        """
        Build shared token embedding layer Shared weights logic adapted from
        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """

        # 初始化位置编码权重
        weight = self._init_weight(self.num_positions, self.embedding_dim)

        # 添加权重张量到层中
        self.weight = self.add_weight(
            name="embeddings",
            shape=[self.num_positions, self.embedding_dim],
        )
        # 将初始化的权重转换为与self.weight相同的数据类型并分配给self.weight
        weight = tf.cast(weight, dtype=self.weight.dtype)
        self.weight.assign(weight)

        super().build(input_shape)

    @staticmethod
    def _init_weight(n_pos: int, dim: int):
        """
        Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
        the 2nd half of the vector. [dim // 2:]
        """
        # 创建位置编码矩阵,使用正弦和余弦函数
        position_enc = np.array(
            [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
        )
        table = np.zeros_like(position_enc)
        # 第一列全为零
        table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
        # 第二列为余弦值
        table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
        # 将表格转换为张量
        table = tf.convert_to_tensor(table)
        # 停止梯度传播
        tf.stop_gradient(table)
        return table
    def call(
        self, input_shape: tf.TensorShape, past_key_values_length: int = 0, position_ids: tf.Tensor | None = None
    ):
        """Input is expected to be of size [bsz x seqlen]."""
        # 如果未提供位置 ID,则根据输入形状和过去键值长度生成位置 ID
        if position_ids is None:
            # 获取输入的序列长度
            seq_len = input_shape[1]
            # 使用 TensorFlow 的 range 函数生成位置 ID,起始值为 past_key_values_length,
            # 终止值为 seq_len + past_key_values_length,步长为 1
            position_ids = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
        # 根据位置 ID 从 self.weight 中收集对应的权重值
        return tf.gather(self.weight, position_ids)
# 从 transformers.models.bart.modeling_tf_bart.TFBartAttention 复制并修改为 Bart->Marian
class TFMarianAttention(keras.layers.Layer):
    """Multi-headed attention from "Attention Is All You Need"""
    
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim  # 初始化注意力机制的嵌入维度

        self.num_heads = num_heads  # 头数,决定了注意力头的数量
        self.dropout = keras.layers.Dropout(dropout)  # dropout层,用于随机失活
        self.head_dim = embed_dim // num_heads  # 每个头的维度
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"  # 检查嵌入维度是否可以被头数整除
                f" and `num_heads`: {num_heads})."
            )
        self.scaling = self.head_dim**-0.5  # 缩放因子,用于调整注意力分数
        self.is_decoder = is_decoder  # 是否为解码器

        self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")  # k投影层,将输入投影到k空间
        self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")  # q投影层,将输入投影到q空间
        self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")  # v投影层,将输入投影到v空间
        self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")  # 输出投影层,将合并的注意力头投影到输出维度

    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
        # 重新整形张量以适应多头注意力的形状,包括张量的转置和重塑操作

    def call(
        self,
        hidden_states: tf.Tensor,
        key_value_states: tf.Tensor | None = None,
        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
        attention_mask: tf.Tensor | None = None,
        layer_head_mask: tf.Tensor | None = None,
        training: Optional[bool] = False,
    ):
        # 定义层的前向传播逻辑
        ...

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "k_proj", None) is not None:
            with tf.name_scope(self.k_proj.name):
                self.k_proj.build([None, None, self.embed_dim])
        if getattr(self, "q_proj", None) is not None:
            with tf.name_scope(self.q_proj.name):
                self.q_proj.build([None, None, self.embed_dim])
        if getattr(self, "v_proj", None) is not None:
            with tf.name_scope(self.v_proj.name):
                self.v_proj.build([None, None, self.embed_dim])
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.embed_dim])
        # 构建函数,用于按需创建投影层



# 从 transformers.models.bart.modeling_tf_bart.TFBartEncoderLayer 复制并修改为 Bart->Marian
class TFMarianEncoderLayer(keras.layers.Layer):
    # 编码器层类,适用于 Marian 模型,从 BART 模型修改而来
    ...
    # 初始化方法,接受一个MarianConfig对象和额外的关键字参数
    def __init__(self, config: MarianConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 设置嵌入维度为config中的模型维度d_model
        self.embed_dim = config.d_model
        # 创建自注意力层TFMarianAttention对象,使用config中的参数设置
        self.self_attn = TFMarianAttention(
            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
        )
        # 创建自注意力层的LayerNormalization层
        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
        # 创建Dropout层,使用config中的dropout参数
        self.dropout = keras.layers.Dropout(config.dropout)
        # 获取激活函数,根据config中的激活函数类型
        self.activation_fn = get_tf_activation(config.activation_function)
        # 创建激活函数的Dropout层,使用config中的activation_dropout参数
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
        # 创建全连接层fc1,输出维度为config中的encoder_ffn_dim
        self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
        # 创建全连接层fc2,输出维度为self.embed_dim
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
        # 创建最终的LayerNormalization层
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
        # 保存传入的MarianConfig对象
        self.config = config

    # call方法用于执行实际的前向传播逻辑
    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: np.ndarray | tf.Tensor | None,
        layer_head_mask: tf.Tensor | None,
        training: Optional[bool] = False,
    ) -> tf.Tensor:
        """
        Args:
            hidden_states (`tf.Tensor`): 输入层的张量,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`tf.Tensor`): 注意力掩码,形状为 `(batch, 1, tgt_len, src_len)`,
                                          其中填充元素由非常大的负值指示。
            layer_head_mask (`tf.Tensor`): 给定层的注意力头掩码,形状为 `(encoder_attention_heads,)`
            training (`Optional[bool]`, optional): 是否处于训练模式,默认为False。
        Returns:
            tf.Tensor: 返回处理后的张量,形状为 `(batch, seq_len, embed_dim)`
        """
        # 保存输入的隐藏状态作为残差连接的一部分
        residual = hidden_states
        # 使用自注意力层处理隐藏状态,得到处理后的隐藏状态、注意力权重和附加信息
        hidden_states, self_attn_weights, _ = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
        )

        # 断言:确保自注意力层没有修改查询的形状
        tf.debugging.assert_equal(
            shape_list(hidden_states),
            shape_list(residual),
            message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
        )

        # 对处理后的隐藏状态应用Dropout层,根据training参数决定是否使用训练模式
        hidden_states = self.dropout(hidden_states, training=training)
        # 将残差连接到处理后的隐藏状态上
        hidden_states = residual + hidden_states
        # 使用自注意力层的LayerNormalization层对结果进行归一化处理
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 保存处理后的隐藏状态作为新的残差连接的一部分
        residual = hidden_states
        # 使用激活函数处理第一个全连接层的结果
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 对处理后的结果应用激活函数的Dropout层,根据training参数决定是否使用训练模式
        hidden_states = self.activation_dropout(hidden_states, training=training)
        # 使用第二个全连接层处理结果
        hidden_states = self.fc2(hidden_states)
        # 对处理后的结果应用Dropout层,根据training参数决定是否使用训练模式
        hidden_states = self.dropout(hidden_states, training=training)
        # 将残差连接到处理后的结果上
        hidden_states = residual + hidden_states
        # 使用最终的LayerNormalization层对结果进行归一化处理
        hidden_states = self.final_layer_norm(hidden_states)

        # 返回处理后的隐藏状态和自注意力权重
        return hidden_states, self_attn_weights
    # 在神经网络层的建立函数中,用于构建网络结构
    def build(self, input_shape=None):
        # 如果已经建立过网络,则直接返回,不重复建立
        if self.built:
            return
        # 将标记设为已建立
        self.built = True
        
        # 如果存在 self_attn 属性,则构建 self_attention 层
        if getattr(self, "self_attn", None) is not None:
            # 使用 self_attn 层的名称作为命名空间
            with tf.name_scope(self.self_attn.name):
                # 调用 self_attn 层的建立函数
                self.self_attn.build(None)
        
        # 如果存在 self_attn_layer_norm 属性,则构建 self_attention 层的 LayerNormalization
        if getattr(self, "self_attn_layer_norm", None) is not None:
            # 使用 self_attn_layer_norm 层的名称作为命名空间
            with tf.name_scope(self.self_attn_layer_norm.name):
                # 调用 self_attn_layer_norm 层的建立函数,输入形状为 [None, None, self.embed_dim]
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 fc1 属性,则构建第一个全连接层
        if getattr(self, "fc1", None) is not None:
            # 使用 fc1 层的名称作为命名空间
            with tf.name_scope(self.fc1.name):
                # 调用 fc1 层的建立函数,输入形状为 [None, None, self.embed_dim]
                self.fc1.build([None, None, self.embed_dim])
        
        # 如果存在 fc2 属性,则构建第二个全连接层
        if getattr(self, "fc2", None) is not None:
            # 使用 fc2 层的名称作为命名空间
            with tf.name_scope(self.fc2.name):
                # 调用 fc2 层的建立函数,输入形状为 [None, None, self.config.encoder_ffn_dim]
                self.fc2.build([None, None, self.config.encoder_ffn_dim])
        
        # 如果存在 final_layer_norm 属性,则构建最终的 LayerNormalization
        if getattr(self, "final_layer_norm", None) is not None:
            # 使用 final_layer_norm 层的名称作为命名空间
            with tf.name_scope(self.final_layer_norm.name):
                # 调用 final_layer_norm 层的建立函数,输入形状为 [None, None, self.embed_dim]
                self.final_layer_norm.build([None, None, self.embed_dim])
# 从 transformers.models.bart.modeling_tf_bart.TFBartDecoderLayer 复制而来,将 Bart 改为 Marian
class TFMarianDecoderLayer(keras.layers.Layer):
    def __init__(self, config: MarianConfig, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = config.d_model  # 初始化嵌入维度为配置中的模型维度
        self.self_attn = TFMarianAttention(  # 创建自注意力层对象
            embed_dim=self.embed_dim,  # 使用配置中的模型维度
            num_heads=config.decoder_attention_heads,  # 使用配置中的解码器注意力头数
            dropout=config.attention_dropout,  # 使用配置中的注意力丢弃率
            name="self_attn",  # 层名称为 self_attn
            is_decoder=True,  # 标记这是一个解码器注意力层
        )
        self.dropout = keras.layers.Dropout(config.dropout)  # 使用配置中的丢弃率创建丢弃层
        self.activation_fn = get_tf_activation(config.activation_function)  # 获取激活函数
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)  # 使用配置中的激活函数丢弃率创建丢弃层

        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
        # 创建自注意力层后的层归一化层

        self.encoder_attn = TFMarianAttention(  # 创建编码器注意力层对象
            self.embed_dim,  # 使用配置中的模型维度
            config.decoder_attention_heads,  # 使用配置中的解码器注意力头数
            dropout=config.attention_dropout,  # 使用配置中的注意力丢弃率
            name="encoder_attn",  # 层名称为 encoder_attn
            is_decoder=True,  # 标记这是一个解码器注意力层
        )
        self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
        # 创建编码器注意力层后的层归一化层

        self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")  # 创建全连接层 fc1
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")  # 创建全连接层 fc2

        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
        # 最终输出后的层归一化层

        self.config = config  # 保存配置对象

    def call(
        self,
        hidden_states: tf.Tensor,  # 输入隐藏状态张量
        attention_mask: np.ndarray | tf.Tensor | None = None,  # 注意力掩码
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,  # 编码器隐藏状态张量或数组
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,  # 编码器注意力掩码
        layer_head_mask: tf.Tensor | None = None,  # 层头部掩码
        cross_attn_layer_head_mask: tf.Tensor | None = None,  # 跨注意力层头部掩码
        past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,  # 过去的键值对(可选)
        training: Optional[bool] = False,  # 训练模式(可选,默认为 False)
    # 定义 build 方法,用于构建模型的层次结构
    def build(self, input_shape=None):
        # 如果已经构建过,直接返回,避免重复构建
        if self.built:
            return
        # 设置标志位,表示模型已经构建
        self.built = True
        
        # 如果存在 self_attn 属性,则构建 self attention 层
        if getattr(self, "self_attn", None) is not None:
            # 在命名空间 self_attn 下,构建 self attention 层
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        
        # 如果存在 self_attn_layer_norm 属性,则构建 self attention 层的 layer normalization 层
        if getattr(self, "self_attn_layer_norm", None) is not None:
            # 在命名空间 self_attn_layer_norm 下,构建 layer normalization 层
            with tf.name_scope(self.self_attn_layer_norm.name):
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 encoder_attn 属性,则构建 encoder attention 层
        if getattr(self, "encoder_attn", None) is not None:
            # 在命名空间 encoder_attn 下,构建 encoder attention 层
            with tf.name_scope(self.encoder_attn.name):
                self.encoder_attn.build(None)
        
        # 如果存在 encoder_attn_layer_norm 属性,则构建 encoder attention 层的 layer normalization 层
        if getattr(self, "encoder_attn_layer_norm", None) is not None:
            # 在命名空间 encoder_attn_layer_norm 下,构建 layer normalization 层
            with tf.name_scope(self.encoder_attn_layer_norm.name):
                self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 fc1 属性,则构建第一个全连接层
        if getattr(self, "fc1", None) is not None:
            # 在命名空间 fc1 下,构建全连接层
            with tf.name_scope(self.fc1.name):
                self.fc1.build([None, None, self.embed_dim])
        
        # 如果存在 fc2 属性,则构建第二个全连接层
        if getattr(self, "fc2", None) is not None:
            # 在命名空间 fc2 下,构建全连接层,输入维度为 decoder_ffn_dim
            with tf.name_scope(self.fc2.name):
                self.fc2.build([None, None, self.config.decoder_ffn_dim])
        
        # 如果存在 final_layer_norm 属性,则构建最终的 layer normalization 层
        if getattr(self, "final_layer_norm", None) is not None:
            # 在命名空间 final_layer_norm 下,构建 layer normalization 层
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.embed_dim])
# 定义一个名为 TFMarianPreTrainedModel 的类,继承自 TFPreTrainedModel 类
class TFMarianPreTrainedModel(TFPreTrainedModel):
    # 指定配置类为 MarianConfig
    config_class = MarianConfig
    # 指定基础模型前缀为 "model"
    base_model_prefix = "model"


# 定义一个文档字符串常量 MARIAN_START_DOCSTRING
MARIAN_START_DOCSTRING = r"""
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
    and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
    pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
    format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
    the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
    positional argument:

    - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
    - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
    `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
    - a dictionary with one or several input Tensors associated to the input names given in the docstring:
    `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`

    Note that when creating models and layers with
    [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
    about any of this, as you can just pass inputs like you would to any other Python function!

    </Tip>

    Args:
        config ([`MarianConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
MARIAN_GENERATION_EXAMPLE = r"""
        TF version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available
        models are listed [here](https://huggingface.co/models?search=Helsinki-NLP).

        Examples:

        ```
        >>> from transformers import AutoTokenizer, TFMarianMTModel
        >>> from typing import List

        >>> src = "fr"  # source language
        >>> trg = "en"  # target language
        >>> sample_text = "où est l'arrêt de bus ?"
        >>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"

        >>> model = TFMarianMTModel.from_pretrained(model_name)
        >>> tokenizer = AutoTokenizer.from_pretrained(model_name)
        >>> batch = tokenizer([sample_text], return_tensors="tf")
        >>> gen = model.generate(**batch)
        >>> tokenizer.batch_decode(gen, skip_special_tokens=True)
        "Where is the bus stop ?"
        ```
"""

MARIAN_INPUTS_DOCSTRING = r"""
"""

# 自定义 Keras 层 `TFMarianEncoder`,标记为可序列化
@keras_serializable
class TFMarianEncoder(keras.layers.Layer):
    # 配置类为 MarianConfig
    config_class = MarianConfig

    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`TFMarianEncoderLayer`].

    Args:
        config: MarianConfig
    """
    
    # 初始化方法
    def __init__(self, config: MarianConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        super().__init__(**kwargs)
        self.config = config  # 存储配置信息
        self.dropout = keras.layers.Dropout(config.dropout)  # Dropout 层,使用配置的 dropout 比率
        self.layerdrop = config.encoder_layerdrop  # Encoder 层 dropout 比率
        self.padding_idx = config.pad_token_id  # 填充 token 的索引
        self.max_source_positions = config.max_position_embeddings  # 最大源位置数
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0  # 嵌入向量缩放因子

        self.embed_tokens = embed_tokens  # 嵌入 token
        self.embed_positions = TFMarianSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )  # Sinusoidal 位置嵌入

        # 创建多个 Transformer Encoder 层
        self.layers = [TFMarianEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]

    # 获取嵌入 token
    def get_embed_tokens(self):
        return self.embed_tokens

    # 设置嵌入 token
    def set_embed_tokens(self, embed_tokens):
        self.embed_tokens = embed_tokens

    # 对输入进行解包并调用处理的方法
    @unpack_inputs
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
        **kwargs,
    ):
    # 定义神经网络层的构建方法,用于在指定输入形状下构建网络层
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 将网络标记为已构建状态
        self.built = True
        
        # 如果存在嵌入位置信息的属性,则构建嵌入位置信息
        if getattr(self, "embed_positions", None) is not None:
            # 使用该属性的命名空间创建名称作用域,并构建嵌入位置信息
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)
        
        # 如果存在多个层,则依次构建每个层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                # 使用每个层的名称创建名称作用域,并构建该层
                with tf.name_scope(layer.name):
                    layer.build(None)
@keras_serializable
class TFMarianDecoder(keras.layers.Layer):
    # 指定配置类为MarianConfig
    config_class = MarianConfig
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFMarianDecoderLayer`]

    Args:
        config: MarianConfig  # 输入参数为MarianConfig类型的配置对象
        embed_tokens: output embedding  # 输出嵌入的标记
    """

    def __init__(self, config: MarianConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        super().__init__(**kwargs)
        self.config = config  # 保存配置对象
        self.padding_idx = config.pad_token_id  # 获取填充标记ID
        self.embed_tokens = embed_tokens  # 保存嵌入标记对象
        self.layerdrop = config.decoder_layerdrop  # 获取层丢弃概率
        self.embed_positions = TFMarianSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )  # 创建Sinusoidal位置嵌入对象
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0  # 根据scale_embedding决定缩放因子
        self.layers = [TFMarianDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]  # 创建多层解码器层对象

        self.dropout = keras.layers.Dropout(config.dropout)  # 创建dropout层对象

    def get_embed_tokens(self):
        return self.embed_tokens  # 返回嵌入标记对象

    def set_embed_tokens(self, embed_tokens):
        self.embed_tokens = embed_tokens  # 设置嵌入标记对象

    @unpack_inputs
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        inputs_embeds: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        position_ids: tf.Tensor | None = None,
        encoder_hidden_states: tf.Tensor | None = None,
        encoder_attention_mask: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        cross_attn_head_mask: tf.Tensor | None = None,
        past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
    ):
        # 定义Transformer解码器的前向传播过程
        ...

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, "embed_positions", None) is not None:
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)  # 构建位置嵌入对象
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)  # 逐层构建解码器层对象


@keras_serializable
class TFMarianMainLayer(keras.layers.Layer):
    config_class = MarianConfig
    # 初始化函数,接受一个MarianConfig对象和其他关键字参数
    def __init__(self, config: MarianConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)

        # 将传入的配置对象存储在self.config中
        self.config = config

        # 创建一个共享的Embedding层,用于编码器和解码器共享
        self.shared = keras.layers.Embedding(
            input_dim=config.vocab_size,  # 词汇表大小,作为输入维度
            output_dim=config.d_model,     # 输出维度,通常是模型的维度
            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),  # 使用截断正态分布初始化
            name="model.shared",           # 层的名称
        )
        
        # 添加一个额外的属性,指定层的预期名称作用域(用于加载/存储权重)
        self.shared.load_weight_prefix = "model.shared"

        # 创建编码器对象,使用TFMarianEncoder类,并传入配置对象和共享的Embedding层
        self.encoder = TFMarianEncoder(config, self.shared, name="encoder")

        # 创建解码器对象,使用TFMarianDecoder类,并传入配置对象和共享的Embedding层
        self.decoder = TFMarianDecoder(config, self.shared, name="decoder")

    # 获取输入Embedding层的方法
    def get_input_embeddings(self):
        return self.shared

    # 设置输入Embedding层的方法,传入新的Embedding层
    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings  # 更新共享的Embedding层
        self.encoder.embed_tokens = self.shared  # 更新编码器的Embedding层
        self.decoder.embed_tokens = self.shared  # 更新解码器的Embedding层

    # 使用unpack_inputs装饰器定义的call方法,实现模型的调用过程
    @unpack_inputs
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        decoder_input_ids: tf.Tensor | None = None,
        decoder_attention_mask: tf.Tensor | None = None,
        decoder_position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        decoder_head_mask: tf.Tensor | None = None,
        cross_attn_head_mask: tf.Tensor | None = None,
        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
        past_key_values: Tuple[Tuple[tf.Tensor]] = None,
        inputs_embeds: tf.Tensor | None = None,
        decoder_inputs_embeds: tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: bool = False,
        **kwargs,
    ):
        # 实现模型的前向传播逻辑,接收一系列输入张量和参数
        pass  # 该方法尚未实现具体的逻辑,只是定义了方法签名和参数
        ):
            # 如果没有提供解码器的输入 ID 和嵌入向量,则不使用缓存
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                use_cache = False
        
            # 如果没有指定输出隐藏状态,则使用模型配置中的默认设置
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
        
            # 如果没有提供编码器输出,则调用编码器进行前向传播计算
            if encoder_outputs is None:
                encoder_outputs = self.encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    head_mask=head_mask,
                    inputs_embeds=inputs_embeds,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    training=training,
                )
            # 如果 return_dict=True 并且用户传入的 encoder_outputs 是 tuple 类型,则将其包装成 TFBaseModelOutput 对象
            elif return_dict and not isinstance(encoder_outputs, TFBaseModelOutput):
                encoder_outputs = TFBaseModelOutput(
                    last_hidden_state=encoder_outputs[0],
                    hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                    attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
                )
            # 如果 return_dict=False 并且用户传入的 encoder_outputs 是 TFBaseModelOutput 类型,则将其转换成 tuple 类型
            elif not return_dict and not isinstance(encoder_outputs, tuple):
                encoder_outputs = encoder_outputs.to_tuple()
        
            # 调用解码器进行解码操作
            decoder_outputs = self.decoder(
                decoder_input_ids,
                attention_mask=decoder_attention_mask,
                position_ids=decoder_position_ids,
                encoder_hidden_states=encoder_outputs[0],
                encoder_attention_mask=attention_mask,
                head_mask=decoder_head_mask,
                cross_attn_head_mask=cross_attn_head_mask,
                past_key_values=past_key_values,
                inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                training=training,
            )
        
            # 如果 return_dict=False,则将解码器输出和编码器输出组合成一个 tuple 返回
            if not return_dict:
                return decoder_outputs + encoder_outputs
        
            # 如果 return_dict=True,则将解码器和编码器的输出组装成 TFSeq2SeqModelOutput 对象返回
            return TFSeq2SeqModelOutput(
                last_hidden_state=decoder_outputs.last_hidden_state,
                past_key_values=decoder_outputs.past_key_values,
                decoder_hidden_states=decoder_outputs.hidden_states,
                decoder_attentions=decoder_outputs.attentions,
                cross_attentions=decoder_outputs.cross_attentions,
                encoder_last_hidden_state=encoder_outputs.last_hidden_state,
                encoder_hidden_states=encoder_outputs.hidden_states,
                encoder_attentions=encoder_outputs.attentions,
            )
    # 如果模型已经构建,则直接返回,不再重复构建
    if self.built:
        return
    # 设置标志位表示模型已经构建
    self.built = True
    
    # 在模型基础命名空间中设置共享/绑定权重的命名空间
    # 将 "/" 添加到名称作用域的末尾(而不是开头!)将其放置在根命名空间而不是当前命名空间中。
    with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
        # 构建共享权重模块
        self.shared.build(None)
    
    # 如果存在编码器(encoder)模块,则在其命名空间中构建
    if getattr(self, "encoder", None) is not None:
        with tf.name_scope(self.encoder.name):
            self.encoder.build(None)
    
    # 如果存在解码器(decoder)模块,则在其命名空间中构建
    if getattr(self, "decoder", None) is not None:
        with tf.name_scope(self.decoder.name):
            self.decoder.build(None)
# 以裸MARIAN模型为基础,输出未加特定头部的原始隐藏状态。
# 继承自TFMarianPreTrainedModel类,是MARIAN模型的TensorFlow实现。
class TFMarianModel(TFMarianPreTrainedModel):
    def __init__(self, config: MarianConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        # 使用TFMarianMainLayer初始化模型,命名为"model"
        self.model = TFMarianMainLayer(config, name="model")

    # 返回模型的编码器
    def get_encoder(self):
        return self.model.encoder

    # 返回模型的解码器
    def get_decoder(self):
        return self.model.decoder

    # 对模型进行调用的方法,接受多种输入参数,并根据需要进行处理
    @unpack_inputs
    @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFSeq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def call(
        self,
        input_ids: tf.Tensor | None = None,
        attention_mask: tf.Tensor | None = None,
        decoder_input_ids: tf.Tensor | None = None,
        decoder_attention_mask: tf.Tensor | None = None,
        decoder_position_ids: tf.Tensor | None = None,
        head_mask: tf.Tensor | None = None,
        decoder_head_mask: tf.Tensor | None = None,
        cross_attn_head_mask: tf.Tensor | None = None,
        encoder_outputs: tf.Tensor | None = None,
        past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
        inputs_embeds: tf.Tensor | None = None,
        decoder_inputs_embeds: tf.Tensor | None = None,
        use_cache: bool | None = None,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        training: bool = False,
        **kwargs,
    ) -> Tuple[tf.Tensor] | TFSeq2SeqModelOutput:
        # 将输入参数传递给模型的call方法,返回模型的输出
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 返回模型的输出
        return outputs

    # 从transformers.models.bart.modeling_tf_bart.TFBartModel.serving_output中复制而来
    # 定义一个方法,用于处理模型输出
    def serving_output(self, output):
        # 如果配置中使用缓存,则获取输出中的过去键值对中的第二个元素(过去的键值对)
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中需要输出隐藏状态,则将输出中的解码器隐藏状态转换为张量
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中需要输出注意力权重,则将输出中的解码器注意力权重转换为张量
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中需要输出交叉注意力权重,则将输出中的交叉注意力权重转换为张量
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中需要输出隐藏状态,则将输出中的编码器隐藏状态转换为张量
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中需要输出注意力权重,则将输出中的编码器注意力权重转换为张量
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqModelOutput 对象,包含处理后的各种输出
        return TFSeq2SeqModelOutput(
            last_hidden_state=output.last_hidden_state,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 定义一个构建方法,用于构建模型
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        # 标记为已经构建
        self.built = True
        # 如果已经存在模型,则在模型的命名空间下构建模型(这里可能是指 TensorFlow 的命名空间)
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                self.model.build(None)
# Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
class BiasLayer(keras.layers.Layer):
    """
    Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
    so all weights have to be registered in a layer.
    """

    def __init__(self, shape, initializer, trainable, name, **kwargs):
        super().__init__(name=name, **kwargs)
        # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
        # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
        # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
        # 添加一个权重,用于该层的偏置,用于模型的序列化
        self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)

    def call(self, x):
        # 在调用时,将偏置加到输入张量 x 上
        return x + self.bias


@add_start_docstrings(
    "The MARIAN Model with a language modeling head. Can be used for summarization.",
    MARIAN_START_DOCSTRING,
)
class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
    _keys_to_ignore_on_load_unexpected = [
        r"model.encoder.embed_tokens.weight",
        r"model.decoder.embed_tokens.weight",
    ]

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 创建 MARIAN 模型的主体部分,并命名为 "model"
        self.model = TFMarianMainLayer(config, name="model")
        self.use_cache = config.use_cache
        # 创建一个偏置层 BiasLayer,用于模型的最终 logits 的偏置,保持不可训练状态以保持一致性
        # 这里的 final_logits_bias 在 PyTorch 中作为缓冲区注册,因此在 TensorFlow 中保持不可训练以便正确序列化
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
        )

    def get_decoder(self):
        # 返回模型的解码器部分
        return self.model.decoder

    def get_encoder(self):
        # 返回模型的编码器部分
        return self.model.encoder

    def get_output_embeddings(self):
        # 返回输入的嵌入层
        return self.get_input_embeddings()

    def set_output_embeddings(self, value):
        # 设置输入的嵌入层
        self.set_input_embeddings(value)

    def get_bias(self):
        # 返回模型的偏置信息,这里只包含最终 logits 的偏置
        return {"final_logits_bias": self.bias_layer.bias}

    def set_bias(self, value):
        # 替换已有的偏置层以正确(反)序列化包含偏置的层
        vocab_size = value["final_logits_bias"].shape[-1]
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
        )
        # 赋予新的偏置值
        self.bias_layer.bias.assign(value["final_logits_bias"])

    @unpack_inputs
    @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    @add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
    # 定义一个方法 `call`,用于执行模型的前向传播或推断过程,支持以下参数:

    # 输入序列的 token IDs 张量,可以是 None(TensorFlow 张量或 None)
    input_ids: tf.Tensor | None = None,

    # 注意力掩码张量,用于指定模型关注哪些 token,可以是 None(TensorFlow 张量或 None)
    attention_mask: tf.Tensor | None = None,

    # 解码器输入序列的 token IDs 张量,可以是 None(TensorFlow 张量或 None)
    decoder_input_ids: tf.Tensor | None = None,

    # 解码器注意力掩码张量,用于指定解码器关注哪些 token,可以是 None(TensorFlow 张量或 None)
    decoder_attention_mask: tf.Tensor | None = None,

    # 解码器位置 IDs 张量,可以是 None(TensorFlow 张量或 None)
    decoder_position_ids: tf.Tensor | None = None,

    # 多头注意力掩码张量,用于指定哪些注意力头应该被屏蔽,可以是 None(TensorFlow 张量或 None)
    head_mask: tf.Tensor | None = None,

    # 解码器多头注意力掩码张量,用于指定解码器的哪些注意力头应该被屏蔽,可以是 None(TensorFlow 张量或 None)
    decoder_head_mask: tf.Tensor | None = None,

    # 交叉注意力头掩码张量,用于指定哪些注意力头应该被屏蔽,可以是 None(TensorFlow 张量或 None)
    cross_attn_head_mask: tf.Tensor | None = None,

    # 编码器输出对象,包含模型的编码器输出,可以是 None(TFBaseModelOutput 或 None)
    encoder_outputs: TFBaseModelOutput | None = None,

    # 过去的键值对,用于存储解码器在自回归生成中的过去内容,可以是 None(元组的元组,每个元组包含 TensorFlow 张量)
    past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,

    # 输入嵌入张量,代替输入序列的 token IDs 张量,可以是 None(TensorFlow 张量或 None)
    inputs_embeds: tf.Tensor | None = None,

    # 解码器输入嵌入张量,代替解码器输入序列的 token IDs 张量,可以是 None(TensorFlow 张量或 None)
    decoder_inputs_embeds: tf.Tensor | None = None,

    # 是否使用缓存,用于指定是否在模型中使用缓存,可以是 None(布尔值或 None)
    use_cache: bool | None = None,

    # 是否输出注意力权重,用于指定是否返回模型中注意力权重,可以是 None(布尔值或 None)
    output_attentions: bool | None = None,

    # 是否输出隐藏状态,用于指定是否返回模型中的隐藏状态,可以是 None(布尔值或 None)
    output_hidden_states: bool | None = None,

    # 是否返回一个字典格式的结果,用于指定是否返回模型输出的字典格式结果,可以是 None(布尔值或 None)
    return_dict: bool | None = None,

    # 标签张量,用于指定训练时的标签值,可以是 None(TensorFlow 张量或 None)
    labels: tf.Tensor | None = None,

    # 是否处于训练模式,用于指定模型当前是否处于训练模式,默认为 False
    training: bool = False,
    ) -> Tuple[tf.Tensor] | TFSeq2SeqLMOutput:
        r"""
        labels (`tf.tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Depending on `return_dict`, either a tuple or `TFSeq2SeqLMOutput`.

        """

        # Adjust labels for masked language modeling loss computation
        if labels is not None:
            labels = tf.where(
                labels == self.config.pad_token_id,
                tf.fill(shape_list(labels), tf.cast(-100, labels.dtype)),
                labels,
            )
            # Reset `use_cache` flag if decoder inputs are not provided explicitly
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                # Shift labels to the right for decoder input
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Forward pass through the model with specified inputs and parameters
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # Compute logits for masked language modeling
        lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
        lm_logits = self.bias_layer(lm_logits)
        masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

        # Return outputs either as a tuple or TFSeq2SeqLMOutput based on `return_dict`
        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        return TFSeq2SeqLMOutput(
            loss=masked_lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,  # Past key values from model outputs
            decoder_hidden_states=outputs.decoder_hidden_states,  # Decoder hidden states from model outputs
            decoder_attentions=outputs.decoder_attentions,  # Decoder attentions from model outputs
            cross_attentions=outputs.cross_attentions,  # Cross attentions from model outputs
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,  # Encoder last hidden state from encoder outputs
            encoder_hidden_states=outputs.encoder_hidden_states,  # Encoder hidden states from encoder outputs
            encoder_attentions=outputs.encoder_attentions,  # Encoder attentions from encoder outputs
        )

    # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration.serving_output
    # 定义一个方法,用于生成模型输出的结构化表示
    def serving_output(self, output):
        # 如果配置中启用缓存,则从输出中提取过去键值对的第二个元素作为 pkv,否则为 None
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中启用输出隐藏状态,则将输出的解码器隐藏状态转换为张量,否则为 None
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中启用输出注意力权重,则将输出的解码器注意力权重转换为张量,否则为 None
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中启用输出交叉注意力权重,则将输出的交叉注意力权重转换为张量,否则为 None
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中启用输出隐藏状态,则将输出的编码器隐藏状态转换为张量,否则为 None
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中启用输出注意力权重,则将输出的编码器注意力权重转换为张量,否则为 None
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqLMOutput 对象,包括 logits、过去键值对、解码器隐藏状态、解码器注意力权重、
        # 交叉注意力权重、编码器最后隐藏状态、编码器隐藏状态和编码器注意力权重
        return TFSeq2SeqLMOutput(
            logits=output.logits,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 从 transformers 库中的 TFBartForConditionalGeneration 类的方法 prepare_inputs_for_generation 复制而来
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果 past_key_values 不为 None,则仅保留 decoder_input_ids 的最后一个标记
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        # 如果存在 decoder_attention_mask,则计算其累积位置 IDs 的最后一个值;否则,根据 past_key_values 或 decoder_input_ids 的长度生成位置 IDs
        if decoder_attention_mask is not None:  # xla
            decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
        elif past_key_values is not None:  # no xla + past_key_values
            decoder_position_ids = past_key_values[0][0].shape[2]
        else:  # no xla + no past_key_values
            decoder_position_ids = tf.range(decoder_input_ids.shape[1])

        # 返回一个包含模型生成所需输入的字典,包括 input_ids、encoder_outputs、past_key_values、decoder_input_ids、
        # attention_mask、decoder_attention_mask、decoder_position_ids、head_mask、decoder_head_mask、
        # cross_attn_head_mask 和 use_cache 参数
        return {
            "input_ids": None,  # encoder_outputs 已经定义,因此不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_position_ids": decoder_position_ids,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 修改此处以避免缓存(可能用于调试目的)
        }

    # 定义一个方法,根据标签生成解码器的输入 IDs,向右移动标签,并用 pad_token_id 和 decoder_start_token_id 处理填充
    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
    # 定义神经网络层的构建方法,接受输入形状参数,如果已经构建过则直接返回
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,不再重复构建
        if self.built:
            return
        # 将标志位设置为已构建
        self.built = True
        # 如果存在嵌套的模型对象,则在命名空间下构建该模型
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                # 使用空输入形状构建嵌套模型
                self.model.build(None)
        # 如果存在偏置层对象,则在命名空间下构建该偏置层
        if getattr(self, "bias_layer", None) is not None:
            with tf.name_scope(self.bias_layer.name):
                # 使用空输入形状构建偏置层
                self.bias_layer.build(None)
posted @ 2024-06-29 16:59  绝不原创的飞龙  阅读(9)  评论(0编辑  收藏  举报