Transformers-源码解析-十三-

Transformers 源码解析(十三)

.\models\bart\modeling_flax_bart.py

# coding=utf-8
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax Bart model."""

import math  # 导入数学函数库
import random  # 导入随机数函数库
from functools import partial  # 导入偏函数模块
from typing import Callable, Optional, Tuple  # 导入类型提示

import flax.linen as nn  # 导入Flax的linen模块作为nn别名
import jax  # 导入JAX库
import jax.numpy as jnp  # 导入JAX的NumPy接口,并且用jnp作为别名
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze  # 导入冻结字典相关函数
from flax.linen import combine_masks, make_causal_mask  # 导入生成掩码相关函数
from flax.linen.attention import dot_product_attention_weights  # 导入注意力权重计算函数
from flax.traverse_util import flatten_dict, unflatten_dict  # 导入字典扁平化和还原相关函数
from jax import lax  # 导入JAX的lax库
from jax.random import PRNGKey  # 导入PRNGKey,伪随机数生成器

from ...modeling_flax_outputs import (
    FlaxBaseModelOutput,  # 导入基础模型输出
    FlaxBaseModelOutputWithPastAndCrossAttentions,  # 导入包含过去和交叉注意力的基础模型输出
    FlaxCausalLMOutputWithCrossAttentions,  # 导入包含交叉注意力的因果语言建模输出
    FlaxSeq2SeqLMOutput,  # 导入序列到序列语言建模输出
    FlaxSeq2SeqModelOutput,  # 导入序列到序列模型输出
    FlaxSeq2SeqQuestionAnsweringModelOutput,  # 导入序列到序列问答模型输出
    FlaxSeq2SeqSequenceClassifierOutput,  # 导入序列到序列序列分类器输出
)
from ...modeling_flax_utils import (
    ACT2FN,  # 导入激活函数到函数名称的映射
    FlaxPreTrainedModel,  # 导入Flax预训练模型基类
    append_call_sample_docstring,  # 导入追加调用样例文档字符串函数
    append_replace_return_docstrings,  # 导入追加替换返回文档字符串函数
    overwrite_call_docstring,  # 导入覆盖调用文档字符串函数
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings  # 导入工具函数和模型前向文档字符串处理函数
from .configuration_bart import BartConfig  # 导入BART配置

logger = logging.get_logger(__name__)  # 获取logger对象

_CHECKPOINT_FOR_DOC = "facebook/bart-base"  # 预训练模型的文档检查点
_CONFIG_FOR_DOC = "BartConfig"  # BART模型配置的文档

BART_START_DOCSTRING = r"""
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
"""
    # 参数说明:
    # config ([`BartConfig`]): 模型配置类,包含模型的所有参数。
    #     使用配置文件初始化不会加载模型的权重,只加载配置信息。
    #     可查看 [`~FlaxPreTrainedModel.from_pretrained`] 方法来加载模型权重。
    # dtype (`jax.numpy.dtype`, *可选*, 默认为 `jax.numpy.float32`):
    #     计算时所用的数据类型。可以是 `jax.numpy.float32`, `jax.numpy.float16`(在GPU上)和 `jax.numpy.bfloat16`(在TPU上)之一。
    #
    #     这可以用于在GPU或TPU上启用混合精度训练或半精度推断。如果指定了dtype,则所有计算将使用给定的dtype进行。
    #
    #     **请注意,这仅指定计算时的数据类型,并不影响模型参数的数据类型。**
    #
    #     如果希望更改模型参数的数据类型,请参阅 [`~FlaxPreTrainedModel.to_fp16`] 和 [`~FlaxPreTrainedModel.to_bf16`]。
"""
定义 BART 输入文档字符串
"""
BART_INPUTS_DOCSTRING = r"""
"""


"""
定义 BART 编码输入文档字符串
Args:
    input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
        输入序列标记在词汇表中的索引。默认情况下,将忽略填充。
        
        可以使用 [`AutoTokenizer`] 获取索引。详情请参阅 [`PreTrainedTokenizer.encode`] 和 [`PreTrainedTokenizer.__call__`]。

        [什么是输入 ID?](../glossary#input-ids)
    attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        避免在填充标记索引上执行注意力的掩码。掩码值选在 `[0, 1]`:

        - 1 表示**未屏蔽**的标记,
        - 0 表示**已屏蔽**的标记。

        [什么是注意力掩码?](../glossary#attention-mask)
    position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
        每个输入序列标记在位置嵌入中的位置索引。选在范围 `[0, config.max_position_embeddings - 1]`。
    output_attentions (`bool`, *optional*):
        是否返回所有注意力层的注意力张量。详见返回张量中的 `attentions`。
    output_hidden_states (`bool`, *optional*):
        是否返回所有层的隐藏状态。详见返回张量中的 `hidden_states`。
    return_dict (`bool`, *optional*):
        是否返回 [`~utils.ModelOutput`] 而非普通元组。
"""


"""
定义 BART 解码输入文档字符串
"""
BART_DECODE_INPUTS_DOCSTRING = r"""
"""


def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
    """
    将输入 ID 向右移动一个标记。
    """
    shifted_input_ids = jnp.zeros_like(input_ids)
    shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])  # 将输入向右移动一个位置
    shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)  # 在起始位置插入解码器起始标记

    shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)  # 替换特殊标记为填充标记
    return shifted_input_ids


class FlaxBartAttention(nn.Module):
    """
    FlaxBartAttention 类定义
    """
    config: BartConfig
    embed_dim: int
    num_heads: int
    dropout: float = 0.0
    causal: bool = False
    bias: bool = True
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型
    # 设置函数,初始化注意力头的维度
    def setup(self) -> None:
        # 计算每个注意力头的维度
        self.head_dim = self.embed_dim // self.num_heads
        # 检查 embed_dim 是否可以整除 num_heads,否则抛出数值错误异常
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {self.num_heads})."
            )

        # 定义一个局部函数 dense,部分应用 Dense 层的参数
        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # 创建查询、键、值以及输出投影的 Dense 层
        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        # 初始化 Dropout 层
        self.dropout_layer = nn.Dropout(rate=self.dropout)

        # 如果启用因果注意力,创建一个因果掩码
        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
            )

    # 将隐藏状态按注意力头进行分割
    def _split_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))

    # 将分割后的注意力头合并回原始形状
    def _merge_heads(self, hidden_states):
        return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

    # 使用 JAX 编译这个类的方法
    @nn.compact
    def _concatenate_to_cache(self, key, value, query, attention_mask):
        """
        This function takes projected key, value states from a single input token and concatenates the states to cached
        states from previous steps. This function is slightly adapted from the official Flax repository:
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
        """
        # 检测是否通过检查"cache"变量来初始化缓存数据。
        is_initialized = self.has_variable("cache", "cached_key")
        # 获取或初始化缓存中的键和值
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
        # 获取或初始化缓存中的索引,作为当前缓存操作的起始位置
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))

        if is_initialized:
            # 提取缓存键值张量的维度信息,其中包括批次维度、序列长度、注意力头数和每头注意力深度
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
            # 使用新的1D空间切片更新键和值的缓存
            cur_index = cache_index.value
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
            # 更新缓存中的键和值
            cached_key.value = key
            cached_value.value = value
            # 更新缓存索引,增加已更新缓存向量的数量
            num_updated_cache_vectors = query.shape[1]
            cache_index.value = cache_index.value + num_updated_cache_vectors
            # 生成用于缓存的因果掩码,确保每个查询位置只注意到已生成并缓存的键位置,而不是剩余的零元素。
            pad_mask = jnp.broadcast_to(
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
            )
            # 合并现有的注意力掩码和生成的因果掩码
            attention_mask = combine_masks(pad_mask, attention_mask)
        
        # 返回更新后的键、值和注意力掩码
        return key, value, attention_mask
class FlaxBartEncoderLayer(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        # 设置编码器层的嵌入维度为模型配置中的维度
        self.embed_dim = self.config.d_model
        # 初始化自注意力机制
        self.self_attn = FlaxBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.encoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 初始化自注意力层规范化
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化 dropout 层
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 设置激活函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 初始化激活函数的 dropout 层
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
        # 第一个全连接层,映射到编码器的前馈神经网络维度
        self.fc1 = nn.Dense(
            self.config.encoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 第二个全连接层,映射回嵌入维度
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 最终层规范化
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    def __call__(
        self,
        hidden_states: jnp.ndarray,
        attention_mask: jnp.ndarray,
        output_attentions: bool = True,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        # 保存残差连接
        residual = hidden_states
        # 使用自注意力机制计算新的隐藏状态和注意力权重
        hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)

        # 应用 dropout 到隐藏状态
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 应用自注意力层规范化
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 保存残差连接
        residual = hidden_states
        # 应用激活函数到第一个全连接层
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用激活函数的 dropout
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用第二个全连接层
        hidden_states = self.fc2(hidden_states)
        # 应用 dropout 到第二个全连接层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 应用最终层规范化
        hidden_states = self.final_layer_norm(hidden_states)

        # 返回隐藏状态作为输出
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,将它们添加到输出中
        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class FlaxBartEncoderLayerCollection(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型

    def setup(self):
        # 初始化编码器层集合,每层使用不同的编号和数据类型
        self.layers = [
            FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
        ]
        # 设置层级丢弃率
        self.layerdrop = self.config.encoder_layerdrop

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        ):
            # 如果不需要输出注意力权重,则初始化一个空元组
            all_attentions = () if output_attentions else None
            # 如果不需要输出隐藏状态,则初始化一个空元组
            all_hidden_states = () if output_hidden_states else None

            # 遍历每个编码器层
            for encoder_layer in self.layers:
                if output_hidden_states:
                    # 如果需要输出隐藏状态,则将当前隐藏状态加入到所有隐藏状态的元组中
                    all_hidden_states = all_hidden_states + (hidden_states,)
                # 添加LayerDrop功能(参见https://arxiv.org/abs/1909.11556进行描述)
                dropout_probability = random.uniform(0, 1)
                # 如果非确定性且随机数小于层丢弃率,则跳过当前层
                if not deterministic and (dropout_probability < self.layerdrop):  # skip the layer
                    # 设置当前层输出为None
                    layer_outputs = (None, None)
                else:
                    # 否则,调用当前编码器层进行前向传播计算
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask,
                        output_attentions,
                        deterministic,
                    )
                # 更新当前隐藏状态为编码器层输出的第一个元素
                hidden_states = layer_outputs[0]
                # 如果需要输出注意力权重,则将当前层的注意力权重加入到所有注意力权重的元组中
                if output_attentions:
                    all_attentions = all_attentions + (layer_outputs[1],)

            # 如果需要输出隐藏状态,则将最终的隐藏状态加入到所有隐藏状态的元组中
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            # 构建模型输出结果
            outputs = (hidden_states, all_hidden_states, all_attentions)

            # 如果不使用返回字典格式,则返回输出元组中非None的部分
            if not return_dict:
                return tuple(v for v in outputs if v is not None)

            # 使用FlaxBaseModelOutput类包装输出结果并以字典格式返回
            return FlaxBaseModelOutput(
                last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
            )
# 定义一个名为 FlaxBartDecoderLayer 的类,继承自 nn.Module,表示这是一个神经网络模块
class FlaxBartDecoderLayer(nn.Module):
    # 类变量 config,指定为 BartConfig 类型,用于配置模型参数
    config: BartConfig
    # 类变量 dtype,默认为 jnp.float32 类型
    dtype: jnp.dtype = jnp.float32

    # 初始化方法,设置类的初始状态
    def setup(self) -> None:
        # 设置类的嵌入维度为配置中的 d_model 参数
        self.embed_dim = self.config.d_model
        # 初始化 self_attn 层,使用 FlaxBartAttention 自定义类,实现自注意力机制
        self.self_attn = FlaxBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            causal=True,
            dtype=self.dtype,
        )
        # 初始化 dropout_layer 层,用于随机断开神经元连接,防止过拟合
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)
        # 根据配置中的激活函数选择对应的激活函数
        self.activation_fn = ACT2FN[self.config.activation_function]
        # 初始化 activation_dropout_layer 层,对激活函数的输出进行随机断开
        self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)

        # 初始化 self_attn_layer_norm 层,用 LayerNorm 进行归一化处理
        self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化 encoder_attn 层,实现编码器-解码器注意力机制
        self.encoder_attn = FlaxBartAttention(
            config=self.config,
            embed_dim=self.embed_dim,
            num_heads=self.config.decoder_attention_heads,
            dropout=self.config.attention_dropout,
            dtype=self.dtype,
        )
        # 初始化 encoder_attn_layer_norm 层,对编码器-解码器注意力输出进行归一化
        self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
        # 初始化 fc1 层,全连接层,将输入映射到更高维度的空间
        self.fc1 = nn.Dense(
            self.config.decoder_ffn_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 初始化 fc2 层,全连接层,将高维度的输出映射回原始维度
        self.fc2 = nn.Dense(
            self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 初始化 final_layer_norm 层,对最终输出进行归一化处理
        self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 类的调用方法,定义类在被调用时的行为
    def __call__(
        self,
        hidden_states: jnp.ndarray,  # 输入的隐藏状态,使用 JAX 的数组表示
        attention_mask: jnp.ndarray,  # 注意力掩码,指定哪些位置需要注意力
        encoder_hidden_states: Optional[jnp.ndarray] = None,  # 编码器的隐藏状态,可选参数
        encoder_attention_mask: Optional[jnp.ndarray] = None,  # 编码器的注意力掩码,可选参数
        init_cache: bool = False,  # 是否初始化缓存,用于存储计算结果的中间状态
        output_attentions: bool = True,  # 是否输出注意力权重
        deterministic: bool = True,  # 是否使用确定性计算结果
    ) -> Tuple[jnp.ndarray]:
        # 保留原始输入作为残差连接的一部分
        residual = hidden_states

        # 自注意力机制
        # 调用 self_attn 方法进行自注意力计算
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
        )
        # 应用 dropout 层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 应用自注意力层的 Layer Normalization
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 跨注意力块
        cross_attn_weights = None
        if encoder_hidden_states is not None:
            # 保留当前隐藏状态作为残差连接的一部分
            residual = hidden_states

            # 调用 encoder_attn 方法进行跨注意力计算
            hidden_states, cross_attn_weights = self.encoder_attn(
                hidden_states=hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
            )
            # 应用 dropout 层
            hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
            # 添加残差连接
            hidden_states = residual + hidden_states
            # 应用跨注意力层的 Layer Normalization
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

        # 全连接层
        # 保留当前隐藏状态作为残差连接的一部分
        residual = hidden_states
        # 应用激活函数和第一个全连接层 fc1
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        # 应用激活函数后的 dropout 层
        hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
        # 应用第二个全连接层 fc2
        hidden_states = self.fc2(hidden_states)
        # 应用 dropout 层
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
        # 添加残差连接
        hidden_states = residual + hidden_states
        # 应用最终的 Layer Normalization
        hidden_states = self.final_layer_norm(hidden_states)

        # 输出设置为一个包含隐藏状态的元组
        outputs = (hidden_states,)

        # 如果需要输出注意力权重,则将它们添加到输出中
        if output_attentions:
            outputs += (self_attn_weights, cross_attn_weights)

        return outputs
class FlaxBartDecoderLayerCollection(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # 计算中使用的数据类型

    def setup(self):
        # 创建多个 FlaxBartDecoderLayer 实例作为层集合
        self.layers = [
            FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
        ]
        # 从配置中获取并设置层丢弃率
        self.layerdrop = self.config.decoder_layerdrop

    def __call__(
        self,
        hidden_states,
        attention_mask,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # 初始化用于存储所有隐藏状态、自注意力、交叉注意力的元组,根据参数决定是否存储
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None

        # 遍历每个解码器层
        for decoder_layer in self.layers:
            if output_hidden_states:
                # 如果需要输出隐藏状态,则记录当前隐藏状态
                all_hidden_states += (hidden_states,)
                # 添加层丢弃 (LayerDrop) 描述,参考论文 https://arxiv.org/abs/1909.11556

            # 随机生成丢弃概率
            dropout_probability = random.uniform(0, 1)
            # 如果不是确定性计算且随机数小于层丢弃率,则将层输出置为None
            if not deterministic and (dropout_probability < self.layerdrop):
                layer_outputs = (None, None, None)
            else:
                # 否则调用解码器层进行计算
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    init_cache=init_cache,
                    output_attentions=output_attentions,
                    deterministic=deterministic,
                )

            # 更新隐藏状态为当前解码器层的输出的第一个元素
            hidden_states = layer_outputs[0]
            if output_attentions:
                # 如果需要输出注意力,记录自注意力分数
                all_self_attns += (layer_outputs[1],)

                # 如果有编码器的隐藏状态,记录交叉注意力分数
                if encoder_hidden_states is not None:
                    all_cross_attentions += (layer_outputs[2],)

        # 如果需要输出隐藏状态,则添加最后一个解码器层的隐藏状态
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 组装输出结果,根据需求返回字典或元组
        outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]

        if not return_dict:
            return tuple(v for v in outputs if v is not None)

        # 返回带过去和交叉注意力的基础模型输出
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=all_cross_attentions,
        )


class FlaxBartClassificationHead(nn.Module):
    """用于句子级分类任务的头部模块。"""

    config: BartConfig
    inner_dim: int
    num_classes: int
    pooler_dropout: float
    dtype: jnp.dtype = jnp.float32
    # 定义模型初始化方法
    def setup(self):
        # 初始化一个全连接层对象,设置输入维度为 self.inner_dim,数据类型为 self.dtype,
        # 使用正态分布初始化权重,标准差为 self.config.init_std
        self.dense = nn.Dense(
            self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
        )
        # 初始化一个 Dropout 层对象,设置丢弃率为 self.pooler_dropout
        self.dropout = nn.Dropout(rate=self.pooler_dropout)
        # 初始化一个全连接层对象,设置输出维度为 self.num_classes,数据类型为 self.dtype,
        # 使用正态分布初始化权重,标准差为 self.config.init_std
        self.out_proj = nn.Dense(
            self.num_classes,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

    # 定义模型调用方法
    def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
        # 对输入 hidden_states 应用 Dropout 层,根据 deterministic 参数决定是否使用确定性推断
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 将经过 Dropout 处理后的 hidden_states 输入到全连接层 self.dense 中进行线性变换
        hidden_states = self.dense(hidden_states)
        # 对 hidden_states 中的每个元素应用双曲正切函数
        hidden_states = jnp.tanh(hidden_states)
        # 再次对经过 tanh 函数处理后的 hidden_states 应用 Dropout 层
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        # 将经过 Dropout 处理后的 hidden_states 输入到全连接层 self.out_proj 中进行线性变换
        hidden_states = self.out_proj(hidden_states)
        # 返回处理后的 hidden_states
        return hidden_states
# 定义 FlaxBartEncoder 类,继承自 nn.Module
class FlaxBartEncoder(nn.Module):
    # 引入 BartConfig 类型的配置参数 config
    config: BartConfig
    # 嵌入词汇表的 nn.Embed 类型对象 embed_tokens
    embed_tokens: nn.Embed
    # 计算过程中使用的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    # 模型初始化方法
    def setup(self):
        # 根据配置参数中的 dropout 率创建 Dropout 层
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 获取模型的嵌入维度
        embed_dim = self.config.d_model
        # 设置填充索引,从配置参数中获取
        self.padding_idx = self.config.pad_token_id
        # 设置最大源序列长度,从配置参数中获取
        self.max_source_positions = self.config.max_position_embeddings
        # 设置嵌入缩放因子,根据配置参数是否需要缩放
        self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0

        # Bart 模型的特殊设置,如果指定了 padding_idx 则需要偏移嵌入 ids 2 个单位
        # 并相应调整 num_embeddings。其他模型没有这种特殊处理
        self.offset = 2
        # 初始化嵌入位置的 nn.Embed 层
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,  # 嵌入位置的最大长度加上偏移量
            embed_dim,  # 嵌入的维度
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 初始化方法为正态分布
            dtype=self.dtype,  # 指定数据类型
        )
        # 创建包含多个编码器层的集合
        self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
        # 对嵌入层进行 LayerNorm 规范化
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 模型调用方法
    def __call__(
        self,
        input_ids,  # 输入的 token ids
        attention_mask,  # 注意力遮罩
        position_ids,  # 位置 ids
        output_attentions: bool = False,  # 是否输出注意力权重
        output_hidden_states: bool = False,  # 是否输出隐藏状态
        return_dict: bool = True,  # 是否以字典形式返回结果
        deterministic: bool = True,  # 是否确定性计算
    ):
        # 获取输入的形状信息
        input_shape = input_ids.shape
        # 将输入 ids 展平为二维张量
        input_ids = input_ids.reshape(-1, input_shape[-1])

        # 根据嵌入 ids 获取对应的嵌入向量,并乘以嵌入缩放因子
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        # 根据位置 ids 获取嵌入的位置向量
        embed_pos = self.embed_positions(position_ids + self.offset)

        # 将输入的嵌入向量和位置向量相加作为初始隐藏状态
        hidden_states = inputs_embeds + embed_pos
        # 对隐藏状态进行嵌入层规范化
        hidden_states = self.layernorm_embedding(hidden_states)
        # 使用 Dropout 层对隐藏状态进行随机置零处理
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        # 将隐藏状态传递给多层编码器层处理
        outputs = self.layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果不以字典形式返回结果,则直接返回 outputs
        if not return_dict:
            return outputs

        # 以 FlaxBaseModelOutput 类型的字典形式返回结果
        return FlaxBaseModelOutput(
            last_hidden_state=outputs.last_hidden_state,  # 最后的隐藏状态
            hidden_states=outputs.hidden_states,  # 隐藏状态列表
            attentions=outputs.attentions,  # 注意力权重列表
        )
    # 初始化方法,设置模型的一些基本属性和层
    def setup(self):
        # 定义一个dropout层,用于在训练过程中随机丢弃部分神经元,防止过拟合
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        # 获取嵌入向量的维度,填充标记的索引,以及目标位置的最大值
        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.max_target_positions = self.config.max_position_embeddings
        # 根据配置是否对嵌入向量进行缩放
        self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0

        # 如果padding_idx被指定,则调整嵌入id,通过offset为2调整num_embeddings
        # 其他模型不需要此调整
        self.offset = 2
        # 初始化位置嵌入层,输入大小为最大位置嵌入加上偏移量,输出维度为embed_dim
        self.embed_positions = nn.Embed(
            self.config.max_position_embeddings + self.offset,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
            dtype=self.dtype,
        )

        # 初始化解码器层集合
        self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
        # 初始化层归一化层,用于归一化嵌入层的输出
        self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)

    # 调用方法,执行模型的前向计算过程
    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 获取输入的形状信息,并重新调整input_ids的形状
        input_shape = input_ids.shape
        input_ids = input_ids.reshape(-1, input_shape[-1])

        # 嵌入输入token
        inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        # 嵌入位置信息
        positions = self.embed_positions(position_ids + self.offset)

        # 将嵌入的token和位置信息相加得到隐藏状态
        hidden_states = inputs_embeds + positions
        # 对隐藏状态进行层归一化
        hidden_states = self.layernorm_embedding(hidden_states)

        # 对隐藏状态应用dropout层,根据deterministic参数确定是否确定性操作
        hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)

        # 通过解码器层进行前向传播
        outputs = self.layers(
            hidden_states,
            attention_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 如果return_dict为False,则直接返回outputs
        if not return_dict:
            return outputs

        # 如果return_dict为True,则返回包含过去和交叉注意力的FlaxBaseModelOutputWithPastAndCrossAttentions对象
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
class FlaxBartModule(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # 计算时的数据类型

    def setup(self):
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
            dtype=self.dtype,
        )

        # 初始化编码器和解码器模块
        self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
        self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)

    def _get_encoder_module(self):
        return self.encoder

    def _get_decoder_module(self):
        return self.decoder

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 调用编码器并获取其输出
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 调用解码器并获取其输出
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        # 根据 return_dict 决定返回类型
        if not return_dict:
            return decoder_outputs + encoder_outputs

        # 返回经过 Seq2Seq 模型输出后的结果
        return FlaxSeq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
    config_class = BartConfig
    base_model_prefix: str = "model"
    module_class: nn.Module = None

    def __init__(
        self,
        config: BartConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 使用给定的配置初始化模块
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
    # 初始化模型参数的函数,使用给定的随机数生成器 rng,输入形状 input_shape 和可选的参数 params
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量 input_ids,全零张量,数据类型为整数
        input_ids = jnp.zeros(input_shape, dtype="i4")
        # 确保初始化阶段适用于 FlaxBartForSequenceClassificationModule
        # 将 input_ids 的最后一个位置设为配置中的 eos_token_id
        input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
        # 创建注意力掩码,与 input_ids 大小相同,全为 1
        attention_mask = jnp.ones_like(input_ids)
        # 解码器输入与输入相同
        decoder_input_ids = input_ids
        # 解码器注意力掩码与输入相同
        decoder_attention_mask = jnp.ones_like(input_ids)

        # 获取批量大小和序列长度
        batch_size, sequence_length = input_ids.shape
        # 创建位置编码张量,形状与 input_ids 相同,内容为序列长度的广播值
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
        # 解码器位置编码与输入相同
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 分割随机数生成器 rng 以用于参数和 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        # 组合随机数生成器
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 使用模型的初始化方法初始化随机参数
        random_params = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            position_ids,
            decoder_position_ids,
        )["params"]

        # 如果提供了初始参数 params,则将随机生成的参数与其合并
        if params is not None:
            # 展平并解冻随机生成的参数
            random_params = flatten_dict(unfreeze(random_params))
            # 展平并解冻提供的参数
            params = flatten_dict(unfreeze(params))
            # 将缺失的键从随机参数复制到提供的参数中
            for missing_key in self._missing_keys:
                params[missing_key] = random_params[missing_key]
            # 清空缺失键集合
            self._missing_keys = set()
            # 冻结并返回合并后的参数
            return freeze(unflatten_dict(params))
        else:
            # 如果没有提供初始参数,则直接返回随机生成的参数
            return random_params
    def init_cache(self, batch_size, max_length, encoder_outputs):
        r"""
        Args:
            batch_size (`int`):
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
            max_length (`int`):
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
                cache.
            encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
                `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
                `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
                is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
                cross-attention of the decoder.
        """
        # 初始化输入变量以检索缓存
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
        )

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )

        # 使用模型的初始化方法初始化变量
        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],  # 使用编码器输出的最后隐藏状态初始化
            init_cache=True,
            method=_decoder_forward,  # 我们只需调用解码器来初始化缓存
        )
        # 解冻缓存变量并返回
        return unfreeze(init_variables["cache"])

    @add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig)
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```
        >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration

        >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
        >>> encoder_outputs = model.encode(**inputs)
        ```"""
        # 初始化输出配置,如果未指定则使用模型配置中的默认值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 如果 attention_mask 未提供,则使用全 1 的张量作为默认值
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        # 如果 position_ids 未提供,则根据 input_ids 的形状自动广播生成位置编码
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 处理可能存在的随机数生成器 PRNG
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        # 定义内部函数 _encoder_forward 用于编码器的前向传播
        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        # 调用 Flax 模型的 apply 方法进行编码器的正向传播
        return self.module.apply(
            {"params": params or self.params},
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

    @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 确定是否输出注意力权重,默认从配置中获取
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        # 确定是否输出隐藏状态,默认从配置中获取
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 确定是否返回字典形式的输出,默认从配置中获取
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # 准备编码器的输入
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)  # 使用与输入相同形状的全1注意力掩码
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
            # 若未提供位置编码,生成一个默认的位置编码矩阵

        # 准备解码器的输入
        if decoder_input_ids is None:
            decoder_input_ids = shift_tokens_right(
                input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
            )
            # 若未提供解码器输入,使用右移函数生成以pad_token_id开头的序列
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
            # 使用与解码器输入相同形状的全1注意力掩码
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )
            # 若未提供解码器位置编码,生成一个默认的位置编码矩阵

        # 处理需要的随机数生成器
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},  # 提供参数字典,若未提供则使用默认参数self.params
            input_ids=jnp.array(input_ids, dtype="i4"),  # 转换输入ids为指定类型的JAX数组
            attention_mask=jnp.array(attention_mask, dtype="i4"),  # 转换注意力掩码为指定类型的JAX数组
            position_ids=jnp.array(position_ids, dtype="i4"),  # 转换位置编码为指定类型的JAX数组
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),  # 转换解码器输入ids为指定类型的JAX数组
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),  # 转换解码器注意力掩码为指定类型的JAX数组
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),  # 转换解码器位置编码为指定类型的JAX数组
            output_attentions=output_attentions,  # 是否输出注意力权重
            output_hidden_states=output_hidden_states,  # 是否输出隐藏状态
            return_dict=return_dict,  # 是否返回字典形式的输出
            deterministic=not train,  # 是否为确定性计算,取决于train参数
            rngs=rngs,  # 随机数生成器字典
        )
# 为 FlaxBartModel 类添加文档字符串,描述其作为 Bart 模型的基础转换器,输出原始隐藏状态而无需特定的输出头。
class FlaxBartModel(FlaxBartPreTrainedModel):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32  # 计算的数据类型为 jnp.float32
    module_class = FlaxBartModule

# 向 FlaxBartModel 类附加调用示例的文档字符串,以及 BART 的起始文档字符串。
append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)

# 定义 FlaxBartForConditionalGenerationModule 类
class FlaxBartForConditionalGenerationModule(nn.Module):
    config: BartConfig
    dtype: jnp.dtype = jnp.float32
    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros

    def setup(self):
        # 初始化模型为 FlaxBartModule 实例,使用给定的配置和数据类型
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
        # 初始化 lm_head 作为全连接层,输出维度为模型共享词汇表大小,不使用偏置,使用给定的初始化器
        self.lm_head = nn.Dense(
            self.model.shared.num_embeddings,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # 初始化 final_logits_bias 作为模型参数,维度为 (1, 模型共享词汇表大小),使用给定的偏置初始化器
        self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))

    # 获取编码器模块的方法
    def _get_encoder_module(self):
        return self.model.encoder

    # 获取解码器模块的方法
    def _get_decoder_module(self):
        return self.model.decoder

    # 定义类的调用方法,接收多个输入和控制参数,并返回条件生成模型的输出
    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        ):
            # 使用模型进行推理,传入输入参数:input_ids, attention_mask, decoder_input_ids等
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                position_ids=position_ids,
                decoder_position_ids=decoder_position_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=deterministic,
            )

            # 从模型输出中获取隐藏状态
            hidden_states = outputs[0]

            # 如果配置要求共享词嵌入,则获取共享的嵌入层,并应用到语言模型的输出上
            if self.config.tie_word_embeddings:
                shared_embedding = self.model.variables["params"]["shared"]["embedding"]
                lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
            else:
                # 否则直接使用语言模型头部生成预测logits
                lm_logits = self.lm_head(hidden_states)

            # 将最终logits偏置加到语言模型logits上
            lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))

            # 如果不需要返回字典形式的输出,则返回一个元组,包含lm_logits和其余输出
            if not return_dict:
                output = (lm_logits,) + outputs[1:]
                return output

            # 返回一个FlaxSeq2SeqLMOutput对象,包含各种输出,如logits、隐藏状态、注意力等
            return FlaxSeq2SeqLMOutput(
                logits=lm_logits,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
# 使用装饰器为类添加文档字符串,指定了BART模型带有语言建模头部,可用于摘要生成
@add_start_docstrings(
    "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
    # 指定模块类为FlaxBartForConditionalGenerationModule
    module_class = FlaxBartForConditionalGenerationModule
    # 指定数据类型为jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 使用装饰器添加解码方法的文档字符串
    @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
    # 替换返回值文档字符串,指定输出类型为FlaxCausalLMOutputWithCrossAttentions,配置类为BartConfig
    @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        # 解码方法,用于生成输出
        pass

    # 为生成准备输入的方法,准备生成时需要的输入数据
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jax.Array] = None,
        decoder_attention_mask: Optional[jax.Array] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 初始化缓存,用于存储先前的键值对
        batch_size, seq_length = decoder_input_ids.shape
        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
        
        # 根据解码器的注意力掩码生成扩展的注意力掩码
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回生成所需的输入数据字典
        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    # 更新生成输入的方法,根据模型输出和模型关键字参数更新输入数据
    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs


# Flax BART 条件生成的文档字符串,描述了返回的摘要示例和使用的例子
FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
    Returns:

    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
    >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
    # 使用预训练的 FlaxBart 模型加载条件生成模型,用于生成文本摘要
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
    # 使用预训练的 tokenizer 加载 BART 模型的分词器
    
    >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
    # 待摘要的文章内容
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
    # 使用分词器对文章进行分词,并封装成适合模型输入的格式
    
    >>> # Generate Summary
    # 生成摘要的过程
    >>> summary_ids = model.generate(inputs["input_ids"]).sequences
    # 使用模型生成输入文章的摘要序列
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
    # 打印生成的摘要,跳过特殊标记并保持分词时的空格处理方式
    
    
    Mask filling example:
    
    
    >>> import jax
    # 导入 JAX 库,用于高性能数值计算
    >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
    
    >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large")
    # 使用预训练的 FlaxBart 模型加载条件生成模型
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
    # 使用预训练的 tokenizer 加载 BART 模型的分词器
    
    >>> TXT = "My friends are <mask> but they eat too many carbs."
    # 带有掩码填充的文本示例
    >>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]
    # 使用分词器对带有掩码的文本进行分词,并封装成适合模型输入的格式
    
    >>> logits = model(input_ids).logits
    # 通过模型生成输入文本的 logits,用于获取每个词的预测概率
    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
    # 找到掩码位置的索引
    >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
    # 对掩码位置的 logits 进行 softmax 处理,得到预测概率分布
    >>> values, predictions = jax.lax.top_k(probs, k=1)
    # 获取最高概率的预测值和其对应的索引
    
    >>> tokenizer.decode(predictions).split()
    # 解码预测的标记并拆分成词汇列表
"""
将调用文档字符串覆盖为 BART 输入文档字符串和 FLAX BART 条件生成文档字符串的组合
"""
overwrite_call_docstring(
    FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING
)
"""
追加并替换 FlaxBartForConditionalGeneration 类的返回文档字符串
"""
append_replace_return_docstrings(
    FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)

"""
定义一个用于序列分类的 FlaxBartForSequenceClassificationModule 类
"""
class FlaxBartForSequenceClassificationModule(nn.Module):
    """
    BART 的配置
    """
    config: BartConfig
    """
    数据类型,默认为 32 位浮点数
    """
    dtype: jnp.dtype = jnp.float32
    """
    可选的标签数目
    """
    num_labels: Optional[int] = None

    """
    模型的设置方法
    """
    def setup(self):
        """
        创建 BART 模型实例
        """
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
        """
        创建用于分类的 BART 分类头
        """
        self.classification_head = FlaxBartClassificationHead(
            config=self.config,
            inner_dim=self.config.d_model,
            num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
            pooler_dropout=self.config.classifier_dropout,
        )

    """
    获取编码器模块的私有方法
    """
    def _get_encoder_module(self):
        return self.model.encoder

    """
    获取解码器模块的私有方法
    """
    def _get_decoder_module(self):
        return self.model.decoder

    """
    定义类实例被调用时的行为
    """
    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
        """
        输入序列分类模块的参数:
        input_ids: 输入的 token IDs
        attention_mask: 注意力遮罩
        decoder_input_ids: 解码器的输入 token IDs
        decoder_attention_mask: 解码器的注意力遮罩
        position_ids: 位置 IDs
        decoder_position_ids: 解码器的位置 IDs
        output_attentions: 是否输出注意力权重
        output_hidden_states: 是否输出隐藏状态
        return_dict: 是否返回字典格式的输出
        deterministic: 是否确定性运行
        """

            # 实例方法主体为空,由子类实现具体逻辑
            pass
        ):
            # 调用模型进行推理,获取输出结果
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                position_ids=position_ids,
                decoder_position_ids=decoder_position_ids,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=deterministic,
            )

            # 获取模型输出中的最后一个隐藏状态
            hidden_states = outputs[0]  # 最后一个隐藏状态

            # 创建一个掩码,标记输入中的 <eos> 位置
            eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)

            # 处理特定的 JAX 编译错误类型,确保避免 JIT 编译中的错误
            if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
                # 检查每个示例中 <eos> 标记的数量是否一致
                if len(jnp.unique(eos_mask.sum(1))) > 1:
                    raise ValueError("所有示例必须具有相同数量的 <eos> 标记。")

                # 检查是否有示例缺少 <eos> 标记
                if any(eos_mask.sum(1) == 0):
                    raise ValueError("输入中缺少 <eos> 标记。")

                # 为每个示例保留最后一个 <eos> 标记
                eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
                eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)

            # 使用 eos_mask 对隐藏状态进行加权求和,以获得句子表示
            sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)

            # 将句子表示传递给分类头,获取分类 logits
            logits = self.classification_head(sentence_representation, deterministic=deterministic)

            # 如果不需要返回字典,则返回输出的元组
            if not return_dict:
                output = (logits,) + outputs[1:]
                return output

            # 构造 FlaxSeq2SeqSequenceClassifierOutput 对象,封装模型输出
            return FlaxSeq2SeqSequenceClassifierOutput(
                logits=logits,
                decoder_hidden_states=outputs.decoder_hidden_states,
                decoder_attentions=outputs.decoder_attentions,
                cross_attentions=outputs.cross_attentions,
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,
                encoder_hidden_states=outputs.encoder_hidden_states,
                encoder_attentions=outputs.encoder_attentions,
            )
# 使用自定义的 docstring 添加起始注释给 FlaxBartForSequenceClassification 类,指定其用途和应用场景
@add_start_docstrings(
    """
    Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    BART_START_DOCSTRING,  # 引用预定义的 Bart 模型的起始注释
)
class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel):
    module_class = FlaxBartForSequenceClassificationModule  # 设定模型类
    dtype = jnp.float32  # 设置数据类型


# 向 FlaxBartForSequenceClassification 类添加调用样例的文档字符串
append_call_sample_docstring(
    FlaxBartForSequenceClassification,
    _CHECKPOINT_FOR_DOC,  # 引用检查点文档
    FlaxSeq2SeqSequenceClassifierOutput,  # 引用输出类文档
    _CONFIG_FOR_DOC,  # 引用配置文档
)


# 定义 FlaxBartForQuestionAnsweringModule 类,继承自 nn.Module
class FlaxBartForQuestionAnsweringModule(nn.Module):
    config: BartConfig  # 使用 BartConfig 配置
    dtype: jnp.dtype = jnp.float32  # 设置数据类型为 float32
    num_labels = 2  # 设定标签数量为 2

    def setup(self):
        self.model = FlaxBartModule(config=self.config, dtype=self.dtype)  # 使用配置和数据类型初始化模型
        self.qa_outputs = nn.Dense(  # 定义问题-回答输出层
            self.num_labels,  # 输出层标签数量
            dtype=self.dtype,  # 输出层数据类型
            kernel_init=jax.nn.initializers.normal(self.config.init_std),  # 使用正态分布初始化权重
        )

    def _get_encoder_module(self):
        return self.model.encoder  # 获取编码器模块

    def _get_decoder_module(self):
        return self.model.decoder  # 获取解码器模块

    def __call__(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        position_ids,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        # 调用模型进行正向传播
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        sequence_output = outputs[0]  # 提取序列输出

        logits = self.qa_outputs(sequence_output)  # 通过问题-回答输出层计算 logits
        start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)  # 分割 logits 得到起始和结束 logits
        start_logits = start_logits.squeeze(-1)  # 压缩起始 logits 的最后一维
        end_logits = end_logits.squeeze(-1)  # 压缩结束 logits 的最后一维

        if not return_dict:
            output = (start_logits, end_logits) + outputs[1:]  # 如果不返回字典,则将输出整合为元组
            return output

        # 返回字典格式的输出
        return FlaxSeq2SeqQuestionAnsweringModelOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )


# 使用自定义的 docstring 添加起始注释给 FlaxBartForSequenceClassification 类,指定其用途和应用场景
@add_start_docstrings(
    """
    BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    """,
    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
    # 创建一个层用于在隐藏状态输出的基础上计算“span起始位置logits”和“span结束位置logits”。
    """,
    BART_START_DOCSTRING,
    # 使用预定义的 BART_START_DOCSTRING 常量作为文档字符串的起始部分
)

# 定义一个类,继承自FlaxBartPreTrainedModel,用于问答任务
class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel):
    # 模块类设置为FlaxBartForQuestionAnsweringModule
    module_class = FlaxBartForQuestionAnsweringModule
    # 数据类型设置为32位浮点数
    dtype = jnp.float32

# 向FlaxBartForQuestionAnswering类附加一个函数调用样例的文档字符串
append_call_sample_docstring(
    FlaxBartForQuestionAnswering,
    _CHECKPOINT_FOR_DOC,
    FlaxSeq2SeqQuestionAnsweringModelOutput,
    _CONFIG_FOR_DOC,
)

# 定义一个类,继承自FlaxPreTrainedModel,用于BART解码器预训练模型
class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
    # 配置类设置为BartConfig
    config_class = BartConfig
    # 基础模型前缀设置为"model"
    base_model_prefix: str = "model"
    # 模块类初始化为None
    module_class: nn.Module = None

    def __init__(
        self,
        config: BartConfig,
        input_shape: Tuple[int] = (1, 1),
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        _do_init: bool = True,
        **kwargs,
    ):
        # 设置配置为解码器模式
        config.is_decoder = True
        # 设置不是编码器-解码器模式
        config.is_encoder_decoder = False
        # 使用配置和数据类型初始化模块
        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # 调用父类初始化方法
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
        # 初始化输入张量
        input_ids = jnp.zeros(input_shape, dtype="i4")
        attention_mask = jnp.ones_like(input_ids)

        # 获取批量大小和序列长度
        batch_size, sequence_length = input_ids.shape
        # 生成位置编码张量
        position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # 分割随机数生成器
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}
        # 初始化编码器隐藏状态和注意力掩码
        encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
        encoder_attention_mask = attention_mask
        # 调用模块的初始化方法
        module_init_outputs = self.module.init(
            rngs,
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            return_dict=False,
        )
        # 返回模块初始化的参数
        return module_init_outputs["params"]

    def init_cache(self, batch_size, max_length):
        r"""
        Args:
            batch_size (`int`):
                用于快速自回归解码的批量大小,定义了初始化缓存的批量大小。
            max_length (`int`):
                自回归解码的最大可能长度,定义了初始化缓存的序列长度。
        """
        # 初始化用于检索缓存的输入变量
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        attention_mask = jnp.ones_like(input_ids, dtype="i4")
        position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        # 调用模块的初始化方法,设置init_cache=True以获取缓存
        init_variables = self.module.init(
            jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
        )
        # 解冻并返回初始化变量的缓存部分
        return unfreeze(init_variables["cache"])

    @add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING)
    # 定义一个特殊方法 __call__,使得对象可以被调用
    def __call__(
        # 参数 input_ids: 接受一个 NumPy 数组,用于输入模型的标识符
        self,
        input_ids: jnp.ndarray,
        # 参数 attention_mask: 可选参数,用于指定哪些标识符需要被注意
        attention_mask: Optional[jnp.ndarray] = None,
        # 参数 position_ids: 可选参数,用于指定输入标识符的位置信息
        position_ids: Optional[jnp.ndarray] = None,
        # 参数 encoder_hidden_states: 可选参数,编码器的隐藏状态
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        # 参数 encoder_attention_mask: 可选参数,编码器的注意力掩码
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        # 参数 output_attentions: 可选参数,指示是否返回注意力权重
        output_attentions: Optional[bool] = None,
        # 参数 output_hidden_states: 可选参数,指示是否返回所有隐藏状态
        output_hidden_states: Optional[bool] = None,
        # 参数 return_dict: 可选参数,指示是否返回结果字典形式
        return_dict: Optional[bool] = None,
        # 参数 train: 布尔类型参数,指示当前是否处于训练模式
        train: bool = False,
        # 参数 params: 字典类型参数,用于存储额外的参数信息
        params: dict = None,
        # 参数 past_key_values: 字典类型参数,用于存储过去的键值信息
        past_key_values: dict = None,
        # 参数 dropout_rng: PRNGKey 类型参数,用于控制 dropout 行为的随机数生成器
        dropout_rng: PRNGKey = None,
        ):
            # 如果 output_attentions 参数未指定,则使用模型配置中的默认值
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            # 如果 output_hidden_states 参数未指定,则使用模型配置中的默认值
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            # 如果 return_dict 参数未指定,则使用模型配置中的默认值
            return_dict = return_dict if return_dict is not None else self.config.return_dict

            # 如果 encoder_hidden_states 存在且未提供 encoder_attention_mask,则创建一个全为 1 的注意力掩码
            if encoder_hidden_states is not None and encoder_attention_mask is None:
                batch_size, sequence_length = encoder_hidden_states.shape[:2]
                encoder_attention_mask = jnp.ones((batch_size, sequence_length))

            # 准备解码器的输入
            # 如果 attention_mask 未提供,则创建一个与 input_ids 形状相同的全为 1 的注意力掩码
            if attention_mask is None:
                attention_mask = jnp.ones_like(input_ids)
            # 如果 position_ids 未提供,则根据 input_ids 的形状创建位置 ID
            if position_ids is None:
                batch_size, sequence_length = input_ids.shape
                position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

            # 处理需要的随机数生成器(PRNG)
            rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

            inputs = {"params": params or self.params}

            # 如果传入了 past_key_values,则将其作为 cache 输入,同时设置 mutable 标志确保 cache 可变
            if past_key_values:
                inputs["cache"] = past_key_values
                mutable = ["cache"]
            else:
                mutable = False

            # 调用模型的 apply 方法,传递各种输入参数
            outputs = self.module.apply(
                inputs,
                input_ids=jnp.array(input_ids, dtype="i4"),
                attention_mask=jnp.array(attention_mask, dtype="i4"),
                position_ids=jnp.array(position_ids, dtype="i4"),
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                deterministic=not train,
                rngs=rngs,
                mutable=mutable,
            )

            # 将更新后的 cache 添加到模型输出中(仅在 return_dict=True 且 past_key_values 不为空时执行)
            if past_key_values is not None and return_dict:
                outputs, past_key_values = outputs
                outputs["past_key_values"] = unfreeze(past_key_values["cache"])
                return outputs
            elif past_key_values is not None and not return_dict:
                outputs, past_key_values = outputs
                # 在输出的第一个元素后添加解冻的 past_key_values["cache"],用于非字典返回模式
                outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

            # 返回模型的输出
            return outputs
class FlaxBartDecoderWrapper(nn.Module):
    """
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    """

    config: BartConfig  # 定义一个成员变量 config,类型为 BartConfig,用于存储模型的配置信息
    dtype: jnp.dtype = jnp.float32  # 定义一个成员变量 dtype,指定数据类型为 jnp.float32,默认值为 jnp.float32

    def setup(self):
        embed_dim = self.config.d_model  # 从 config 中获取模型的 embedding 维度
        embed_tokens = nn.Embed(  # 创建一个嵌入层,用于处理模型的词汇表和 embedding 维度
            self.config.vocab_size,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),  # 使用正态分布初始化嵌入层权重
            dtype=self.dtype,
        )
        self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
        # 初始化一个 FlaxBartDecoder 对象,传入配置、嵌入层和数据类型

    def __call__(self, *args, **kwargs):
        return self.decoder(*args, **kwargs)
        # 调用 FlaxBartDecoder 对象的 __call__ 方法,将参数传递给 decoder


class FlaxBartForCausalLMModule(nn.Module):
    config: BartConfig  # 定义一个成员变量 config,类型为 BartConfig,用于存储模型的配置信息
    dtype: jnp.dtype = jnp.float32  # 定义一个成员变量 dtype,指定数据类型为 jnp.float32,默认值为 jnp.float32

    def setup(self):
        self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
        # 初始化一个 FlaxBartDecoderWrapper 对象,传入配置和数据类型
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),  # 使用正态分布初始化 Dense 层的权重
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        position_ids,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(  # 调用 self.model 对象,传递所有参数
            input_ids,
            attention_mask,
            position_ids,
            encoder_hidden_states,
            encoder_attention_mask,
            deterministic=deterministic,
            init_cache=init_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]  # 获取模型输出的隐藏状态

        if self.config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
            # 如果配置指定共享词嵌入,则从模型的变量中获取共享的嵌入层
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
            # 应用共享的嵌入层权重计算 LM logits
        else:
            lm_logits = self.lm_head(hidden_states)
            # 否则直接计算 LM logits

        if not return_dict:
            return (lm_logits,) + outputs[1:]
            # 如果不返回字典,则返回 LM logits 和其他输出项

        return FlaxCausalLMOutputWithCrossAttentions(
            logits=lm_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )
        # 返回带交叉注意力的因果语言模型输出


@add_start_docstrings(
    """
    Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
    e.g for autoregressive tasks.
    """,
    BART_START_DOCSTRING,
)
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
    module_class = FlaxBartForCausalLMModule
    # 定义一个 FlaxBartForCausalLM 类,继承自 FlaxBartDecoderPreTrainedModel,指定模块类为 FlaxBartForCausalLMModule
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
        # initializing the cache
        # 获取输入张量的批量大小和序列长度
        batch_size, seq_length = input_ids.shape

        # 使用模型的方法初始化缓存,返回过去的键值对
        past_key_values = self.init_cache(batch_size, max_length)
        
        # 注意:通常需要为超出输入长度和缓存长度之外的位置在 attention_mask 中填入 0
        # 但由于解码器使用因果掩码,这些位置已经被掩码了
        # 因此,我们可以在这里创建一个静态的 attention_mask,这样更有效率
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        
        # 如果提供了 attention_mask,则计算位置 ids
        if attention_mask is not None:
            position_ids = attention_mask.cumsum(axis=-1) - 1
            # 使用 lax.dynamic_update_slice 将 attention_mask 更新到 extended_attention_mask 中
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
        else:
            # 否则,广播创建位置 ids
            position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))

        # 返回准备好的输入字典,包括过去的键值对、扩展的注意力掩码和位置 ids
        return {
            "past_key_values": past_key_values,
            "attention_mask": extended_attention_mask,
            "position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        # 更新生成阶段的输入参数
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
        return model_kwargs
# 调用函数 append_call_sample_docstring,用于为指定模型和相关对象添加示例文档字符串
append_call_sample_docstring(
    FlaxBartForCausalLM,               # 参数1: FlaxBartForCausalLM 模型类
    _CHECKPOINT_FOR_DOC,               # 参数2: _CHECKPOINT_FOR_DOC 常量,表示检查点
    FlaxCausalLMOutputWithCrossAttentions,  # 参数3: FlaxCausalLMOutputWithCrossAttentions 类,带有跨注意力的输出
    _CONFIG_FOR_DOC,                   # 参数4: _CONFIG_FOR_DOC 常量,表示配置
)

.\models\bart\modeling_tf_bart.py

    # 创建一个 mask tensor,用于标记输入的自回归性质
    """

    # "Verify that `labels` has only positive values and -100"
    assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))

    # Make sure the assertion op is called by wrapping the result in an identity no-op
    with tf.control_dependencies([assert_gte0]):
        # 确保断言操作被调用,并返回与输入相同的 shifted_input_ids 张量
        shifted_input_ids = tf.identity(shifted_input_ids)

    return shifted_input_ids
    # 创建用于双向自注意力的因果掩码。
    bsz = input_ids_shape[0]  # 获取批次大小
    tgt_len = input_ids_shape[1]  # 获取目标长度
    mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE  # 创建一个初始掩码矩阵,用负无穷大填充
    
    mask_cond = tf.range(shape_list(mask)[-1])  # 创建一个与掩码矩阵最后一个维度大小相等的序列
    
    # 将掩码矩阵的下三角部分置零,实现因果性,确保每个位置只能依赖于它之前的位置
    mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
    
    if past_key_values_length > 0:
        # 如果存在过去的键值对长度,则在掩码矩阵左侧填充零,以匹配过去键值对的长度
        mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
    
    # 使用 tf.tile 扩展掩码矩阵的维度以匹配输入的批次大小,并返回结果
    return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
    def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
        """
        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
        """
        # 获取注意力掩码的序列长度
        src_len = shape_list(mask)[1]
        # 如果未指定目标长度,则使用源长度作为目标长度
        tgt_len = tgt_len if tgt_len is not None else src_len
        # 创建常数张量,值为1.0
        one_cst = tf.constant(1.0)
        # 将注意力掩码转换为与 one_cst 相同数据类型的张量
        mask = tf.cast(mask, dtype=one_cst.dtype)
        # 在第二维和第三维上对注意力掩码进行复制扩展
        expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))

        # 返回扩展后的掩码,并乘以一个大负数,表示未关注的区域
        return (one_cst - expanded_mask) * LARGE_NEGATIVE


class TFBartLearnedPositionalEmbedding(keras.layers.Embedding):
    """
    This module learns positional embeddings up to a fixed maximum size.
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
        # 如果 padding_idx 被指定,Bart 模型会偏移嵌入的 id 值并相应调整 num_embeddings
        # 这是一个针对 Bart 模型的特殊处理,其他模型不需要
        self.offset = 2
        super().__init__(num_embeddings + self.offset, embedding_dim, **kwargs)

    def call(
        self,
        input_shape: Optional[tf.TensorShape] = None,
        past_key_values_length: int = 0,
        position_ids: tf.Tensor | None = None,
    ):
        """Input is expected to be of size [bsz x seqlen]."""
        if position_ids is None:
            # 如果未提供位置 id,则根据输入形状中的序列长度创建位置 id
            seq_len = input_shape[1]
            position_ids = tf.range(seq_len, delta=1, name="range")
            position_ids += past_key_values_length

        # 确定位置 id 的数据类型,并将其与偏移量相加后传递给父类的调用方法
        offset_dtype = position_ids.dtype if isinstance(position_ids, tf.Tensor) else tf.int32
        return super().call(position_ids + tf.constant(self.offset, dtype=offset_dtype))


class TFBartAttention(keras.layers.Layer):
    """Multi-headed attention from "Attention Is All You Need"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # 初始化注意力层的参数
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = keras.layers.Dropout(dropout)
        self.head_dim = embed_dim // num_heads
        # 确保 embed_dim 能够被 num_heads 整除
        if (self.head_dim * num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
                f" and `num_heads`: {num_heads})."
            )
        # 缩放因子,用于缩放注意力分数
        self.scaling = self.head_dim**-0.5
        self.is_decoder = is_decoder

        # 初始化线性变换层
        self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
        self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
        self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
        self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")

    def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
        # 重塑张量形状,以便进行多头注意力计算
        return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
    # 定义一个方法,用于调用自定义层对象
    def call(
        self,
        hidden_states: tf.Tensor,
        key_value_states: tf.Tensor | None = None,
        past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
        attention_mask: tf.Tensor | None = None,
        layer_head_mask: tf.Tensor | None = None,
        training: Optional[bool] = False,
    ):
        # 如果已经构建过,则直接返回,不再重复构建
        if self.built:
            return
        # 将构建标志设置为 True,表示已经构建
        self.built = True
        # 如果存在 self.k_proj 属性,则构建 k_proj 层
        if getattr(self, "k_proj", None) is not None:
            # 在名为 self.k_proj 的命名作用域下,构建 k_proj 层
            with tf.name_scope(self.k_proj.name):
                self.k_proj.build([None, None, self.embed_dim])
        # 如果存在 self.q_proj 属性,则构建 q_proj 层
        if getattr(self, "q_proj", None) is not None:
            # 在名为 self.q_proj 的命名作用域下,构建 q_proj 层
            with tf.name_scope(self.q_proj.name):
                self.q_proj.build([None, None, self.embed_dim])
        # 如果存在 self.v_proj 属性,则构建 v_proj 层
        if getattr(self, "v_proj", None) is not None:
            # 在名为 self.v_proj 的命名作用域下,构建 v_proj 层
            with tf.name_scope(self.v_proj.name):
                self.v_proj.build([None, None, self.embed_dim])
        # 如果存在 self.out_proj 属性,则构建 out_proj 层
        if getattr(self, "out_proj", None) is not None:
            # 在名为 self.out_proj 的命名作用域下,构建 out_proj 层
            with tf.name_scope(self.out_proj.name):
                self.out_proj.build([None, None, self.embed_dim])
class TFBartEncoderLayer(keras.layers.Layer):
    # TFBartEncoderLayer 类定义,继承自 keras.layers.Layer
    def __init__(self, config: BartConfig, **kwargs):
        # 初始化函数,接受一个 BartConfig 类型的配置对象和其他关键字参数
        super().__init__(**kwargs)
        # 调用父类初始化方法

        # 设置嵌入维度为配置对象中的 d_model 属性
        self.embed_dim = config.d_model

        # 创建自注意力层,使用 TFBartAttention 类,配置包括嵌入维度、注意力头数、注意力层的名字为 "self_attn"
        self.self_attn = TFBartAttention(
            self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
        )

        # 创建自注意力层的 LayerNormalization 层,设置 epsilon 为 1e-5,名字为 "self_attn_layer_norm"
        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")

        # 创建 dropout 层,使用配置对象中的 dropout 率
        self.dropout = keras.layers.Dropout(config.dropout)

        # 获取激活函数,根据配置对象中的激活函数名获取对应的 TensorFlow 激活函数
        self.activation_fn = get_tf_activation(config.activation_function)

        # 创建激活 dropout 层,使用配置对象中的激活 dropout 率
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)

        # 创建第一个全连接层,输出维度为配置对象中的 encoder_ffn_dim,名字为 "fc1"
        self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")

        # 创建第二个全连接层,输出维度为嵌入维度,名字为 "fc2"
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")

        # 创建最终的 LayerNormalization 层,设置 epsilon 为 1e-5,名字为 "final_layer_norm"
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")

        # 保存配置对象
        self.config = config

    def call(
        self,
        hidden_states: tf.Tensor,
        attention_mask: np.ndarray | tf.Tensor | None,
        layer_head_mask: tf.Tensor | None,
        training: Optional[bool] = False,
    ) -> tf.Tensor:
        """
        Args:
            hidden_states (`tf.Tensor`): 输入到该层的张量,形状为 `(batch, seq_len, embed_dim)`
            attention_mask (`tf.Tensor`): 注意力掩码张量,形状为 `(batch, 1, tgt_len, src_len)`,用大负值表示填充元素
            layer_head_mask (`tf.Tensor`): 给定层中注意力头的掩码张量,形状为 `(encoder_attention_heads,)`
            training (`Optional[bool]`): 是否处于训练模式,默认为 False
        """
        # 保存输入的原始状态作为残差连接的一部分
        residual = hidden_states

        # 调用自注意力层进行操作,返回处理后的 hidden_states、注意力权重和额外信息
        hidden_states, self_attn_weights, _ = self.self_attn(
            hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
        )

        # 断言保证自注意力层没有修改查询的形状
        tf.debugging.assert_equal(
            shape_list(hidden_states),
            shape_list(residual),
            message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
        )

        # 应用 dropout 操作
        hidden_states = self.dropout(hidden_states, training=training)

        # 添加残差连接
        hidden_states = residual + hidden_states

        # 应用 LayerNormalization 操作
        hidden_states = self.self_attn_layer_norm(hidden_states)

        # 保存当前状态作为新的残差
        residual = hidden_states

        # 应用激活函数和第一个全连接层
        hidden_states = self.activation_fn(self.fc1(hidden_states))

        # 应用激活 dropout 操作
        hidden_states = self.activation_dropout(hidden_states, training=training)

        # 应用第二个全连接层
        hidden_states = self.fc2(hidden_states)

        # 应用 dropout 操作
        hidden_states = self.dropout(hidden_states, training=training)

        # 添加残差连接
        hidden_states = residual + hidden_states

        # 应用最终的 LayerNormalization 操作
        hidden_states = self.final_layer_norm(hidden_states)

        # 返回处理后的 hidden_states 和自注意力权重
        return hidden_states, self_attn_weights
    # 构建神经网络层,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果存在 self_attn 属性,则构建 self_attn 层
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        # 如果存在 self_attn_layer_norm 属性,则构建 self_attn_layer_norm 层
        if getattr(self, "self_attn_layer_norm", None) is not None:
            with tf.name_scope(self.self_attn_layer_norm.name):
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        # 如果存在 fc1 属性,则构建 fc1 层
        if getattr(self, "fc1", None) is not None:
            with tf.name_scope(self.fc1.name):
                self.fc1.build([None, None, self.embed_dim])
        # 如果存在 fc2 属性,则构建 fc2 层
        if getattr(self, "fc2", None) is not None:
            with tf.name_scope(self.fc2.name):
                self.fc2.build([None, None, self.config.encoder_ffn_dim])
        # 如果存在 final_layer_norm 属性,则构建 final_layer_norm 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.embed_dim])
# 定义 TFBartDecoderLayer 类,继承自 keras.layers.Layer,用于实现 BART 解码器的一个层
class TFBartDecoderLayer(keras.layers.Layer):
    # 初始化方法,接受配置参数 config 和其他关键字参数
    def __init__(self, config: BartConfig, **kwargs):
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 设置层的嵌入维度为配置中的模型维度
        self.embed_dim = config.d_model
        # 创建自注意力机制层 self_attn,使用 TFBartAttention 类,配置包括嵌入维度、注意力头数、注意力 dropout 等
        self.self_attn = TFBartAttention(
            embed_dim=self.embed_dim,
            num_heads=config.decoder_attention_heads,
            dropout=config.attention_dropout,
            name="self_attn",
            is_decoder=True,
        )
        # 创建 Dropout 层,用于 self_attn 层的输出
        self.dropout = keras.layers.Dropout(config.dropout)
        # 获取激活函数,并将其赋值给 activation_fn
        self.activation_fn = get_tf_activation(config.activation_function)
        # 创建用于激活函数 dropout 的 Dropout 层
        self.activation_dropout = keras.layers.Dropout(config.activation_dropout)

        # 创建 LayerNormalization 层,用于自注意力机制的输出
        self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
        
        # 创建编码器注意力机制层 encoder_attn,配置与 self_attn 类似,用于处理编码器的输出
        self.encoder_attn = TFBartAttention(
            self.embed_dim,
            config.decoder_attention_heads,
            dropout=config.attention_dropout,
            name="encoder_attn",
            is_decoder=True,
        )
        # 创建 LayerNormalization 层,用于编码器注意力机制的输出
        self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
        
        # 创建全连接层 fc1,用于进行维度变换和非线性变换
        self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
        # 创建全连接层 fc2,输出维度与嵌入维度相同
        self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
        
        # 创建 LayerNormalization 层,用于最终输出的规范化
        self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
        
        # 保存配置对象
        self.config = config

    # 定义 call 方法,实现层的前向传播逻辑
    def call(
        self,
        hidden_states: tf.Tensor,  # 输入的隐藏状态张量
        attention_mask: np.ndarray | tf.Tensor | None = None,  # 注意力遮罩,可选参数
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,  # 编码器的隐藏状态张量,可选参数
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,  # 编码器的注意力遮罩,可选参数
        layer_head_mask: tf.Tensor | None = None,  # 层级头部掩码,可选参数
        cross_attn_layer_head_mask: tf.Tensor | None = None,  # 交叉注意力层级头部掩码,可选参数
        past_key_value: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,  # 过去的键值对,可选参数
        training: Optional[bool] = False,  # 是否处于训练模式,可选参数,默认为 False

        # 方法主体部分暂时省略,根据具体逻辑进行完整注释
        pass
    # 定义模型的构建方法,如果已经构建过则直接返回
    def build(self, input_shape=None):
        if self.built:
            return
        # 标记模型已构建
        self.built = True
        
        # 如果存在 self_attn 属性,则构建 self attention 层
        if getattr(self, "self_attn", None) is not None:
            with tf.name_scope(self.self_attn.name):
                self.self_attn.build(None)
        
        # 如果存在 self_attn_layer_norm 属性,则构建 self attention 层的 layer normalization 层
        if getattr(self, "self_attn_layer_norm", None) is not None:
            with tf.name_scope(self.self_attn_layer_norm.name):
                self.self_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 encoder_attn 属性,则构建 encoder-decoder attention 层
        if getattr(self, "encoder_attn", None) is not None:
            with tf.name_scope(self.encoder_attn.name):
                self.encoder_attn.build(None)
        
        # 如果存在 encoder_attn_layer_norm 属性,则构建 encoder-decoder attention 层的 layer normalization 层
        if getattr(self, "encoder_attn_layer_norm", None) is not None:
            with tf.name_scope(self.encoder_attn_layer_norm.name):
                self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
        
        # 如果存在 fc1 属性,则构建第一个全连接层
        if getattr(self, "fc1", None) is not None:
            with tf.name_scope(self.fc1.name):
                self.fc1.build([None, None, self.embed_dim])
        
        # 如果存在 fc2 属性,则构建第二个全连接层
        if getattr(self, "fc2", None) is not None:
            with tf.name_scope(self.fc2.name):
                self.fc2.build([None, None, self.config.decoder_ffn_dim])
        
        # 如果存在 final_layer_norm 属性,则构建最终的 layer normalization 层
        if getattr(self, "final_layer_norm", None) is not None:
            with tf.name_scope(self.final_layer_norm.name):
                self.final_layer_norm.build([None, None, self.embed_dim])
class TFBartClassificationHead(keras.layers.Layer):
    """Head for sentence-level classification tasks."""

    def __init__(self, inner_dim: int, num_classes: int, pooler_dropout: float, name: str, **kwargs):
        super().__init__(name=name, **kwargs)
        # 定义一个全连接层,输出维度为 inner_dim
        self.dense = keras.layers.Dense(inner_dim, name="dense")
        # 定义一个 dropout 层,用于在训练过程中随机失活部分神经元
        self.dropout = keras.layers.Dropout(pooler_dropout)
        # 定义一个全连接层,输出维度为 num_classes,用于分类任务的输出
        self.out_proj = keras.layers.Dense(num_classes, name="out_proj")
        # 记录输入维度和内部维度,这些参数在构建模型时会用到
        self.input_dim = inner_dim
        self.inner_dim = inner_dim

    def call(self, inputs):
        # 对输入进行 dropout 处理
        hidden_states = self.dropout(inputs)
        # 经过全连接层 dense 处理
        hidden_states = self.dense(hidden_states)
        # 使用 tanh 激活函数处理隐藏状态
        hidden_states = keras.activations.tanh(hidden_states)
        # 再次进行 dropout 处理
        hidden_states = self.dropout(hidden_states)
        # 最后经过全连接层 out_proj 输出最终的分类结果
        hidden_states = self.out_proj(hidden_states)
        return hidden_states

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 构建模型,如果已经构建则直接返回
        if getattr(self, "dense", None) is not None:
            with tf.name_scope(self.dense.name):
                # 使用 input_dim 构建 dense 层
                self.dense.build([None, None, self.input_dim])
        if getattr(self, "out_proj", None) is not None:
            with tf.name_scope(self.out_proj.name):
                # 使用 inner_dim 构建 out_proj 层
                self.out_proj.build([None, None, self.inner_dim])


class TFBartPretrainedModel(TFPreTrainedModel):
    config_class = BartConfig
    base_model_prefix = "model"

    @property
    def dummy_inputs(self):
        dummy_inputs = super().dummy_inputs
        # 修改虚拟输入,使得 input_ids 和 decoder_input_ids 均扩展为原来的两倍长度
        dummy_inputs["input_ids"] = dummy_inputs["input_ids"] * 2
        if "decoder_input_ids" in dummy_inputs:
            dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"] * 2
        return dummy_inputs

    def tf_to_pt_weight_rename(self, tf_weight):
        # 将 TF 的权重名称转换为 PyTorch 风格的权重名称
        if tf_weight == "model.shared.weight":
            return tf_weight, "model.decoder.embed_tokens.weight"
        else:
            return (tf_weight,)


BART_START_DOCSTRING = r"""
    This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
    as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
    behavior.

    <Tip>

    TensorFlow models and layers in `transformers` accept two formats as input:

    - having all inputs as keyword arguments (like PyTorch models), or
    - having all inputs as a list, tuple or dict in the first positional argument.

    The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
"""
    # 使用 `BartConfig` 类型的配置参数 `config`,该类包含了模型的所有参数设定
    config ([`BartConfig`]): Model configuration class with all the parameters of the model.
        # 通过配置文件初始化模型不会加载模型的权重,只加载配置信息
        Initializing with a config file does not load the weights associated with the model, only the
        configuration.
        # 使用 [`~TFPreTrainedModel.from_pretrained`] 方法可以加载模型的权重
        Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
"""


BART_GENERATION_EXAMPLE = r"""
    Summarization example:

    ```
    >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration

    >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large")
    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")

    >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
    >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="tf")

    >>> # Generate Summary
    >>> summary_ids = model.generate(inputs["input_ids"], num_beams=4, max_length=5)
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
    ```

    Mask filling example:

    ```
    >>> from transformers import AutoTokenizer, TFBartForConditionalGeneration

    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
    >>> TXT = "My friends are <mask> but they eat too many carbs."

    >>> model = TFBartForConditionalGeneration.from_pretrained("facebook/bart-large")
    >>> input_ids = tokenizer([TXT], return_tensors="tf")["input_ids"]
    >>> logits = model(input_ids).logits
    >>> probs = tf.nn.softmax(logits[0])
    >>> # probs[5] is associated with the mask token
    ```
"""


BART_INPUTS_DOCSTRING = r"""
"""


@keras_serializable
class TFBartEncoder(keras.layers.Layer):
    config_class = BartConfig
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`TFBartEncoderLayer`].

    Args:
        config: BartConfig
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.dropout = keras.layers.Dropout(config.dropout)
        self.layerdrop = config.encoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0

        self.embed_tokens = embed_tokens
        self.embed_positions = TFBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )
        self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
        self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
        self.embed_dim = config.d_model

    @unpack_inputs
    # 这是一个装饰器,用于解包输入参数,使其可以作为函数的参数使用
    # 定义类方法 `call`,用于模型的前向传播
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        # 如果模型已经构建完成,则直接返回,避免重复构建
        if self.built:
            return
        # 设置模型已经构建的标志
        self.built = True

        # 如果存在 `embed_positions` 属性,则构建它
        if getattr(self, "embed_positions", None) is not None:
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)

        # 如果存在 `layernorm_embedding` 属性,则构建它
        if getattr(self, "layernorm_embedding", None) is not None:
            with tf.name_scope(self.layernorm_embedding.name):
                self.layernorm_embedding.build([None, None, self.embed_dim])

        # 如果存在 `layers` 属性,则逐层构建每一层
        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)
@keras_serializable
class TFBartDecoder(keras.layers.Layer):
    # 定义一个可序列化的 Keras 层,用于实现 BART 解码器
    config_class = BartConfig
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFBartDecoderLayer`]

    Args:
        config: BartConfig
        embed_tokens: output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
        super().__init__(**kwargs)
        # 初始化函数,设置配置、填充索引、嵌入 tokens 和其他参数
        self.config = config
        self.padding_idx = config.pad_token_id
        self.embed_tokens = embed_tokens
        self.layerdrop = config.decoder_layerdrop
        self.embed_positions = TFBartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            name="embed_positions",
        )
        self.embed_scale = tf.math.sqrt(float(config.d_model)) if config.scale_embedding else 1.0
        self.layers = [TFBartDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
        self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")

        self.dropout = keras.layers.Dropout(config.dropout)

    @unpack_inputs
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        position_ids: np.ndarray | tf.Tensor | None = None,
        encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
        encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
    ):
        # 解码器的调用方法,接受多种输入和参数
        ...

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 如果层已经构建,则直接返回

        if getattr(self, "embed_positions", None) is not None:
            with tf.name_scope(self.embed_positions.name):
                self.embed_positions.build(None)
        # 如果存在嵌入位置信息,构建嵌入位置层

        if getattr(self, "layernorm_embedding", None) is not None:
            with tf.name_scope(self.layernorm_embedding.name):
                self.layernorm_embedding.build([None, None, self.config.d_model])
        # 如果存在层归一化层,构建层归一化层

        if getattr(self, "layers", None) is not None:
            for layer in self.layers:
                with tf.name_scope(layer.name):
                    layer.build(None)
        # 构建解码器的每一层

@keras_serializable
class TFBartMainLayer(keras.layers.Layer):
    # 定义一个可序列化的 Keras 主层,用于实现 BART 主要层
    config_class = BartConfig
    def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        # 创建一个共享的嵌入层,用于输入的词汇表大小和模型维度
        self.shared = keras.layers.Embedding(
            input_dim=config.vocab_size,
            output_dim=config.d_model,
            # 使用 TruncatedNormal 初始化器初始化嵌入层权重
            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
            name="model.shared",
        )
        # 设置加载/存储权重时的预期名称空间
        self.shared.load_weight_prefix = "model.shared" if load_weight_prefix is None else load_weight_prefix

        # 创建编码器和解码器对象
        self.encoder = TFBartEncoder(config, self.shared, name="encoder")
        self.decoder = TFBartDecoder(config, self.shared, name="decoder")

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        # 更新共享的嵌入层权重
        self.shared = new_embeddings
        # 更新编码器和解码器的嵌入层权重
        self.encoder.embed_tokens = self.shared
        self.decoder.embed_tokens = self.shared

    @unpack_inputs
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        **kwargs,
    ):
        # 省略模型调用的详细注释,因为这些参数涉及模型的输入和输出处理

    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        # 设置共享/共享权重的名称空间预期在模型基本名称空间中
        # 在 tf.name_scope 的末尾添加 "/"(而不是开头!)将其放置在根名称空间而不是当前名称空间中。
        with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
            self.shared.build(None)
        # 如果存在编码器对象,则在其名称空间下构建
        if getattr(self, "encoder", None) is not None:
            with tf.name_scope(self.encoder.name):
                self.encoder.build(None)
        # 如果存在解码器对象,则在其名称空间下构建
        if getattr(self, "decoder", None) is not None:
            with tf.name_scope(self.decoder.name):
                self.decoder.build(None)
# 添加 BART 模型的文档字符串,描述该类用于输出没有特定头部的原始隐藏状态
@add_start_docstrings(
    "The bare BART Model outputting raw hidden-states without any specific head on top.",
    BART_START_DOCSTRING,
)
# 定义 TFBartModel 类,继承自 TFBartPretrainedModel
class TFBartModel(TFBartPretrainedModel):
    # 表示需要加载权重前缀
    _requires_load_weight_prefix = True

    # 初始化方法,接受 BartConfig 类型的配置对象和其他输入参数
    def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
        # 调用父类的初始化方法
        super().__init__(config, *inputs, **kwargs)
        
        # 创建 TFBartMainLayer 实例作为 self.model,用于处理 BART 的主体部分
        self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")

    # 返回 encoder 部分
    def get_encoder(self):
        return self.model.encoder

    # 返回 decoder 部分
    def get_decoder(self):
        return self.model.decoder

    # 调用方法,用于执行 BART 模型的前向传播
    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=TFSeq2SeqModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    @unpack_inputs
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        training: Optional[bool] = False,
        **kwargs,
    ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
        # 调用 self.model 的前向传播,传递所有参数
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        # 返回模型的输出
        return outputs
    # 定义一个方法用于处理模型的输出
    def serving_output(self, output):
        # 如果配置中使用缓存,则提取输出中的过去键值对中的第一个元素
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中输出隐藏状态,则将输出的解码器隐藏状态转换为张量
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中输出注意力权重,则将输出的解码器注意力权重转换为张量
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中输出交叉注意力权重,则将输出的交叉注意力权重转换为张量
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中输出隐藏状态,则将输出的编码器隐藏状态转换为张量
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中输出注意力权重,则将输出的编码器注意力权重转换为张量
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqModelOutput 对象,包括最后隐藏状态、过去键值对、解码器隐藏状态、解码器注意力权重、
        # 交叉注意力权重、编码器最后隐藏状态、编码器隐藏状态和编码器注意力权重
        return TFSeq2SeqModelOutput(
            last_hidden_state=output.last_hidden_state,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 定义一个方法用于构建模型
    def build(self, input_shape=None):
        # 如果模型已经构建完成,则直接返回
        if self.built:
            return
        # 将模型标记为已构建
        self.built = True
        # 如果对象中存在模型属性,则在模型的命名空间下构建模型
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                self.model.build(None)
class BiasLayer(keras.layers.Layer):
    """
    Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
    so all weights have to be registered in a layer.
    """

    def __init__(self, shape, initializer, trainable, name, **kwargs):
        super().__init__(name=name, **kwargs)
        # 添加一个权重变量作为偏置,用于神经网络层的偏置项
        self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)

    def call(self, x):
        # 在前向传播中,将输入张量和偏置相加并返回
        return x + self.bias


@add_start_docstrings(
    "The BART Model with a language modeling head. Can be used for summarization.",
    BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageModelingLoss):
    _keys_to_ignore_on_load_missing = [r"final_logits_bias"]
    _requires_load_weight_prefix = True

    def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 创建 BART 模型的主体部分,包括编码器和解码器
        self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
        self.use_cache = config.use_cache
        # 创建一个偏置层用于处理最终输出的偏置 logits
        # 在 PyTorch 中,final_logits_bias 被注册为一个缓冲区,为了保持一致性,这里设置为不可训练
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
        )

    def get_decoder(self):
        # 返回 BART 模型的解码器
        return self.model.decoder

    def get_encoder(self):
        # 返回 BART 模型的编码器
        return self.model.encoder

    def get_output_embeddings(self):
        # 返回输入嵌入层,用于获取输出的词汇表嵌入
        return self.get_input_embeddings()

    def set_output_embeddings(self, value):
        # 设置输入嵌入层,用于设置输出的词汇表嵌入
        self.set_input_embeddings(value)

    def get_bias(self):
        # 返回偏置层的偏置值,用于模型保存和加载
        return {"final_logits_bias": self.bias_layer.bias}

    def set_bias(self, value):
        # 替换现有的包含偏置的层,确保正确的序列化和反序列化
        vocab_size = value["final_logits_bias"].shape[-1]
        self.bias_layer = BiasLayer(
            name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
        )
        self.bias_layer.bias.assign(value["final_logits_bias"])

    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    @add_end_docstrings(BART_GENERATION_EXAMPLE)
    @unpack_inputs
    # 定义一个方法,用于调用模型。以下是方法的参数列表,每个参数都有特定的类型注解和默认值。

    # 输入序列的标识符,可以是 TFModelInputType 类型或者 None
    input_ids: TFModelInputType | None = None,

    # 注意力掩码,可以是 numpy 数组、TensorFlow 张量或者 None
    attention_mask: np.ndarray | tf.Tensor | None = None,

    # 解码器输入序列的标识符,可以是 numpy 数组、TensorFlow 张量或者 None
    decoder_input_ids: np.ndarray | tf.Tensor | None = None,

    # 解码器的注意力掩码,可以是 numpy 数组、TensorFlow 张量或者 None
    decoder_attention_mask: np.ndarray | tf.Tensor | None = None,

    # 解码器的位置标识符,可以是 numpy 数组、TensorFlow 张量或者 None
    decoder_position_ids: np.ndarray | tf.Tensor | None = None,

    # 头部掩码,用于控制每个注意力头部的屏蔽情况,可以是 numpy 数组、TensorFlow 张量或者 None
    head_mask: np.ndarray | tf.Tensor | None = None,

    # 解码器头部掩码,用于控制解码器每个注意力头部的屏蔽情况,可以是 numpy 数组、TensorFlow 张量或者 None
    decoder_head_mask: np.ndarray | tf.Tensor | None = None,

    # 交叉注意力头部掩码,用于控制编码器-解码器注意力每个头部的屏蔽情况,可以是 numpy 数组、TensorFlow 张量或者 None
    cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,

    # 编码器输出,类型为 TFBaseModelOutput 或者 None
    encoder_outputs: Optional[TFBaseModelOutput] = None,

    # 缓存的键值对,可以是包含 numpy 数组或 TensorFlow 张量的元组的元组,或者 None
    past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,

    # 输入的嵌入向量,可以是 numpy 数组、TensorFlow 张量或者 None
    inputs_embeds: np.ndarray | tf.Tensor | None = None,

    # 解码器的输入嵌入向量,可以是 numpy 数组、TensorFlow 张量或者 None
    decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,

    # 是否使用缓存,可以是布尔值或者 None
    use_cache: Optional[bool] = None,

    # 是否输出注意力权重,可以是布尔值或者 None
    output_attentions: Optional[bool] = None,

    # 是否输出隐藏状态,可以是布尔值或者 None
    output_hidden_states: Optional[bool] = None,

    # 是否返回一个字典格式的结果,可以是布尔值或者 None
    return_dict: Optional[bool] = None,

    # 标签,类型为 TensorFlow 张量或者 None
    labels: tf.Tensor | None = None,

    # 是否在训练模式,可以是布尔值,默认为 False
    training: Optional[bool] = False,
        ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
        r"""
        labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:
            Either a `TFSeq2SeqLMOutput` object or a tuple containing a `tf.Tensor` depending on the `return_dict` parameter.

        """

        if labels is not None:
            # Replace tokens equal to pad_token_id with -100 for loss computation
            labels = tf.where(
                labels == self.config.pad_token_id,
                tf.cast(tf.fill(shape_list(labels), -100), labels.dtype),
                labels,
            )
            use_cache = False
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                # Shift labels to the right to obtain decoder_input_ids
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        # Forward pass through the model
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )
        
        # Calculate logits and apply bias layer
        lm_logits = tf.matmul(outputs[0], self.model.shared.weights, transpose_b=True)
        lm_logits = self.bias_layer(lm_logits)
        
        # Compute masked language modeling loss if labels are provided
        masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)

        # Prepare output based on return_dict flag
        if not return_dict:
            # Return tuple of lm_logits and other outputs
            output = (lm_logits,) + outputs[1:]
            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
        else:
            # Return TFSeq2SeqLMOutput object with specified attributes
            return TFSeq2SeqLMOutput(
                loss=masked_lm_loss,
                logits=lm_logits,
                past_key_values=outputs.past_key_values,  # index 1 of d outputs
                decoder_hidden_states=outputs.decoder_hidden_states,  # index 2 of d outputs
                decoder_attentions=outputs.decoder_attentions,  # index 3 of d outputs
                cross_attentions=outputs.cross_attentions,  # index 4 of d outputs
                encoder_last_hidden_state=outputs.encoder_last_hidden_state,  # index 0 of encoder outputs
                encoder_hidden_states=outputs.encoder_hidden_states,  # index 1 of encoder outputs
                encoder_attentions=outputs.encoder_attentions,  # index 2 of encoder outputs
            )
    # 定义一个方法,用于生成服务端输出对象,基于给定的模型输出参数
    def serving_output(self, output):
        # 如果配置中指定使用缓存,则从输出中提取过去的键值对,否则设置为 None
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置中设置输出隐藏状态,则将输出的解码器隐藏状态转换为张量,否则设置为 None
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中设置输出注意力权重,则将输出的解码器注意力权重转换为张量,否则设置为 None
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置中设置输出交叉注意力权重,则将输出的交叉注意力权重转换为张量,否则设置为 None
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置中设置输出隐藏状态,则将输出的编码器隐藏状态转换为张量,否则设置为 None
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置中设置输出注意力权重,则将输出的编码器注意力权重转换为张量,否则设置为 None
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回一个 TFSeq2SeqLMOutput 对象,包含经处理后的输出参数
        return TFSeq2SeqLMOutput(
            logits=output.logits,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 定义一个方法,为生成过程准备输入参数
    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past_key_values=None,
        attention_mask=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # 如果 past_key_values 不为 None,则将 decoder_input_ids 截取最后一个位置的输入
        if past_key_values is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        # 如果 decoder_attention_mask 不为 None,用于 XLA 编译
        if decoder_attention_mask is not None:  # xla
            # 计算 decoder_attention_mask 在最后一个维度上的累积和,然后取最后一个位置的值
            decoder_position_ids = tf.math.cumsum(decoder_attention_mask, axis=-1, exclusive=True)[:, -1:]
        # 否则,如果同时没有使用 XLA 和 past_key_values
        elif past_key_values is not None:  # no xla + past_key_values
            # 获取 past_key_values 中第一个元素的第一个维度的长度作为 decoder_position_ids
            decoder_position_ids = past_key_values[0][0].shape[2]
        else:
            # 否则,使用 decoder_input_ids 的长度范围作为 decoder_position_ids
            decoder_position_ids = tf.range(decoder_input_ids.shape[1])

        # 返回一个字典,包含生成过程中所需的输入参数
        return {
            "input_ids": None,  # 如果定义了 encoder_outputs,则不需要 input_ids
            "encoder_outputs": encoder_outputs,
            "past_key_values": past_key_values,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "decoder_attention_mask": decoder_attention_mask,
            "decoder_position_ids": decoder_position_ids,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # 修改此项以避免缓存(可能用于调试目的)
        }

    # 定义一个方法,从标签生成器中准备解码器输入的标识符
    def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
        # 使用 shift_tokens_right 函数将标签右移,用 pad_token_id 填充,并加入 decoder_start_token_id
        return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
    # 定义一个方法 `build`,用于构建模型的结构
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回,避免重复构建
        if self.built:
            return
        # 将构建状态标记为已构建
        self.built = True
        
        # 如果模型属性存在,则为模型命名空间添加一个名为模型名称的作用域,并构建模型
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                self.model.build(None)
        
        # 如果偏置层属性存在,则为偏置层命名空间添加一个名为偏置层名称的作用域,并构建偏置层
        if getattr(self, "bias_layer", None) is not None:
            with tf.name_scope(self.bias_layer.name):
                self.bias_layer.build(None)
@add_start_docstrings(
    """
    Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
    tasks.
    """,
    BART_START_DOCSTRING,
)
class TFBartForSequenceClassification(TFBartPretrainedModel, TFSequenceClassificationLoss):
    def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        # 初始化 BART 主模型层,加载预训练权重(如果提供),命名为 "model"
        self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
        # 初始化 BART 分类头部,用于分类任务,包括一个线性层在汇聚输出之上,设置丢弃率为 config.classifier_dropout,命名为 "classification_head"
        self.classification_head = TFBartClassificationHead(
            config.d_model, config.num_labels, config.classifier_dropout, name="classification_head"
        )

    @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=TFSeq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
    @unpack_inputs
    def call(
        self,
        input_ids: TFModelInputType | None = None,
        attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_input_ids: np.ndarray | tf.Tensor | None = None,
        decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
        decoder_position_ids: np.ndarray | tf.Tensor | None = None,
        head_mask: np.ndarray | tf.Tensor | None = None,
        decoder_head_mask: np.ndarray | tf.Tensor | None = None,
        cross_attn_head_mask: np.ndarray | tf.Tensor | None = None,
        encoder_outputs: Optional[TFBaseModelOutput] = None,
        past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
        inputs_embeds: np.ndarray | tf.Tensor | None = None,
        decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        labels: tf.Tensor | None = None,
        training: Optional[bool] = False,
        # 输入参数详细描述见 BART_INPUTS_DOCSTRING
    # 定义一个方法用于处理模型的输出,将输出转换为 TensorFlow 张量
    def serving_output(self, output):
        # 将输出中的 logits 转换为 TensorFlow 张量
        logits = tf.convert_to_tensor(output.logits)
        # 如果配置要求使用缓存,则从输出中提取过去的键值
        pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
        # 如果配置要求输出隐藏状态,则将输出中的解码器隐藏状态转换为 TensorFlow 张量
        dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置要求输出注意力分布,则将输出中的解码器注意力分布转换为 TensorFlow 张量
        dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
        # 如果配置要求输出注意力分布,则将输出中的交叉注意力分布转换为 TensorFlow 张量
        cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
        # 如果配置要求输出隐藏状态,则将输出中的编码器隐藏状态转换为 TensorFlow 张量
        enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
        # 如果配置要求输出注意力分布,则将输出中的编码器注意力分布转换为 TensorFlow 张量
        enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None

        # 返回转换后的输出对象 TFSeq2SeqSequenceClassifierOutput
        return TFSeq2SeqSequenceClassifierOutput(
            logits=logits,
            past_key_values=pkv,
            decoder_hidden_states=dec_hs,
            decoder_attentions=dec_attns,
            cross_attentions=cross_attns,
            encoder_last_hidden_state=output.encoder_last_hidden_state,
            encoder_hidden_states=enc_hs,
            encoder_attentions=enc_attns,
        )

    # 构建方法,用于构建模型
    def build(self, input_shape=None):
        # 如果已经构建过,则直接返回
        if self.built:
            return
        # 设置标志位,表示模型已构建
        self.built = True
        # 如果对象中存在模型属性,则在命名作用域内构建模型
        if getattr(self, "model", None) is not None:
            with tf.name_scope(self.model.name):
                self.model.build(None)
        # 如果对象中存在分类头属性,则在命名作用域内构建分类头
        if getattr(self, "classification_head", None) is not None:
            with tf.name_scope(self.classification_head.name):
                self.classification_head.build(None)

.\models\bart\tokenization_bart.py

# 定义一个名为 `bytes_to_unicode` 的函数,并且使用 `@lru_cache()` 装饰器进行缓存,使其结果可以被缓存以提高性能
@lru_cache()
def bytes_to_unicode():
    """
    返回 utf-8 字节列表及其与 Unicode 字符串的映射。特别地,避免将空格/控制字符映射到 BPE 代码中会出错的情况。

    可逆的 BPE(Byte Pair Encoding)代码适用于 Unicode 字符串。这意味着你的词汇表中需要有大量的 Unicode 字符。
    """
    # 返回具体的映射关系和描述性文本,这些信息在 BPE 算法中是必需的
    return [
        '\u2581' + chr(i) for i in range(0, 128)
    ] + [chr(i) for i in range(128, 256)]
    # 定义一个函数,返回一个字典,用于 utf-8 字节和 Unicode 字符串之间的映射
    def make_utf8_to_unicode_lookup():
        # 创建一个包含可打印 ASCII 字符、特殊符号和扩展 Latin-1 范围的字节列表
        bs = (
            list(range(ord("!"), ord("~") + 1)) + 
            list(range(ord("¡"), ord("¬") + 1)) + 
            list(range(ord("®"), ord("ÿ") + 1))
        )
        # 复制 bs 列表到 cs 列表
        cs = bs[:]
        # 初始化计数器 n
        n = 0
        # 遍历所有可能的 8 位字节值
        for b in range(2**8):
            # 如果 b 不在 bs 中,则将 b 添加到 bs 和对应的扩展编码添加到 cs 中
            if b not in bs:
                bs.append(b)
                cs.append(2**8 + n)
                n += 1
        # 将 cs 中的整数转换为对应的 Unicode 字符串
        cs = [chr(n) for n in cs]
        # 返回 bs 和 cs 对应的字典
        return dict(zip(bs, cs))
def get_pairs(word):
    """
    Return set of symbol pairs in a word.

    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    # 初始化一个空集合用于存放符号对
    pairs = set()
    # 获取单词的第一个字符作为前一个字符
    prev_char = word[0]
    # 遍历单词中的每个字符,形成符号对并添加到集合中
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    # 返回包含符号对的集合
    return pairs


class BartTokenizer(PreTrainedTokenizer):
    """
    Constructs a BART tokenizer, which is smilar to the ROBERTa tokenizer, using byte-level Byte-Pair-Encoding.

    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
    be encoded differently whether it is at the beginning of the sentence (without space) or not:

    ```
    >>> from transformers import BartTokenizer

    >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
    call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

    <Tip>

    When used with `is_split_into_words=True`, this tokenizer will add a space before each word (even the first one).

    </Tip>

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    """
    # 构造函数,用于初始化一个 BART 分词器对象
    def __init__(self, vocab_file, merges_file, errors='replace', special_tokens_dict=None, max_len=None, **kwargs):
        # 调用父类构造函数初始化 BART 分词器
        super().__init__(vocab_file, merges_file, errors=errors, special_tokens_dict=special_tokens_dict, **kwargs)
        # 设置最大长度属性
        self.max_len = max_len

    # 实现从预训练模型加载 BART 分词器的类方法
    @classmethod
    def from_pretrained(cls, *inputs, **kwargs):
        # 调用父类的类方法加载预训练模型
        return super().from_pretrained(*inputs, **kwargs)

    # 重写父类方法,根据文本生成输入 ID 列表
    def __call__(self, text, **kwargs):
        # 调用父类方法生成输入 ID 列表
        return super().__call__(text, **kwargs)
    # 定义一个函数的参数列表,用于初始化一个类的实例或调用函数时传递参数。
    Args:
        vocab_file (`str`):
            词汇表文件的路径。
        merges_file (`str`):
            合并文件的路径。
        errors (`str`, *optional*, defaults to `"replace"`):
            当解码字节为 UTF-8 时的错误处理方式。详见 [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode)。
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            预训练过程中用作序列开头的特殊 token。可以用作序列分类器的 token。
            <Tip>
            在构建序列时使用特殊 token 时,实际用于序列开头的 token 是 `cls_token`。
            </Tip>
        eos_token (`str`, *optional*, defaults to `"</s>"`):
            序列结尾的特殊 token。
            <Tip>
            在构建序列时使用特殊 token 时,实际用于序列结尾的 token 是 `sep_token`。
            </Tip>
        sep_token (`str`, *optional*, defaults to `"</s>"`):
            分隔符 token,在构建多个序列的合并序列时使用,例如序列分类或问答任务中的问题和文本序列。也作为使用特殊 token 构建序列时的最后一个 token。
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            分类器 token,在序列分类任务中使用(整个序列的分类而不是每个 token 的分类)。在使用特殊 token 构建序列时,它是序列的第一个 token。
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            未知 token,如果词汇表中不存在某个 token,则将其替换为该 token。
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            用于填充的 token,在批处理不同长度的序列时使用。
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
            用于掩码值的 token,在进行掩码语言建模训练时使用,模型将尝试预测该 token。
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            是否在输入的开头添加一个空格,这允许将第一个词视为其他词一样处理。(BART tokenizer 通过前导空格检测单词的开头)。
    """
    
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    model_input_names = ["input_ids", "attention_mask"]
    def __init__(
        self,
        vocab_file,
        merges_file,
        errors="replace",
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        add_prefix_space=False,
        **kwargs,
    ):
        # 如果初始的特殊标记是字符串类型,则使用AddedToken进行处理,保持左右空格的原样
        bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
        eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
        sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
        cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
        unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
        pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token

        # 处理mask_token,使其像普通单词一样,包括前面的空格
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 使用utf-8编码打开vocab_file文件,加载其中的内容到self.encoder字典中
        with open(vocab_file, encoding="utf-8") as vocab_handle:
            self.encoder = json.load(vocab_handle)

        # 创建self.decoder字典,将self.encoder的键值对反转,用于从索引到单词的解码
        self.decoder = {v: k for k, v in self.encoder.items()}

        # 设定解码中遇到错误时的处理方式
        self.errors = errors  # how to handle errors in decoding

        # 使用bytes_to_unicode函数生成字节编码到Unicode的映射表
        self.byte_encoder = bytes_to_unicode()

        # 创建self.byte_decoder字典,将self.byte_encoder的键值对反转,用于从Unicode到字节编码的解码
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

        # 使用utf-8编码打开merges_file文件,读取内容并按行分割,去掉首尾空行后将其转换为元组列表bpe_merges
        with open(merges_file, encoding="utf-8") as merges_handle:
            bpe_merges = merges_handle.read().split("\n")[1:-1]
        bpe_merges = [tuple(merge.split()) for merge in bpe_merges]

        # 创建self.bpe_ranks字典,将bpe_merges列表转换为字典,键为元组,值为其在列表中的索引
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))

        # 初始化缓存字典为空字典
        self.cache = {}

        # 设定是否在前缀空格之前添加特殊标记的选项
        self.add_prefix_space = add_prefix_space

        # 编译正则表达式模式pat,用于匹配字符串中的各种形式的标点、字母和数字
        # 应该添加re.IGNORECASE标志,以便处理大写形式的缩写
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

        # 调用父类的构造方法,传递初始化参数
        super().__init__(
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            **kwargs,
        )

    @property
    def vocab_size(self):
        # 返回self.encoder字典的长度,即词汇表的大小
        return len(self.encoder)

    def get_vocab(self):
        # 返回包含self.encoder和self.added_tokens_encoder所有键值对的字典
        return dict(self.encoder, **self.added_tokens_encoder)
    def _tokenize(self, text):
        """Tokenize a string."""
        # 初始化空列表,用于存储BPE处理后的token
        bpe_tokens = []
        # 使用正则表达式找到所有匹配self.pat的token,并进行处理
        for token in re.findall(self.pat, text):
            # 将token按utf-8编码,并映射到unicode字符串,避免BPE的控制token(在我们的情况下是空格)
            token = "".join(
                self.byte_encoder[b] for b in token.encode("utf-8")
            )  # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
            # 将BPE处理后的token按空格分割,并加入到bpe_tokens列表中
            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
        # 返回处理后的token列表
        return bpe_tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 使用self.encoder获取token对应的id,若token不存在,则使用self.unk_token的id
        return self.encoder.get(token, self.encoder.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 使用self.decoder获取index对应的token
        return self.decoder.get(index)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        # 将tokens列表连接成一个字符串
        text = "".join(tokens)
        # 将字符串按字节解码成utf-8格式,并处理可能的错误
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
        # 返回解码后的文本字符串
        return text
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建词汇表文件路径
        vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        # 构建合并文件路径
        merge_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
        )

        # 写入词汇表到文件中
        with open(vocab_file, "w", encoding="utf-8") as f:
            f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

        index = 0
        # 写入合并数据到文件中
        with open(merge_file, "w", encoding="utf-8") as writer:
            writer.write("#version: 0.2\n")
            # 遍历 BPE rank 数据并写入文件
            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    # 记录警告,指出 BPE 合并索引不连续的情况
                    logger.warning(
                        f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
                        " Please check that the tokenizer is not corrupted!"
                    )
                    index = token_index
                writer.write(" ".join(bpe_tokens) + "\n")
                index += 1

        # 返回保存的词汇表文件路径和合并文件路径
        return vocab_file, merge_file

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BART sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 如果没有第二个序列,则返回添加特殊 token 后的单个序列
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        
        # 构建两个序列合并后的输入序列
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ):
        """
        Retrieve sequence ids where special tokens are added.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs of the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional list of IDs of the second sequence.
            already_has_special_tokens (`bool`, *optional*):
                Whether the sequences already contain special tokens.

        Returns:
            `List[int]`: A list of binary indicators where 1 indicates a special token and 0 indicates a regular token.
        """
        # 计算特殊 token 的掩码
        special_tokens_mask = [1] * len(token_ids_0)
        if token_ids_1 is not None:
            special_tokens_mask += [1] * len(token_ids_1)
        return special_tokens_mask
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        # If the token list already has special tokens, delegate to the superclass method
        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        # If there's only one token list provided, return a mask with special tokens added at both ends
        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        # If two token lists are provided, return a mask with special tokens at both ends of each sequence
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """
        # Initialize special tokens for separator and class, but BART does not use token type ids
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If there's only one sequence, return a list of zeros with the length of cls + token_ids_0 + sep
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        # If two sequences are provided, return a list of zeros with the length of cls + token_ids_0 + sep + sep + token_ids_1 + sep
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
        """
        Prepares text for tokenization by optionally adding a prefix space based on conditions.

        Args:
            text (str): The input text to be tokenized.
            is_split_into_words (bool, optional): Whether the text is already split into words.
            **kwargs: Additional keyword arguments.

        Returns:
            tuple: A tuple containing the modified text and remaining keyword arguments.
        """
        # Check if prefix space should be added based on conditions
        add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
        if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
            text = " " + text
        return (text, kwargs)

.\models\bart\tokenization_bart_fast.py

# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json  # 导入 json 模块,用于处理 JSON 数据
from typing import List, Optional, Tuple  # 导入用于类型提示的模块

from tokenizers import pre_tokenizers, processors  # 导入 tokenizers 库中的预处理器和处理器

from ...tokenization_utils_base import AddedToken, BatchEncoding  # 导入基础的 tokenization_utils_base 中的类
from ...tokenization_utils_fast import PreTrainedTokenizerFast  # 导入 tokenization_utils_fast 中的 PreTrainedTokenizerFast 类
from ...utils import logging  # 导入工具类 logging

from .tokenization_bart import BartTokenizer  # 导入当前目录下的 tokenization_bart 模块中的 BartTokenizer 类


logger = logging.get_logger(__name__)  # 获取当前文件名的日志记录器对象

# 定义用于存储文件名的字典常量
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}

# 定义预训练模型的词汇文件映射字典常量
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/vocab.json",
        "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/vocab.json",
        "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json",
        "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json",
        "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/vocab.json",
        "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/vocab.json",
    },
    "merges_file": {
        "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/merges.txt",
        "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/merges.txt",
        "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt",
        "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt",
        "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/merges.txt",
        "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/merges.txt",
    },
    {
        # 定义一个字典,映射不同的 BART 模型到它们对应的 tokenizer.json 文件的 URL
        "tokenizer_file": {
            "facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json",
            "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/tokenizer.json",
            "facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json",
            "facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json",
            "facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/tokenizer.json",
            "yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/tokenizer.json",
        },
    }
}

# 预训练位置嵌入的大小字典,映射不同的BART模型到对应的嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/bart-base": 1024,
    "facebook/bart-large": 1024,
    "facebook/bart-large-mnli": 1024,
    "facebook/bart-large-cnn": 1024,
    "facebook/bart-large-xsum": 1024,
    "yjernite/bart_eli5": 1024,
}


class BartTokenizerFast(PreTrainedTokenizerFast):
    r"""
    构建一个“快速”BART分词器(基于HuggingFace的*tokenizers*库),派生自GPT-2分词器,使用字节级别的字节对编码。

    此分词器已经训练成将空格视为标记的一部分(类似于sentencepiece),因此一个词会根据其是否位于句子开头而编码不同:

    ```
    >>> from transformers import BartTokenizerFast

    >>> tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
    >>> tokenizer("Hello world")["input_ids"]
    [0, 31414, 232, 2]

    >>> tokenizer(" Hello world")["input_ids"]
    [0, 20920, 232, 2]
    ```

    当在实例化分词器或对文本调用时传递 `add_prefix_space=True`,可以避免这种行为,但由于模型未以这种方式进行预训练,可能会导致性能下降。

    <Tip>

    当与 `is_split_into_words=True` 一起使用时,需要使用 `add_prefix_space=True` 实例化此分词器。

    </Tip>

    此分词器继承自[`PreTrainedTokenizerFast`],该类包含大多数主要方法。用户应参考此超类以获取有关这些方法的更多信息。
    ```
    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        merges_file (`str`):
            Path to the merges file.
        errors (`str`, *optional*, defaults to `"replace"`):
            Paradigm to follow when decoding bytes to UTF-8. See
            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        sep_token (`str`, *optional*, defaults to `"</s>"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        add_prefix_space (`bool`, *optional*, defaults to `False`):
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
            other word. (BART tokenizer detect beginning of words by the preceding space).
        trim_offsets (`bool`, *optional*, defaults to `True`):
            Whether the post processing step should trim offsets to avoid including whitespaces.
    """
    # 预训练模型中的词汇文件名列表
    vocab_files_names = VOCAB_FILES_NAMES
    # 预训练模型中的词汇文件映射表
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 预训练模型中的最大输入大小列表
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 定义模型输入的名称列表,包括输入ID和注意力掩码
    model_input_names = ["input_ids", "attention_mask"]
    # 指定慢速分词器的类为BartTokenizer

    # 初始化方法,用于创建一个新的实例
    def __init__(
        self,
        vocab_file=None,         # 词汇表文件路径,用于设置分词器的词汇表
        merges_file=None,        # 合并文件路径,用于设置分词器的合并规则
        tokenizer_file=None,     # 分词器文件路径,用于加载预训练的分词器
        errors="replace",        # 处理编码错误的方式
        bos_token="<s>",         # 开始符号
        eos_token="</s>",        # 结束符号
        sep_token="</s>",        # 分隔符号
        cls_token="<s>",         # 类别符号
        unk_token="<unk>",       # 未知符号
        pad_token="<pad>",       # 填充符号
        mask_token="<mask>",     # 掩码符号
        add_prefix_space=False,  # 是否在词前加空格
        trim_offsets=True,       # 是否修剪偏移量
        **kwargs,                # 其他关键字参数
    ):
        # 如果 mask_token 是字符串类型,则创建一个特殊的 AddedToken 对象,保证 normalized=True
        mask_token = (
            AddedToken(mask_token, lstrip=True, normalized=True, special=True)
            if isinstance(mask_token, str)
            else mask_token
        )
        # 调用父类的初始化方法,设置 tokenizer 的基本参数
        super().__init__(
            vocab_file,
            merges_file,
            tokenizer_file=tokenizer_file,
            errors=errors,
            bos_token=bos_token,
            eos_token=eos_token,
            sep_token=sep_token,
            cls_token=cls_token,
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            add_prefix_space=add_prefix_space,
            trim_offsets=trim_offsets,
            **kwargs,
        )

        # 获取当前的 pre_tokenizer 状态,并检查是否需要更新 add_prefix_space
        pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
        if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
            # 如果需要更新 add_prefix_space,则更新 pre_tokenizer 的配置
            pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
            pre_tok_state["add_prefix_space"] = add_prefix_space
            self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)

        # 设置实例属性 add_prefix_space
        self.add_prefix_space = add_prefix_space

        # 检查并更新 tokenizer 的后处理器组件状态
        tokenizer_component = "post_processor"
        tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
        if tokenizer_component_instance:
            state = json.loads(tokenizer_component_instance.__getstate__())

            # 如果 state 中有 "sep",将其转换为元组形式
            if "sep" in state:
                state["sep"] = tuple(state["sep"])
            # 如果 state 中有 "cls",将其转换为元组形式
            if "cls" in state:
                state["cls"] = tuple(state["cls"])

            changes_to_apply = False

            # 检查是否需要更新 add_prefix_space
            if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
                state["add_prefix_space"] = add_prefix_space
                changes_to_apply = True

            # 检查是否需要更新 trim_offsets
            if state.get("trim_offsets", trim_offsets) != trim_offsets:
                state["trim_offsets"] = trim_offsets
                changes_to_apply = True

            # 如果有需要更新的内容,则创建新的后处理器组件实例并应用更新
            if changes_to_apply:
                component_class = getattr(processors, state.pop("type"))
                new_value = component_class(**state)
                setattr(self.backend_tokenizer, tokenizer_component, new_value)
    def mask_token(self) -> str:
        """
        `str`: 返回用于训练模型的掩码标记。如果尚未设置,则记录错误信息。

        BART 分词器具有特殊的掩码标记,用于填充掩码管道。该掩码标记会贪婪地包括 *<mask>* 前面的空格。
        """
        if self._mask_token is None:
            if self.verbose:
                logger.error("Using mask_token, but it is not set yet.")
            return None
        return str(self._mask_token)

    @mask_token.setter
    def mask_token(self, value):
        """
        重写掩码标记的默认行为,使其能够吞掉前面的空格。

        这是为了与所有之前基于 BART 的模型保持向后兼容性。
        """
        # 掩码标记表现得像普通单词,即包括前面的空格,因此我们设置 lstrip=True
        value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
        self._mask_token = value

    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)

        if is_split_into_words and not self.add_prefix_space:
            raise ValueError(
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
                "to use it with pretokenized inputs."
            )

        return super()._batch_encode_plus(*args, **kwargs)

    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
        is_split_into_words = kwargs.get("is_split_into_words", False)

        if is_split_into_words and not self.add_prefix_space:
            raise ValueError(
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
                "to use it with pretokenized inputs."
            )

        return super()._encode_plus(*args, **kwargs)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        """
        将词汇表保存到指定的目录中。

        调用底层分词器模型的保存方法,并返回保存的文件列表。
        """
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        为输入构建包含特殊标记的序列。

        在 token_ids_0 前加入 bos_token_id,后加入 eos_token_id。如果提供 token_ids_1,则在其前后也加入 eos_token_id。
        """
        output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
        if token_ids_1 is None:
            return output

        return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        """
        根据序列创建 token_type_ids。

        token_ids_0 和 token_ids_1 用于创建对应的 token_type_ids,用于区分不同的句子或片段。
        """
    # 定义一个函数,用于生成用于序列对分类任务的掩码。BART 模型不使用token type ids,因此返回一个全零列表。

    Args:
        token_ids_0 (`List[int]`):
            第一个序列的ID列表。
        token_ids_1 (`List[int]`, *optional*):
            第二个序列的ID列表,用于序列对。

    Returns:
        `List[int]`: 全零列表,长度根据输入的序列长度动态计算。
    """
    # 分隔符 token 的 ID 列表
    sep = [self.sep_token_id]
    # 类别 token 的 ID 列表
    cls = [self.cls_token_id]

    # 如果第二个序列的 ID 列表为空
    if token_ids_1 is None:
        # 返回长度为 cls + token_ids_0 + sep 组合后的列表,每个元素都是 0
        return len(cls + token_ids_0 + sep) * [0]
    
    # 如果有第二个序列的 ID 列表
    # 返回长度为 cls + token_ids_0 + sep + sep + token_ids_1 + sep 组合后的列表,每个元素都是 0
    return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

.\models\bart\__init__.py

# 引入类型检查模块,用于在类型检查时导入特定模块
from typing import TYPE_CHECKING

# 从工具模块中导入必要的类和函数
from ...utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_flax_available,
    is_tf_available,
    is_tokenizers_available,
    is_torch_available,
)

# 定义导入结构,将模块名映射到需要导入的类和函数列表
_import_structure = {
    "configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"],
    "tokenization_bart": ["BartTokenizer"],
}

# 检查是否存在 Tokenizers 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 Tokenizers 库,则添加 "tokenization_bart_fast" 到导入结构中
    _import_structure["tokenization_bart_fast"] = ["BartTokenizerFast"]

# 检查是否存在 Torch 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 Torch 库,则添加 "modeling_bart" 到导入结构中
    _import_structure["modeling_bart"] = [
        "BART_PRETRAINED_MODEL_ARCHIVE_LIST",
        "BartForCausalLM",
        "BartForConditionalGeneration",
        "BartForQuestionAnswering",
        "BartForSequenceClassification",
        "BartModel",
        "BartPreTrainedModel",
        "BartPretrainedModel",
        "PretrainedBartModel",
    ]

# 检查是否存在 TensorFlow 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 TensorFlow 库,则添加 "modeling_tf_bart" 到导入结构中
    _import_structure["modeling_tf_bart"] = [
        "TFBartForConditionalGeneration",
        "TFBartForSequenceClassification",
        "TFBartModel",
        "TFBartPretrainedModel",
    ]

# 检查是否存在 Flax 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果存在 Flax 库,则添加 "modeling_flax_bart" 到导入结构中
    _import_structure["modeling_flax_bart"] = [
        "FlaxBartDecoderPreTrainedModel",
        "FlaxBartForCausalLM",
        "FlaxBartForConditionalGeneration",
        "FlaxBartForQuestionAnswering",
        "FlaxBartForSequenceClassification",
        "FlaxBartModel",
        "FlaxBartPreTrainedModel",
    ]

# 如果在类型检查时,导入以下模块和类
if TYPE_CHECKING:
    from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig
    from .tokenization_bart import BartTokenizer

    # 检查是否存在 Tokenizers 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果存在 Tokenizers 库,则导入 BartTokenizerFast 类
        from .tokenization_bart_fast import BartTokenizerFast

    # 检查是否存在 Torch 库,如果不存在则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    # 尝试导入模型 BART 的相关模块和类,如果依赖项不可用则忽略
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 BART 模型的预训练模型存档列表和各种模型类
        from .modeling_bart import (
            BART_PRETRAINED_MODEL_ARCHIVE_LIST,
            BartForCausalLM,
            BartForConditionalGeneration,
            BartForQuestionAnswering,
            BartForSequenceClassification,
            BartModel,
            BartPreTrainedModel,
            BartPretrainedModel,
            PretrainedBartModel,
        )

    # 尝试导入 TensorFlow 版本的 BART 相关模块和类,如果 TensorFlow 不可用则忽略
    try:
        if not is_tf_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 TensorFlow 版本的 BART 模型的各种类
        from .modeling_tf_bart import (
            TFBartForConditionalGeneration,
            TFBartForSequenceClassification,
            TFBartModel,
            TFBartPretrainedModel,
        )

    # 尝试导入 Flax 版本的 BART 相关模块和类,如果 Flax 不可用则忽略
    try:
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 导入 Flax 版本的 BART 模型的各种类
        from .modeling_flax_bart import (
            FlaxBartDecoderPreTrainedModel,
            FlaxBartForCausalLM,
            FlaxBartForConditionalGeneration,
            FlaxBartForQuestionAnswering,
            FlaxBartForSequenceClassification,
            FlaxBartModel,
            FlaxBartPreTrainedModel,
        )
else:
    # 导入 sys 模块
    import sys

    # 将当前模块注册到 sys.modules 中,使用 _LazyModule 包装
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\barthez\tokenization_barthez.py

# coding=utf-8
# 版权所有 2020 年 Ecole Polytechnique 和 HuggingFace Inc. 团队。
#
# 根据 Apache 许可证 2.0 版本(“许可证”),您只有在遵守许可证的情况下才能使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据“原样”分发,不提供任何明示或暗示的担保或条件。
# 有关详细信息,请参阅许可证。
""" BARThez 模型的分词类。"""


import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

import sentencepiece as spm  # 导入 sentencepiece 库

from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging  # 导入 logging 模块


logger = logging.get_logger(__name__)  # 获取当前模块的 logger 对象

VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}  # 词汇文件名字典

PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/sentencepiece.bpe.model",
        "moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/sentencepiece.bpe.model",
        "moussaKam/barthez-orangesum-title": (
            "https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/sentencepiece.bpe.model"
        ),
    },
}  # 预训练词汇文件映射

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "moussaKam/mbarthez": 1024,
    "moussaKam/barthez": 1024,
    "moussaKam/barthez-orangesum-title": 1024,
}  # 预训练位置嵌入的尺寸

SPIECE_UNDERLINE = "▁"  # SentencePiece 的特殊标记

# TODO this class is useless. This is the most standard sentencpiece model. Let's find which one is closest and nuke this.


class BarthezTokenizer(PreTrainedTokenizer):
    """
    从 `CamembertTokenizer` 和 `BartTokenizer` 改编而来。构建一个 BARThez 分词器。基于
    [SentencePiece](https://github.com/google/sentencepiece)。

    此分词器继承自 `PreTrainedTokenizer`,其中包含大多数主要方法。用户应参考
    此超类以获取关于这些方法的更多信息。

    Attributes:
        sp_model (`SentencePieceProcessor`):
            用于所有转换(字符串、标记和 ID)的 SentencePiece 处理器。
    """

    vocab_files_names = VOCAB_FILES_NAMES  # 词汇文件名字典
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP  # 预训练词汇文件映射
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES  # 预训练位置嵌入的尺寸
    model_input_names = ["input_ids", "attention_mask"]  # 模型输入的名称列表

    def __init__(
        self,
        vocab_file,
        bos_token="<s>",  # 开始标记
        eos_token="</s>",  # 结束标记
        sep_token="</s>",  # 分隔标记
        cls_token="<s>",  # 类别标记
        unk_token="<unk>",  # 未知标记
        pad_token="<pad>",  # 填充标记
        mask_token="<mask>",  # 掩码标记
        sp_model_kwargs: Optional[Dict[str, Any]] = None,  # SentencePiece 模型参数字典,默认为空
        **kwargs,  # 其他参数
    ) -> None:
        """
        初始化一个新的 BARThezTokenizer 对象。

        Args:
            mask_token (`Union[str, AddedToken]`):
                用作掩码标记的特殊令牌。如果是字符串,则 lstrip=True,special=True。
            sp_model_kwargs (`Optional[Dict]`, *optional*):
                SentencePiece 模型的额外参数,默认为空字典。
            vocab_file (`Optional[Union[str, Path]]`):
                词汇文件的路径。
            bos_token (`Optional[str]`, *optional*):
                用作开头(beginning of sequence)标记的特殊令牌。
            eos_token (`Optional[str]`, *optional*):
                用作结尾(end of sequence)标记的特殊令牌。
            unk_token (`Optional[str]`, *optional*):
                用作未知标记的特殊令牌。
            sep_token (`Optional[str]`, *optional*):
                用作分隔标记的特殊令牌。
            cls_token (`Optional[str]`, *optional*):
                用作类标记的特殊令牌。
            pad_token (`Optional[str]`, *optional*):
                用作填充标记的特殊令牌。
            **kwargs:
                其他参数传递给父类构造函数。
        """
        # 如果 mask_token 是字符串,则创建一个 AddedToken 对象,lstrip=True,special=True
        mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token

        # 如果 sp_model_kwargs 为 None,则设为空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        # 设置词汇文件路径
        self.vocab_file = vocab_file
        # 使用 SentencePieceProcessor 加载并初始化模型
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(str(vocab_file))
        # 调用父类的初始化方法,传递参数并初始化对象
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        为序列分类任务构建模型输入,通过连接和添加特殊标记。BARThez 序列的格式如下:

        - 单个序列: `<s> X </s>`
        - 序列对: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                要添加特殊标记的 ID 列表。
            token_ids_1 (`Optional[List[int]]`, *optional*):
                第二个序列的 ID 列表,用于序列对输入。

        Returns:
            `List[int]`: 包含适当特殊标记的输入 ID 列表。
        """

        if token_ids_1 is None:
            # 返回只包含一个序列的特殊标记后的输入 ID 列表
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        # 构建包含两个序列的特殊标记后的输入 ID 列表
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        返回包含特殊标记的掩码列表,用于指示输入中的特殊标记位置。

        Args:
            token_ids_0 (`List[int]`):
                输入序列的 ID 列表。
            token_ids_1 (`Optional[List[int]]`, *optional*):
                第二个序列的 ID 列表,用于序列对输入。
            already_has_special_tokens (`bool`, *optional*):
                如果输入已包含特殊标记,则为 True。

        Returns:
            `List[int]`: 标记了特殊标记位置的掩码列表。
        """
    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        if already_has_special_tokens:
            # If the token list already contains special tokens, delegate to superclass method
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is None:
            # Return a mask with 1 for the added special tokens (CLS and SEP) and 0 for sequence tokens
            return [1] + ([0] * len(token_ids_0)) + [1]
        else:
            # Return a mask with 1 for the added special tokens (CLS, SEP) for both sequences and 0s for sequence tokens
            return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.
        """
        sep = [self.sep_token_id]  # Get the separator token ID
        cls = [self.cls_token_id]  # Get the classification token ID

        if token_ids_1 is None:
            # Return a list of zeros of the length of cls + token_ids_0 + sep
            return len(cls + token_ids_0 + sep) * [0]
        else:
            # Return a list of zeros of the length of cls + token_ids_0 + sep + sep + token_ids_1 + sep
            return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    @property
    def vocab_size(self):
        # Return the size of the vocabulary
        return len(self.sp_model)

    def get_vocab(self):
        # Generate a vocabulary dictionary mapping tokens to IDs
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # Update with additional tokens
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[str]:
        # Tokenize input text into a list of strings (tokens)
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) into an ID using the vocabulary."""
        return self.sp_model.PieceToId(token)

    def _convert_id_to_token(self, index):
        """Converts an ID (integer) into a token (str) using the vocabulary."""
        return self.sp_model.IdToPiece(index)

    # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
    # 将一系列标记(字符串)转换为单个字符串。
    def convert_tokens_to_string(self, tokens):
        # 当前正在处理的子标记列表
        current_sub_tokens = []
        # 输出的字符串
        out_string = ""
        # 上一个标记是否是特殊标记
        prev_is_special = False
        # 遍历每个标记
        for token in tokens:
            # 检查特殊标记是否需要使用 sentencepiece 模型解码
            if token in self.all_special_tokens:
                # 如果上一个标记不是特殊标记,则在 out_string 后添加空格
                if not prev_is_special:
                    out_string += " "
                # 使用 sentencepiece 模型解码 current_sub_tokens,并添加当前特殊标记到 out_string
                out_string += self.sp_model.decode(current_sub_tokens) + token
                prev_is_special = True
                # 清空 current_sub_tokens,准备处理下一个标记序列
                current_sub_tokens = []
            else:
                # 将当前标记添加到 current_sub_tokens 中
                current_sub_tokens.append(token)
                prev_is_special = False
        # 处理剩余的 current_sub_tokens,并添加到 out_string
        out_string += self.sp_model.decode(current_sub_tokens)
        # 返回去除首尾空格的 out_string
        return out_string.strip()

    # 获取对象的状态信息,以便序列化保存
    def __getstate__(self):
        # 复制对象的字典属性
        state = self.__dict__.copy()
        # 将 sp_model 设置为 None,以便在序列化时排除
        state["sp_model"] = None
        return state

    # 设置对象的状态信息,以便反序列化恢复
    def __setstate__(self, d):
        # 使用字典 d 更新对象的属性
        self.__dict__ = d

        # 为了向后兼容
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 根据 sp_model_kwargs 重新创建 sp_model 对象,并加载词汇文件
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(self.vocab_file)

    # 保存词汇表到指定目录
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 设置输出的词汇文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇文件路径和输出路径不同,并且当前词汇文件存在,则复制当前词汇文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇文件不存在,则将 sentencepiece 模型序列化内容写入输出路径
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 返回保存的词汇文件路径的元组
        return (out_vocab_file,)

.\models\barthez\tokenization_barthez_fast.py

# coding=utf-8
# 版权归 2020 年 Ecole Polytechnique 和 HuggingFace Inc. 团队所有。
#
# 根据 Apache 许可证版本 2.0 进行许可;
# 除非符合许可证的要求,否则不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件是基于“原样”分发的,没有任何形式的担保或条件。
# 有关更多信息,请参阅许可证。
""" BARThez 模型的分词类。"""


import os
from shutil import copyfile
from typing import List, Optional, Tuple

from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging

# 检查是否安装了 sentencepiece
if is_sentencepiece_available():
    from .tokenization_barthez import BarthezTokenizer
else:
    BarthezTokenizer = None

# 获取 logger 实例
logger = logging.get_logger(__name__)

# 定义词汇文件名字典
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "tokenizer_file": "tokenizer.json"}

# 预训练模型词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/sentencepiece.bpe.model",
        "moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/sentencepiece.bpe.model",
        "moussaKam/barthez-orangesum-title": (
            "https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/sentencepiece.bpe.model"
        ),
    },
    "tokenizer_file": {
        "moussaKam/mbarthez": "https://huggingface.co/moussaKam/mbarthez/resolve/main/tokenizer.json",
        "moussaKam/barthez": "https://huggingface.co/moussaKam/barthez/resolve/main/tokenizer.json",
        "moussaKam/barthez-orangesum-title": (
            "https://huggingface.co/moussaKam/barthez-orangesum-title/resolve/main/tokenizer.json"
        ),
    },
}

# 预训练模型位置嵌入的尺寸映射
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "moussaKam/mbarthez": 1024,
    "moussaKam/barthez": 1024,
    "moussaKam/barthez-orangesum-title": 1024,
}

# SentencePiece 使用的分词前缀
SPIECE_UNDERLINE = "▁"

class BarthezTokenizerFast(PreTrainedTokenizerFast):
    """
    从 `CamembertTokenizer` 和 `BartTokenizer` 改编而来。构建一个“快速”的 BARThez 分词器,基于
    [SentencePiece](https://github.com/google/sentencepiece)。

    该分词器继承自 `PreTrainedTokenizerFast`,其中包含大多数主要方法。用户应参考这个超类以获取更多关于这些方法的信息。
    """
    """
    Args:
        vocab_file (`str`):
            [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
            contains the vocabulary necessary to instantiate a tokenizer.
        bos_token (`str`, *optional*, defaults to `"<s>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the beginning of
            sequence. The token used is the `cls_token`.

            </Tip>

        eos_token (`str`, *optional*, defaults to `"</s>"`):
            The end of sequence token.

            <Tip>

            When building a sequence using special tokens, this is not the token that is used for the end of sequence.
            The token used is the `sep_token`.

            </Tip>

        sep_token (`str`, *optional*, defaults to `"</s>"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        cls_token (`str`, *optional*, defaults to `"<s>"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        additional_special_tokens (`List[str]`, *optional*, defaults to `["<s>NOTUSED", "</s>NOTUSED"]`):
            Additional special tokens used by the tokenizer.
    """

    # 将文件名映射至文件名常量
    vocab_files_names = VOCAB_FILES_NAMES
    # 将预训练的词汇文件映射至词汇文件映射常量
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    # 将预训练位置嵌入大小映射至最大模型输入大小常量
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    # 模型输入名称列表
    model_input_names = ["input_ids", "attention_mask"]
    # 慢速分词器类别
    slow_tokenizer_class = BarthezTokenizer

    # 初始化方法,接受多个参数,包括词汇文件、tokenizer文件及各种特殊token
    def __init__(
        self,
        vocab_file=None,
        tokenizer_file=None,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        **kwargs,
    """
    ):
        # 如果 mask_token 是字符串类型,将其包装为一个带有剥离左侧空格和不剥离右侧空格的 AddedToken 对象;否则保持不变
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 调用父类的初始化方法,传入必要的参数和关键字参数
        super().__init__(
            vocab_file,
            tokenizer_file=tokenizer_file,
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            **kwargs,
        )

        # 设置对象的 vocab_file 属性为传入的 vocab_file
        self.vocab_file = vocab_file

    @property
    def can_save_slow_tokenizer(self) -> bool:
        # 如果 self.vocab_file 存在且是一个文件,则返回 True;否则返回 False
        return os.path.isfile(self.vocab_file) if self.vocab_file else False

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        通过添加特殊 token 构建用于序列分类任务的模型输入。BARThez 序列的格式如下:

        - 单个序列: `<s> X </s>`
        - 序列对: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                需要添加特殊 token 的 ID 列表。
            token_ids_1 (`List[int]`, *optional*):
                第二个序列的 ID 列表(对序列任务时使用)。

        Returns:
            `List[int]`: 包含适当特殊 token 的输入 ID 列表。
        """

        if token_ids_1 is None:
            # 如果没有第二个序列,返回包含 cls_token_id, token_ids_0 和 sep_token_id 的列表
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        # 返回包含 cls_token_id, token_ids_0, sep_token_id, sep_token_id, token_ids_1 和 sep_token_id 的列表
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        从两个传入的序列创建一个用于序列对分类任务的掩码。

        Args:
            token_ids_0 (`List[int]`):
                第一个序列的 ID 列表。
            token_ids_1 (`List[int]`, *optional*):
                第二个序列的 ID 列表(对序列任务时使用)。

        Returns:
            `List[int]`: 全为零的列表。
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        if token_ids_1 is None:
            # 如果没有第二个序列,返回长度为 cls_token_id, token_ids_0 和 sep 的列表,所有元素为零
            return len(cls + token_ids_0 + sep) * [0]
        # 返回长度为 cls_token_id, token_ids_0, sep, sep, token_ids_1 和 sep 的列表,所有元素为零
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
    # 定义一个方法用于保存词汇表到指定目录下的文件,返回文件路径元组
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 如果当前的快速分词器不具备保存慢速分词器所需的信息,则引发数值错误异常
        if not self.can_save_slow_tokenizer:
            raise ValueError(
                "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
                "tokenizer."
            )

        # 如果保存目录不存在,记录错误日志并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return

        # 构建输出的词汇表文件路径
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

        # 如果当前词汇表文件路径与输出文件路径不一致,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        # 返回输出词汇表文件路径的元组
        return (out_vocab_file,)

.\models\barthez\__init__.py

# 导入类型检查模块中的 TYPE_CHECKING
from typing import TYPE_CHECKING

# 导入自定义的异常类 OptionalDependencyNotAvailable 和延迟加载模块 _LazyModule
from ...utils import OptionalDependencyNotAvailable, _LazyModule

# 导入用于检查依赖是否可用的函数 is_sentencepiece_available 和 is_tokenizers_available
from ...utils import is_sentencepiece_available, is_tokenizers_available

# 定义一个空的字典 _import_structure 用于存储导入结构
_import_structure = {}

# 检查是否 sentencepiece 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_sentencepiece_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将 "tokenization_barthez" 映射到 ["BarthezTokenizer"] 并存入 _import_structure
    _import_structure["tokenization_barthez"] = ["BarthezTokenizer"]

# 检查是否 tokenizers 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_tokenizers_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果可用,则将 "tokenization_barthez_fast" 映射到 ["BarthezTokenizerFast"] 并存入 _import_structure
    _import_structure["tokenization_barthez_fast"] = ["BarthezTokenizerFast"]

# 如果 TYPE_CHECKING 为 True,则执行以下导入语句
if TYPE_CHECKING:
    # 检查是否 sentencepiece 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_sentencepiece_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从 tokenization_barthez 模块导入 BarthezTokenizer 类
        from .tokenization_barthez import BarthezTokenizer

    # 检查是否 tokenizers 可用,如果不可用则抛出 OptionalDependencyNotAvailable 异常
    try:
        if not is_tokenizers_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果可用,则从 tokenization_barthez_fast 模块导入 BarthezTokenizerFast 类
        from .tokenization_barthez_fast import BarthezTokenizerFast

# 如果 TYPE_CHECKING 为 False(通常为运行时),则执行以下导入语句
else:
    import sys

    # 动态地将当前模块指定为延迟加载模块 _LazyModule 的实例
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\models\bartpho\tokenization_bartpho.py

# 定义脚本编码为 UTF-8

# 版权声明,使用 Apache 许可证 2.0 版本
# 除非符合许可证的要求,否则不得使用此文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0

# 导入依赖库和模块
""" Tokenization classes for BARTpho-syllable model."""

import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

import sentencepiece as spm  # 导入 sentencepiece 库

# 导入通用工具模块和日志记录模块
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# SentencePiece 分词器使用的特殊标记
SPIECE_UNDERLINE = "▁"

# 词汇文件名
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model", "monolingual_vocab_file": "dict.txt"}

# 预训练模型的词汇文件映射
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
        "vinai/bartpho-syllable": "https://huggingface.co/vinai/bartpho-syllable/resolve/main/sentencepiece.bpe.model",
    },
    "monolingual_vocab_file": {
        "vinai/bartpho-syllable": "https://huggingface.co/vinai/bartpho-syllable/resolve/main/dict.txt",
    },
}

# 预训练模型的位置编码嵌入大小
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"vinai/bartpho-syllable": 1024}


class BartphoTokenizer(PreTrainedTokenizer):
    """
    自 [`XLMRobertaTokenizer`] 改编。基于 [SentencePiece](https://github.com/google/sentencepiece)。

    此分词器继承自 [`PreTrainedTokenizer`],包含大多数主要方法。用户应参考超类以获取更多有关这些方法的信息。

    Attributes:
        sp_model (`SentencePieceProcessor`):
            每次转换(字符串、标记和 ID)都使用的 SentencePiece 处理器。
    """

    # 词汇文件名
    vocab_files_names = VOCAB_FILES_NAMES

    # 预训练模型的词汇文件映射
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP

    # 预训练模型的最大模型输入尺寸
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

    # 模型输入的名称列表
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(
        self,
        vocab_file,
        monolingual_vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ) -> None:
        # 使用 lstrip=True 和 rstrip=False 来确保添加的遮罩标记行为与普通单词相同,即保留其前面的空格
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

        # 如果 sp_model_kwargs 为 None,则设置为空字典
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

        # 设置词汇文件和单语词汇文件路径
        self.vocab_file = vocab_file
        self.monolingual_vocab_file = monolingual_vocab_file

        # 使用给定的 sp_model_kwargs 创建 SentencePieceProcessor 对象
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        # 从给定的 vocab_file 中加载 SentencePiece 模型
        self.sp_model.Load(str(vocab_file))

        # 加载减少后的词汇表

        # 保持特殊标记的顺序,以保证向后兼容性
        self.fairseq_tokens_to_ids = {}
        cnt = 0
        # 遍历特殊标记列表,如果标记尚未在 fairseq_tokens_to_ids 中,则将其添加
        for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
            if str(token) not in self.fairseq_tokens_to_ids:
                self.fairseq_tokens_to_ids[str(token)] = cnt
                cnt += 1

        # 从 monolingual_vocab_file 中读取每行的第一个词作为标记,并将其添加到 fairseq_tokens_to_ids 中
        with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
            for line in f.readlines():
                token = line.strip().split()[0]
                self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)

        # 如果 mask_token 尚未在 fairseq_tokens_to_ids 中,则将其添加
        if str(mask_token) not in self.fairseq_tokens_to_ids:
            self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)

        # 创建 fairseq_ids_to_tokens 字典,用于将 fairseq_tokens_to_ids 的键值对反转
        self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}

        # 调用父类的初始化方法,传递必要的参数和关键字参数
        super().__init__(
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
            sp_model_kwargs=self.sp_model_kwargs,
            **kwargs,
        )

    def __getstate__(self):
        # 复制对象的字典属性
        state = self.__dict__.copy()
        # 将 sp_model 设为 None,以便序列化对象时不包含该属性
        state["sp_model"] = None
        # 获取序列化的 SentencePiece 模型的原型,并存储在 sp_model_proto 中
        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
        return state

    def __setstate__(self, d):
        # 将对象的字典属性设置为 d
        self.__dict__ = d

        # 为了向后兼容性
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        # 根据 sp_model_kwargs 创建 SentencePieceProcessor 对象,并从 sp_model_proto 加载序列化的模型
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating and
        adding special tokens. An BARTPho sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s></s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of input IDs with the appropriate special tokens.
        """

        if token_ids_1 is None:
            # Return a single sequence with added special tokens: <s> token_ids_0 </s>
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        
        # For sequence pairs, concatenate tokens with special tokens between and at the end: <s> token_ids_0 </s></s> token_ids_1 </s>
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence IDs from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers indicating the presence of special tokens (1) or sequence tokens (0).
        """

        if already_has_special_tokens:
            # If tokens already have special tokens, delegate to superclass method
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is None:
            # For a single sequence, mark positions of special tokens: <s> token_ids_0 </s>
            return [1] + ([0] * len(token_ids_0)) + [1]
        
        # For sequence pairs, mark positions of special tokens: <s> token_ids_0 </s></s> token_ids_1 </s>
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        """
        Create token type IDs tensor from sequences for sequence classification tasks. This method assigns each token in the input
        sequences a token type ID (0 or 1) depending on whether it belongs to the first or the second sequence.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs for the first sequence.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: A list of token type IDs where each ID corresponds to the respective input token.
        """
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. BARTPho does not
        make use of token type ids, therefore a list of zeros is returned.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of zeros.

        """
        # Define the separator and class tokens based on the model's special token IDs
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        # If token_ids_1 is None, return a list of zeros representing the mask for token_ids_0
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        # Otherwise, return a list of zeros representing the mask for the concatenated sequence pairs
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]

    @property
    def vocab_size(self):
        # Return the size of the vocabulary based on the number of entries in fairseq_ids_to_tokens
        return len(self.fairseq_ids_to_tokens)

    def get_vocab(self):
        # Construct and return a dictionary mapping token strings to their respective IDs
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        # Update the vocabulary with any additional tokens from added_tokens_encoder
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[str]:
        # Tokenize the input text using the SentencePiece model and return the tokens as a list of strings
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) into an ID using the vocabulary."""
        # Check if the token exists in fairseq_tokens_to_ids; if not, return the unknown token ID
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        else:
            return self.unk_token_id

    def _convert_id_to_token(self, index):
        """Converts an index (integer) into a token (str) using the vocabulary."""
        # Return the token corresponding to the index from fairseq_ids_to_tokens
        return self.fairseq_ids_to_tokens[index]

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) into a single string."""
        # Concatenate the tokens into a string, replacing SPIECE_UNDERLINE with space and stripping any surrounding whitespace
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
        return out_string
    # 保存词汇表到指定目录
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        # 检查保存目录是否存在,如果不存在则记录错误并返回
        if not os.path.isdir(save_directory):
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
            return
        
        # 构建输出的词汇表文件路径,根据可选的文件名前缀和预定义的文件名
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )
        
        # 构建输出的单语词汇表文件路径,根据可选的文件名前缀和预定义的文件名
        out_monolingual_vocab_file = os.path.join(
            save_directory,
            (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
        )

        # 如果当前词汇表文件不是输出文件且存在,则复制当前词汇表文件到输出路径
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        # 如果当前词汇表文件不存在,则将序列化的 sp_model 内容写入输出文件
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)

        # 如果当前单语词汇表文件不是输出文件且存在,则复制当前单语词汇表文件到输出路径
        if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
            out_monolingual_vocab_file
        ) and os.path.isfile(self.monolingual_vocab_file):
            copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
        # 如果当前单语词汇表文件不存在,则将 fairseq_tokens_to_ids 中的 token 写入输出文件
        elif not os.path.isfile(self.monolingual_vocab_file):
            with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
                for token in self.fairseq_tokens_to_ids:
                    if token not in self.all_special_tokens:
                        fp.write(f"{str(token)} \n")

        # 返回保存的词汇表文件路径和单语词汇表文件路径
        return out_vocab_file, out_monolingual_vocab_file
posted @ 2024-06-30 15:34  绝不原创的飞龙  阅读(3)  评论(0编辑  收藏  举报