Transformers-源码解析-一-

Transformers 源码解析(一)

.\activations.py

# 导入必要的库
import math
from collections import OrderedDict

# 导入 PyTorch 相关模块
import torch
from packaging import version
from torch import Tensor, nn

# 导入自定义的日志记录工具
from .utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义一个 PyTorch 模块,实现了一个高效的 GELU tanh 近似激活函数
class PytorchGELUTanh(nn.Module):
    """
    A fast C implementation of the tanh approximation of the GeLU activation function. See
    https://arxiv.org/abs/1606.08415.

    This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
    match due to rounding errors.
    """

    def __init__(self):
        super().__init__()
        # 检查所需的 PyTorch 版本是否满足要求
        if version.parse(torch.__version__) < version.parse("1.12.0"):
            raise ImportError(
                f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
                "PytorchGELUTanh. Please upgrade torch."
            )

    def forward(self, input: Tensor) -> Tensor:
        # 使用 PyTorch 的内置函数实现 GELU tanh 近似激活
        return nn.functional.gelu(input, approximate="tanh")


class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def forward(self, input: Tensor) -> Tensor:
        # 实现 GELU 激活函数的计算公式
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))


class GELUActivation(nn.Module):
    """
    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
    Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def __init__(self, use_gelu_python: bool = False):
        super().__init__()
        # 根据参数选择使用 Python 实现的 GELU 函数还是 PyTorch 内置的函数
        if use_gelu_python:
            self.act = self._gelu_python
        else:
            self.act = nn.functional.gelu

    def _gelu_python(self, input: Tensor) -> Tensor:
        # Python 实现的 GELU 函数
        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))

    def forward(self, input: Tensor) -> Tensor:
        # 调用选择的 GELU 函数进行前向传播
        return self.act(input)


class FastGELUActivation(nn.Module):
    """
    Placeholder for a fast GELU activation function. Actual implementation is not provided here.
    """
    # 应用 GELU 近似函数,比 QuickGELU 更慢但更准确。参考:https://github.com/hendrycks/GELUs
    """

    # 前向传播函数,接收一个张量作为输入,返回处理后的张量
    def forward(self, input: Tensor) -> Tensor:
        # 使用 GELU 近似函数计算
        return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Module):
    """
    Applies a fast but approximate version of GELU activation.

    Reference: https://github.com/hendrycks/GELUs
    """

    def forward(self, input: Tensor) -> Tensor:
        # Implementing GELU approximation using a sigmoid function
        return input * torch.sigmoid(1.702 * input)


class ClippedGELUActivation(nn.Module):
    """
    Applies GELU activation with output clipped to a specified range [min, max].

    This is useful for quantization purposes to handle negative values in the GELU spectrum.

    References:
    - https://arxiv.org/abs/2004.09602
    """

    def __init__(self, min: float, max: float):
        if min > max:
            raise ValueError(f"min should be < max (got min: {min}, max: {max})")

        super().__init__()
        self.min = min
        self.max = max

    def forward(self, x: Tensor) -> Tensor:
        # Applying GELU activation and clipping the output
        return torch.clip(gelu(x), self.min, self.max)


class AccurateGELUActivation(nn.Module):
    """
    Applies a more accurate version of GELU activation compared to QuickGELU.

    Reference: https://github.com/hendrycks/GELUs

    Implemented in the context of MEGA (Moving Average Equipped Gated Attention).
    """

    def __init__(self):
        super().__init__()
        self.precomputed_constant = math.sqrt(2 / math.pi)

    def forward(self, input: Tensor) -> Tensor:
        # Implementing the accurate GELU activation formula
        return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))


class MishActivation(nn.Module):
    """
    Applies the Mish activation function, a self-regularized non-monotonic activation.

    Reference: https://arxiv.org/abs/1908.08681
    """

    def __init__(self):
        super().__init__()
        if version.parse(torch.__version__) < version.parse("1.9.0"):
            self.act = self._mish_python
        else:
            self.act = nn.functional.mish

    def _mish_python(self, input: Tensor) -> Tensor:
        # Implementing Mish activation using Python function
        return input * torch.tanh(nn.functional.softplus(input))

    def forward(self, input: Tensor) -> Tensor:
        # Applying Mish activation function
        return self.act(input)


class LinearActivation(nn.Module):
    """
    Applies the linear activation function, i.e., forwarding input directly to output.
    """

    def forward(self, input: Tensor) -> Tensor:
        # Identity function; returns input unchanged
        return input


class LaplaceActivation(nn.Module):
    """
    Applies an elementwise activation based on the Laplace function, introduced in MEGA for attention.

    This activation is inspired by squared ReLU but offers a bounded range and gradient for improved stability.

    Reference: https://arxiv.org/abs/2209.10655
    """
    """
    此方法用于计算正向传播过程中的操作,对输入进行标准化处理后,应用误差函数。
    :param input: 输入张量
    :param mu: 均值参数,默认为0.707107
    :param sigma: 标准差参数,默认为0.282095
    :return: 处理后的张量

    将输入张量标准化,减去均值 mu 并除以标准差乘以 sqrt(2.0)
    input = (input - mu).div(sigma * math.sqrt(2.0))
    应用误差函数,计算误差函数的正向传播结果,返回结果
    return 0.5 * (1.0 + torch.erf(input))
    """
# 定义一个自定义的激活函数 ReLUSquaredActivation,继承自 nn.Module
class ReLUSquaredActivation(nn.Module):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    # 定义前向传播方法,接受输入 input
    def forward(self, input):
        # 应用 ReLU 激活函数到输入
        relu_applied = nn.functional.relu(input)
        # 对经过 ReLU 激活后的结果进行平方操作
        squared = torch.square(relu_applied)
        # 返回平方后的结果作为输出
        return squared


# 定义一个名为 ClassInstantier 的类,继承自 OrderedDict
class ClassInstantier(OrderedDict):
    # 重写 __getitem__ 方法,接受键 key 作为输入
    def __getitem__(self, key):
        # 调用父类 OrderedDict 的 __getitem__ 方法获取键对应的值 content
        content = super().__getitem__(key)
        # 如果值 content 是一个元组,则将其解包为 cls 和 kwargs;否则将 cls 设为 content,kwargs 设为一个空字典
        cls, kwargs = content if isinstance(content, tuple) else (content, {})
        # 返回使用 cls 和 kwargs 创建的类实例
        return cls(**kwargs)


# 定义一个名为 ACT2CLS 的字典,将字符串映射到对应的激活函数类或者类与参数元组
ACT2CLS = {
    "gelu": GELUActivation,
    "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
    "gelu_fast": FastGELUActivation,
    "gelu_new": NewGELUActivation,
    "gelu_python": (GELUActivation, {"use_gelu_python": True}),
    "gelu_pytorch_tanh": PytorchGELUTanh,
    "gelu_accurate": AccurateGELUActivation,
    "laplace": LaplaceActivation,
    "leaky_relu": nn.LeakyReLU,
    "linear": LinearActivation,
    "mish": MishActivation,
    "quick_gelu": QuickGELUActivation,
    "relu": nn.ReLU,
    "relu2": ReLUSquaredActivation,  # 引用了之前定义的 ReLUSquaredActivation 激活函数类
    "relu6": nn.ReLU6,
    "sigmoid": nn.Sigmoid,
    "silu": nn.SiLU,  # SiLU 激活函数类,也称作 Swish
    "swish": nn.SiLU,  # 同上,SiLU 激活函数
    "tanh": nn.Tanh,
}

# 使用 ClassInstantier 类创建 ACT2FN 字典,将字符串映射为对应的激活函数类实例
ACT2FN = ClassInstantier(ACT2CLS)


# 定义一个函数 get_activation,接受一个激活函数字符串作为参数
def get_activation(activation_string):
    # 如果 activation_string 存在于 ACT2FN 字典中,则返回对应的激活函数类实例
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        # 否则抛出 KeyError,指示找不到对应的激活函数字符串
        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")


# 创建几个全局变量,用于快速访问不同的激活函数实例
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

.\activations_tf.py

# 引入数学库
import math

# 引入 TensorFlow 库
import tensorflow as tf

# 引入版本解析工具
from packaging.version import parse

# 尝试引入 tf_keras 库,如果失败则引入 keras 库
try:
    import tf_keras as keras
except (ModuleNotFoundError, ImportError):
    import keras

    # 检查 keras 版本是否大于 2,如果是则抛出异常
    if parse(keras.__version__).major > 2:
        raise ValueError(
            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
            "Transformers. Please install the backwards-compatible tf-keras package with "
            "`pip install tf-keras`."
        )


def _gelu(x):
    """
    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
    initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
    https://arxiv.org/abs/1606.08415
    """
    # 将输入转换为 TensorFlow 张量
    x = tf.convert_to_tensor(x)
    # 计算高斯误差线性单元的输出
    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))

    return x * cdf


def _gelu_new(x):
    """
    Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841

    Args:
        x: float Tensor to perform activation

    Returns:
        `x` with the GELU activation applied.
    """
    # 将输入转换为 TensorFlow 张量
    x = tf.convert_to_tensor(x)
    # 定义 pi 和系数
    pi = tf.cast(math.pi, x.dtype)
    coeff = tf.cast(0.044715, x.dtype)
    # 计算平滑的 GELU 输出
    cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))

    return x * cdf


def mish(x):
    # 将输入转换为 TensorFlow 张量
    x = tf.convert_to_tensor(x)
    # 计算 Mish 激活函数的输出
    return x * tf.tanh(tf.math.softplus(x))


def gelu_fast(x):
    # 将输入转换为 TensorFlow 张量
    x = tf.convert_to_tensor(x)
    # 定义系数
    coeff1 = tf.cast(0.044715, x.dtype)
    coeff2 = tf.cast(0.7978845608, x.dtype)
    # 计算快速 GELU 的输出
    return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))


def quick_gelu(x):
    # 将输入转换为 TensorFlow 张量
    x = tf.convert_to_tensor(x)
    # 定义系数
    coeff = tf.cast(1.702, x.dtype)
    # 计算快速 GELU 的输出
    return x * tf.math.sigmoid(coeff * x)


def gelu_10(x):
    """
    Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
    it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
    https://arxiv.org/abs/2004.09602

    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
    """
    # 截断 GeLU 输出的范围在 [-10, 10] 之间
    # 这对于量化目的非常有用,因为它允许在 GeLU 光谱中映射 2 个负值。有关此技巧的更多信息,请参阅链接
    """
    对输入的张量 x 应用改进的 GELU(Gaussian Error Linear Unit)激活函数,并进行值裁剪。

    GELU 函数的数学表达式是:
    0.5 * x * (1 + tanh(math.sqrt(2 / pi) * (x + 0.044715 * x^3)))

    这里使用了一个 TensorFlow 的内置函数 _gelu 来实现 GELU 激活函数。

    参数 x: 输入的张量
    返回值: 应用 GELU 激活函数后的张量,裁剪在 [-10, 10] 的范围内
    """
    return tf.clip_by_value(_gelu(x), -10, 10)
def glu(x, axis=-1):
    """
    Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
    the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).

    Args:
        `x`: float Tensor to perform activation
        `axis`: dimension across which `x` be split in half

    Returns:
        `x` with the GLU activation applied (with its size halved across the dimension `axis`).
    """
    # 将输入 `x` 沿指定轴 `axis` 分成两半,命名为 A 和 B
    a, b = tf.split(x, 2, axis=axis)
    # 返回 A * sigmoid(B) 的结果,即 GLU 激活函数的计算结果
    return a * tf.math.sigmoid(b)


if parse(tf.version.VERSION) >= parse("2.4"):

    def approximate_gelu_wrap(x):
        # 使用 Keras 中的 approximate gelu 函数来计算 gelu 激活
        return keras.activations.gelu(x, approximate=True)

    # 设置 gelu 和 gelu_new 激活函数
    gelu = keras.activations.gelu
    gelu_new = approximate_gelu_wrap
else:
    # 如果 TensorFlow 版本低于 2.4,则使用自定义的 _gelu 和 _gelu_new 函数
    gelu = _gelu
    gelu_new = _gelu_new


# 定义激活函数名称到对应函数的映射字典
ACT2FN = {
    "gelu": gelu,
    "gelu_10": gelu_10,
    "gelu_fast": gelu_fast,
    "gelu_new": gelu_new,
    "glu": glu,
    "mish": mish,
    "quick_gelu": quick_gelu,
    "relu": keras.activations.relu,
    "sigmoid": keras.activations.sigmoid,
    "silu": keras.activations.swish,
    "swish": keras.activations.swish,
    "tanh": keras.activations.tanh,
}


def get_tf_activation(activation_string):
    # 根据激活函数名称 `activation_string` 在 ACT2FN 字典中查找对应的函数
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        # 如果未找到对应的函数,则抛出 KeyError 异常
        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")

.\audio_utils.py

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
and remove unnecessary dependencies.
"""
import warnings
from typing import Optional, Tuple, Union

import numpy as np


def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
    """
    Convert frequency from hertz to mels.

    Args:
        freq (`float` or `np.ndarray`):
            The frequency, or multiple frequencies, in hertz (Hz).
        mel_scale (`str`, *optional*, defaults to `"htk"`):
            The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.

    Returns:
        `float` or `np.ndarray`: The frequencies on the mel scale.
    """

    if mel_scale not in ["slaney", "htk", "kaldi"]:
        raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
    
    # Convert frequencies to mels based on the specified mel scale
    if mel_scale == "htk":
        return 2595.0 * np.log10(1.0 + (freq / 700.0))
    elif mel_scale == "kaldi":
        return 1127.0 * np.log(1.0 + (freq / 700.0))

    # For "slaney" scale, compute mels using specific formulas
    min_log_hertz = 1000.0
    min_log_mel = 15.0
    logstep = 27.0 / np.log(6.4)
    mels = 3.0 * freq / 200.0

    # Handle cases where freq is a numpy array or a scalar
    if isinstance(freq, np.ndarray):
        log_region = freq >= min_log_hertz
        mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
    elif freq >= min_log_hertz:
        mels = min_log_mel + np.log(freq / min_log_hertz) * logstep

    return mels


def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
    """
    Convert frequency from mels to hertz.

    Args:
        mels (`float` or `np.ndarray`):
            The frequency, or multiple frequencies, in mels.
        mel_scale (`str`, *optional*, `"htk"`):
            The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.

    Returns:
        `float` or `np.ndarray`: The frequencies in hertz.
    """

    if mel_scale not in ["slaney", "htk", "kaldi"]:
        raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
    
    # Convert mels to frequencies based on the specified mel scale
    if mel_scale == "htk":
        return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
    elif mel_scale == "kaldi":
        return 700.0 * (np.exp(mels / 1127.0) - 1.0)

    # For "slaney" scale, compute frequencies using specific formulas
    min_log_hertz = 1000.0
    min_log_mel = 15.0
    logstep = np.log(6.4) / 27.0
    freq = 200.0 * mels / 3.0

    return freq
    # 如果输入的 mels 是一个 NumPy 数组
    if isinstance(mels, np.ndarray):
        # 创建一个布尔数组 log_region,标记所有 mels 中大于等于 min_log_mel 的元素位置
        log_region = mels >= min_log_mel
        # 对于 log_region 中为 True 的位置,根据公式计算对应的频率值并存储在 freq 中
        freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
    # 如果输入的 mels 不是 NumPy 数组,而是单个数值
    elif mels >= min_log_mel:
        # 根据公式计算频率值并存储在 freq 中
        freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
    
    # 返回计算得到的频率 freq
    return freq
# 创建一个函数用于将频率从赫兹转换为分数音阶数
def hertz_to_octave(
    freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
):
    """
    Convert frequency from hertz to fractional octave numbers.
    Adapted from *librosa*.

    Args:
        freq (`float` or `np.ndarray`):
            The frequency, or multiple frequencies, in hertz (Hz).
        tuning (`float`, defaults to `0.`):
            Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
        bins_per_octave (`int`, defaults to `12`):
            Number of bins per octave.

    Returns:
        `float` or `np.ndarray`: The frequencies on the octave scale.
    """
    # 计算按照斯图加特音高(A440)偏移后的基准频率
    stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
    # 计算频率的分数音阶数
    octave = np.log2(freq / (float(stuttgart_pitch) / 16))
    return octave


# 创建一个函数用于生成三角形滤波器组成的滤波器组
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
    """
    Creates a triangular filter bank.

    Adapted from *torchaudio* and *librosa*.

    Args:
        fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
            Discrete frequencies of the FFT bins in Hz.
        filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
            Center frequencies of the triangular filters to create, in Hz.

    Returns:
        `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
    """
    # 计算滤波器中心频率的差异
    filter_diff = np.diff(filter_freqs)
    # 计算每个频率 bin 对每个滤波器的下坡斜率
    slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
    # 计算每个频率 bin 对每个滤波器的上坡斜率
    down_slopes = -slopes[:, :-2] / filter_diff[:-1]
    up_slopes = slopes[:, 2:] / filter_diff[1:]
    # 返回下坡和上坡斜率中的较小值
    return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))


# 创建一个函数用于生成色度滤波器组成的滤波器组
def chroma_filter_bank(
    num_frequency_bins: int,
    num_chroma: int,
    sampling_rate: int,
    tuning: float = 0.0,
    power: Optional[float] = 2.0,
    weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
    start_at_c_chroma: Optional[bool] = True,
):
    """
    Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.

    Adapted from *librosa*.
    """
    # 在默认情况下,如果从C音开始,色度数设置为12
    # 否则,设置为24
    # 每个bin中的频率
    # 获取FFT的频率bins,不包括直流分量
    frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]

    # 将频率bins转换为chroma bins,基于给定的调谐和每个八度内的bins数
    freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)

    # 为0 Hz的频率bin赋值,假设它比bin 1低1.5个八度
    # (这样chroma就会从bin 1旋转50%,bin宽度较宽)
    freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))

    # 计算每个bin的宽度
    bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))

    # 创建chroma滤波器,计算每个频率bin与每个chroma bin之间的差异
    chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T

    # 将chroma滤波器投影到范围 -num_chroma/2 .. num_chroma/2
    # 添加一个固定偏移量确保所有传递给rem的值都是正数
    num_chroma2 = np.round(float(num_chroma) / 2)
    chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2

    # 创建高斯形状的chroma滤波器,使它们更窄
    chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)

    # 如果指定了power,则对每列进行归一化
    if power is not None:
        chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)

    # 如果指定了weighting_parameters,则应用高斯加权
    if weighting_parameters is not None:
        center, half_width = weighting_parameters
        chroma_filters *= np.tile(
            np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
            (num_chroma, 1),
        )

    # 如果start_at_c_chroma为True,则将chroma_filters数组向左滚动,以从'C'音调类开始
    if start_at_c_chroma:
        chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)

    # 去除别名列,并复制以确保行连续性,返回numpy数组
    return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
def mel_filter_bank(
    num_frequency_bins: int,
    num_mel_filters: int,
    min_frequency: float,
    max_frequency: float,
    sampling_rate: int,
    norm: Optional[str] = None,
    mel_scale: str = "htk",
    triangularize_in_mel_space: bool = False,
) -> np.ndarray:
    """
    创建用于生成梅尔频谱图的频率 bin 转换矩阵,称为梅尔滤波器组。存在多种实现方式,这些方式在滤波器数量、滤波器形状、
    滤波器间距、滤波器带宽以及频谱扭曲方式上都有所不同。这些特性旨在近似人类对频率变化的非线性感知。

    文献中引入了不同的梅尔滤波器组变体。以下几种变体是支持的:

    - MFCC FB-20: 由Davis和Mermelstein于1980年引入,假设采样频率为10 kHz,语音带宽为 `[0, 4600]` Hz。
    - MFCC FB-24 HTK: 来自于剑桥HMM工具包(HTK)(1995年),使用24个滤波器的滤波器组,语音带宽为 `[0, 8000]` Hz。
      假设采样率 ≥ 16 kHz。
    - MFCC FB-40: 来自于Slaney在1998年为MATLAB编写的听觉工具箱,假设采样率为16 kHz,语音带宽为 `[133, 6854]` Hz。
      此版本还包括区域归一化。
    - HFCC-E FB-29(人因谱系数):由Skowronski和Harris于2004年提出,假设采样率为12.5 kHz,语音带宽为 `[0, 6250]` Hz。

    此代码改编自 *torchaudio* 和 *librosa*。请注意,torchaudio 的 `melscale_fbanks` 的默认参数实现了 `"htk"` 滤波器,
    而 librosa 使用 `"slaney"` 实现。

    Args:
        num_frequency_bins (`int`):
            用于计算频谱图的频率数量(应与 `stft` 中的相同)。
        num_mel_filters (`int`):
            要生成的梅尔滤波器数量。
        min_frequency (`float`):
            兴趣的最低频率(单位:Hz)。
        max_frequency (`float`):
            兴趣的最高频率(单位:Hz)。不应超过 `sampling_rate / 2`。
        sampling_rate (`int`):
            音频波形的采样率。
        norm (`str`, *optional*):
            如果是 `"slaney"`,将三角形梅尔权重除以梅尔带宽的宽度(区域归一化)。
        mel_scale (`str`, *optional*, defaults to `"htk"`):
            要使用的梅尔频率刻度,可选 `"htk"`、`"kaldi"` 或 `"slaney"`。
        triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
            如果启用此选项,则在梅尔空间而不是频率空间中应用三角形滤波器。在计算梅尔滤波器时应将其设置为 `True`,以便获得与 `torchaudio` 相同的结果。
    """
    # 在这里实现梅尔滤波器组的计算和返回
    pass
    if norm is not None and norm != "slaney":
        # 如果指定了 norm 参数但不是 "slaney",则抛出数值错误异常
        raise ValueError('norm must be one of None or "slaney"')

    # 计算三角形梅尔滤波器的中心点频率
    mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
    mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
    mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
    # 将梅尔频率转换为普通频率
    filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)

    if triangularize_in_mel_space:
        # 如果在梅尔空间中进行三角化,则使用FFT频率的梅尔频率宽度
        fft_bin_width = sampling_rate / (num_frequency_bins * 2)
        fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
        filter_freqs = mel_freqs
    else:
        # 否则使用普通的FFT频率范围
        fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)

    # 创建三角形滤波器组成的滤波器组
    mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)

    if norm is not None and norm == "slaney":
        # 如果使用 Slaney 格式的梅尔滤波器,则进行能量归一化
        enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
        mel_filters *= np.expand_dims(enorm, 0)

    if (mel_filters.max(axis=0) == 0.0).any():
        # 如果有至少一个梅尔滤波器全部为零,则发出警告
        warnings.warn(
            "At least one mel filter has all zero values. "
            f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
            f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
        )

    # 返回三角形滤波器组成的矩阵,用于从频谱图到梅尔频谱图的投影
    return mel_filters
def optimal_fft_length(window_length: int) -> int:
    """
    Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
    already a power of two, rounds it up to the next power or two.

    The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
    of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
    is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
    it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
    """
    # 计算大于等于 `window_length` 的最小的 2 的幂次方数
    return 2 ** int(np.ceil(np.log2(window_length)))


def window_function(
    window_length: int,
    name: str = "hann",
    periodic: bool = True,
    frame_length: Optional[int] = None,
    center: bool = True,
) -> np.ndarray:
    """
    Returns an array containing the specified window. This window is intended to be used with `stft`.

    The following window types are supported:

        - `"boxcar"`: a rectangular window
        - `"hamming"`: the Hamming window
        - `"hann"`: the Hann window
        - `"povey"`: the Povey window

    Args:
        window_length (`int`):
            The length of the window in samples.
        name (`str`, *optional*, defaults to `"hann"`):
            The name of the window function.
        periodic (`bool`, *optional*, defaults to `True`):
            Whether the window is periodic or symmetric.
        frame_length (`int`, *optional*):
            The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
            than the frame length, so that it will be zero-padded.
        center (`bool`, *optional*, defaults to `True`):
            Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.

    Returns:
        `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
    """
    # 如果 `periodic` 为真,则增加窗口长度以适应周期性需求
    length = window_length + 1 if periodic else window_length

    if name == "boxcar":
        # 返回一个长度为 `length` 的全一数组,即矩形窗口
        window = np.ones(length)
    elif name in ["hamming", "hamming_window"]:
        # 返回一个 Hamming 窗口
        window = np.hamming(length)
    elif name in ["hann", "hann_window"]:
        # 返回一个 Hann 窗口
        window = np.hanning(length)
    elif name in ["povey"]:
        # 返回一个经过幂次变换的 Hann 窗口
        window = np.power(np.hanning(length), 0.85)
    else:
        # 如果窗口类型未知,则抛出错误
        raise ValueError(f"Unknown window function '{name}'")

    if periodic:
        # 如果窗口需要周期性,则移除最后一个元素
        window = window[:-1]

    if frame_length is None:
        # 如果没有提供 `frame_length`,直接返回窗口数组
        return window

    if window_length > frame_length:
        # 如果窗口长度大于 `frame_length`,则抛出错误
        raise ValueError(
            f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
        )

    # 创建一个长度为 `frame_length` 的零数组,并将窗口数组放置到合适的位置
    padded_window = np.zeros(frame_length)
    offset = (frame_length - window_length) // 2 if center else 0
    padded_window[offset : offset + window_length] = window
    return padded_window
# TODO This method does not support batching yet as we are mainly focused on inference.
def spectrogram(
    waveform: np.ndarray,
    window: np.ndarray,
    frame_length: int,
    hop_length: int,
    fft_length: Optional[int] = None,
    power: Optional[float] = 1.0,
    center: bool = True,
    pad_mode: str = "reflect",
    onesided: bool = True,
    preemphasis: Optional[float] = None,
    mel_filters: Optional[np.ndarray] = None,
    mel_floor: float = 1e-10,
    log_mel: Optional[str] = None,
    reference: float = 1.0,
    min_value: float = 1e-10,
    db_range: Optional[float] = None,
    remove_dc_offset: Optional[bool] = None,
    dtype: np.dtype = np.float32,
) -> np.ndarray:
    """
    Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.

    This function can create the following kinds of spectrograms:

      - amplitude spectrogram (`power = 1.0`)
      - power spectrogram (`power = 2.0`)
      - complex-valued spectrogram (`power = None`)
      - log spectrogram (use `log_mel` argument)
      - mel spectrogram (provide `mel_filters`)
      - log-mel spectrogram (provide `mel_filters` and `log_mel`)

    How this works:

      1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `hop_length` samples.
      2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
      3. The DFT is taken of each windowed frame.
      4. The results are stacked into a spectrogram.

    We make a distinction between the following "blocks" of sample data, each of which may have different lengths:

      - The analysis frame. This is the size of the time slices that the input waveform is split into.
      - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
      - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.

    In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
    padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
    typically the next power of two.

    Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
    `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
    can be constructed.

    Returns:
        `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
        `(num_mel_filters, length)` for a mel spectrogram.
    """
    # Determine the length of the window
    window_length = len(window)

    # If fft_length is not provided, set it equal to frame_length
    if fft_length is None:
        fft_length = frame_length

    # Check if frame_length is greater than fft_length; raise ValueError if true
    if frame_length > fft_length:
        raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
    # 检查窗口长度与帧长度是否相等,若不相等则引发数值错误异常
    if window_length != frame_length:
        raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")

    # 检查跳跃长度是否小于等于零,若是则引发数值错误异常
    if hop_length <= 0:
        raise ValueError("hop_length must be greater than zero")

    # 检查波形的维度是否为一维,若不是则引发数值错误异常,同时给出其形状信息
    if waveform.ndim != 1:
        raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")

    # 检查波形是否为复数类型对象,若是则引发数值错误异常,因为目前不支持复数类型的波形
    if np.iscomplexobj(waveform):
        raise ValueError("Complex-valued input waveforms are not currently supported")

    # 若功率参数为None且mel滤波器不为None,则引发数值错误异常,指出不支持复数谱图计算的情况
    if power is None and mel_filters is not None:
        raise ValueError(
            "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
            "Specify `power` to fix this issue."
        )

    # 若center参数为True,则在波形两端进行中心填充,填充长度为帧长度的一半
    if center:
        padding = [(int(frame_length // 2), int(frame_length // 2))]
        waveform = np.pad(waveform, padding, mode=pad_mode)

    # 将波形数据类型转换为float64,因为np.fft内部使用float64类型进行计算
    waveform = waveform.astype(np.float64)
    window = window.astype(np.float64)

    # 将波形分割为帧,每帧长度为frame_length,计算帧的数量
    num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))

    # 计算FFT后的频率bins的数量,如果是单边谱则为(fft_length // 2) + 1,否则为fft_length
    num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
    # 创建一个空的复数64位数组作为谱图,大小为(num_frames, num_frequency_bins)
    spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)

    # 根据是否单边谱选择FFT函数,rfft比fft更快
    fft_func = np.fft.rfft if onesided else np.fft.fft
    # 创建一个长度为fft_length的零填充数组,用于存储FFT输入数据
    buffer = np.zeros(fft_length)

    timestep = 0
    # 遍历每一帧进行FFT计算
    for frame_idx in range(num_frames):
        # 将波形数据填充到buffer中的前frame_length位置
        buffer[:frame_length] = waveform[timestep : timestep + frame_length]

        # 若remove_dc_offset为True,则移除直流偏移
        if remove_dc_offset:
            buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()

        # 若preemphasis参数不为None,则对帧进行预加重处理
        if preemphasis is not None:
            buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
            buffer[0] *= 1 - preemphasis

        # 对帧数据应用窗口函数
        buffer[:frame_length] *= window

        # 计算FFT并将结果存入spectrogram中的当前帧索引位置
        spectrogram[frame_idx] = fft_func(buffer)
        timestep += hop_length

    # 若power参数不为None,则对谱图进行幅度平方计算,使用**操作符比np.power更快
    if power is not None:
        spectrogram = np.abs(spectrogram, dtype=np.float64) ** power

    # 将谱图转置,使得频率bins成为第一维度
    spectrogram = spectrogram.T

    # 若mel_filters参数不为None,则进行mel滤波器的应用,并确保谱图值不低于mel_floor
    if mel_filters is not None:
        spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
    # 检查是否同时指定了 power 和 log_mel 参数
    if power is not None and log_mel is not None:
        # 如果 log_mel 参数为 "log",应用自然对数变换到 spectrogram
        if log_mel == "log":
            spectrogram = np.log(spectrogram)
        # 如果 log_mel 参数为 "log10",应用以 10 为底的对数变换到 spectrogram
        elif log_mel == "log10":
            spectrogram = np.log10(spectrogram)
        # 如果 log_mel 参数为 "dB"
        elif log_mel == "dB":
            # 根据 power 参数选择不同的转换方法
            if power == 1.0:
                spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
            elif power == 2.0:
                spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
            else:
                # 如果 power 参数不为 1.0 或 2.0,则抛出 ValueError 异常
                raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
        else:
            # 如果 log_mel 参数不是 "log"、"log10" 或 "dB",则抛出 ValueError 异常
            raise ValueError(f"Unknown log_mel option: {log_mel}")

        # 将 spectrogram 转换为指定的 dtype 类型
        spectrogram = np.asarray(spectrogram, dtype)

    # 返回处理后的 spectrogram
    return spectrogram
# 将功率谱图转换为分贝(dB)刻度。使用基本对数属性以确保数值稳定性。
def power_to_db(
    spectrogram: np.ndarray,
    reference: float = 1.0,
    min_value: float = 1e-10,
    db_range: Optional[float] = None,
) -> np.ndarray:
    """
    Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
    logarithm properties for numerical stability.

    The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
    linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
    This means that large variations in energy may not sound all that different if the sound is loud to begin with.
    This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.

    Based on the implementation of `librosa.power_to_db`.

    Args:
        spectrogram (`np.ndarray`):
            The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
        reference (`float`, *optional*, defaults to 1.0):
            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
            the loudest part to 0 dB. Must be greater than zero.
        min_value (`float`, *optional*, defaults to `1e-10`):
            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
            `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
        db_range (`float`, *optional*):
            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.

    Returns:
        `np.ndarray`: the spectrogram in decibels
    """
    # 检查参考值是否小于等于零,如果是,则引发异常
    if reference <= 0.0:
        raise ValueError("reference must be greater than zero")
    # 检查最小值是否小于等于零,如果是,则引发异常
    if min_value <= 0.0:
        raise ValueError("min_value must be greater than zero")

    # 确保参考值不小于最小值
    reference = max(min_value, reference)

    # 将谱图限制在最小值和无上限之间
    spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
    # 计算功率谱图的分贝值
    spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))

    # 如果指定了动态范围(db_range),则进一步限制谱图在指定范围内
    if db_range is not None:
        # 检查动态范围是否小于等于零,如果是,则引发异常
        if db_range <= 0.0:
            raise ValueError("db_range must be greater than zero")
        # 将谱图限制在最小值和无上限之间
        spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)

    return spectrogram


# 将幅度谱图转换为分贝(dB)刻度。使用基本对数属性以确保数值稳定性。
def amplitude_to_db(
    spectrogram: np.ndarray,
    reference: float = 1.0,
    min_value: float = 1e-5,
    db_range: Optional[float] = None,
) -> np.ndarray:
    """
    Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
    basic logarithm properties for numerical stability.
    """
    # 将输入的振幅(mel)频谱图转换为分贝表示的频谱图。
    def amplitude_to_db(spectrogram, reference=1.0, min_value=1e-5, db_range=None):
        # 如果参考值小于等于0,抛出数值错误异常
        if reference <= 0.0:
            raise ValueError("reference must be greater than zero")
        # 如果最小值小于等于0,抛出数值错误异常
        if min_value <= 0.0:
            raise ValueError("min_value must be greater than zero")
    
        # 将参考值设置为最小值和自身的最大值之间的最大值,确保参考值不小于最小值
        reference = max(min_value, reference)
    
        # 将频谱图限制在[min_value, None]的范围内
        spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
        # 将振幅转换为分贝值,公式为 20 * log10(spectrogram / reference)
        spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
    
        # 如果提供了db_range参数,则将频谱图限制在[spectrogram.max() - db_range, None]的范围内
        if db_range is not None:
            # 如果db_range小于等于0,抛出数值错误异常
            if db_range <= 0.0:
                raise ValueError("db_range must be greater than zero")
            # 将频谱图限制在[spectrogram.max() - db_range, None]的范围内
            spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
    
        # 返回转换后的频谱图
        return spectrogram
### deprecated functions below this line ###

# 警告:此函数已弃用,将在 Transformers 版本 4.31.0 中移除
def get_mel_filter_banks(
    nb_frequency_bins: int,
    nb_mel_filters: int,
    frequency_min: float,
    frequency_max: float,
    sample_rate: int,
    norm: Optional[str] = None,
    mel_scale: str = "htk",
) -> np.array:
    # 发出未来警告,提醒函数即将被移除
    warnings.warn(
        "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
        FutureWarning,
    )
    # 调用 mel_filter_bank 函数,返回梅尔滤波器组
    return mel_filter_bank(
        num_frequency_bins=nb_frequency_bins,
        num_mel_filters=nb_mel_filters,
        min_frequency=frequency_min,
        max_frequency=frequency_max,
        sampling_rate=sample_rate,
        norm=norm,
        mel_scale=mel_scale,
    )


def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
    """
    为了计算短时傅里叶变换,需要将波形分割成重叠的窗口化片段,称为“帧”。

    Args:
        waveform (`np.array` of shape `(sample_length,)`):
            将被分割成较小块的原始波形。
        hop_length (`int`, *optional*, defaults to 160):
            波形的每个窗口之间的步长。
        fft_window_size (`int`, *optional*, defaults to 400):
            窗口的大小。
        center (`bool`, defaults to `True`):
            是否将每个帧居中于帧的中间。居中通过在左右两侧反射波形来实现。

    Return:
        framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
            可供 `np.fft` 使用的帧化波形。
    """
    # 发出未来警告,提醒函数即将被移除
    warnings.warn(
        "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
        FutureWarning,
    )
    # 初始化帧列表
    frames = []
    # 对波形数据进行帧分割,每一帧作为一个数据片段进行处理
    for i in range(0, waveform.shape[0] + 1, hop_length):
        # 如果指定居中处理
        if center:
            # 计算帧的一半窗口大小
            half_window = (fft_window_size - 1) // 2 + 1
            # 计算帧的起始和结束位置
            start = i - half_window if i > half_window else 0
            end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
            # 提取波形中的帧数据
            frame = waveform[start:end]
            # 如果起始位置是0,使用反射填充来扩展帧
            if start == 0:
                padd_width = (-i + half_window, 0)
                frame = np.pad(frame, pad_width=padd_width, mode="reflect")
            # 如果结束位置是波形的末尾,使用反射填充来扩展帧
            elif end == waveform.shape[0]:
                padd_width = (0, (i - waveform.shape[0] + half_window))
                frame = np.pad(frame, pad_width=padd_width, mode="reflect")
    
        # 如果不居中处理
        else:
            # 直接从波形中提取指定大小的帧数据
            frame = waveform[i : i + fft_window_size]
            # 获取帧的宽度
            frame_width = frame.shape[0]
            # 如果帧宽度小于指定的窗口大小,使用常数值填充
            if frame_width < waveform.shape[0]:
                frame = np.lib.pad(
                    frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
                )
        # 将处理好的帧数据添加到帧列表中
        frames.append(frame)
    
    # 将帧列表转换为 numpy 数组形式
    frames = np.stack(frames, 0)
    # 返回所有帧数据的 numpy 数组
    return frames
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
    """
    Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
    as `torch.stft`.

    Args:
        frames (`np.array` of dimension `(num_frames, fft_window_size)`):
            A framed audio signal obtained using `audio_utils.fram_wav`.
        windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
            An array representing the function used to reduce amplitude discontinuities at frame boundaries when computing STFT.
            Each frame is multiplied by this windowing function. For details on these discontinuities (Spectral leakage),
            refer to [this tutorial](https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf).
        fft_window_size (`int`, *optional*):
            Size of the window on which the Fourier transform is applied, controlling frequency resolution of the spectrogram.
            Default is `None`, where it defaults to `frame_size`. Increasing `fft_window_size` slows computation but improves resolution.

    Example:

    ```
    >>> from transformers.audio_utils import stft, fram_wave
    >>> import numpy as np

    >>> audio = np.random.rand(50)
    >>> fft_window_size = 10
    >>> hop_length = 2
    >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
    >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
    ```

    Returns:
        spectrogram (`np.ndarray`):
            Spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using STFT algorithm
    """
    warnings.warn(
        "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
        FutureWarning,
    )
    # Determine the frame size from input frames
    frame_size = frames.shape[1]

    # Set fft_window_size to frame_size if not provided
    if fft_window_size is None:
        fft_window_size = frame_size

    # Validate fft_window_size against frame_size
    if fft_window_size < frame_size:
        raise ValueError("FFT size must be greater or equal to the frame size")

    # Calculate the number of FFT bins to store
    nb_frequency_bins = (fft_window_size >> 1) + 1

    # Initialize an empty array for the spectrogram
    spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)

    # Initialize an array for the FFT signal
    fft_signal = np.zeros(fft_window_size)

    # Iterate over frames and compute STFT
    for f, frame in enumerate(frames):
        # Apply windowing function to the frame if provided
        if windowing_function is not None:
            np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
        else:
            fft_signal[:frame_size] = frame
        
        # Compute FFT and store in the spectrogram array
        spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]

    # Transpose the spectrogram and return
    return spectrogram.T

.\benchmark\benchmark.py

# coding=utf-8
# 声明文件编码格式为 UTF-8

# 版权声明,版权归 The HuggingFace Inc. 团队所有
# 版权归 NVIDIA 公司所有,保留所有权利

# 根据 Apache 许可证 2.0 版本,除非符合许可证的要求,否则不得使用此文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0

# 如果适用法律要求或书面同意,本软件按 "原样" 分发,不提供任何明示或暗示的保证或条件
# 请参阅许可证,了解详细的法律规定

"""
    在 PyTorch 中对库进行推理和训练的基准测试。
"""

# 导入计时模块
import timeit
# 导入类型提示模块
from typing import Callable, Optional

# 导入配置工具模块
from ..configuration_utils import PretrainedConfig
# 导入模型映射和带语言模型头部的模型映射
from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
# 导入工具模块,包括检测 Py3nvml 和 Torch 是否可用,以及日志记录
from ..utils import is_py3nvml_available, is_torch_available, logging
# 导入基准测试工具模块,包括内存、内存摘要、CPU 最高内存测量、内存跟踪等
from .benchmark_utils import (
    Benchmark,
    Memory,
    MemorySummary,
    measure_peak_memory_cpu,
    start_memory_tracing,
    stop_memory_tracing,
)

# 如果 Torch 可用
if is_torch_available():
    # 导入 Torch 模块
    import torch

    # 导入 PyTorch 基准测试参数
    from .benchmark_args import PyTorchBenchmarkArguments

# 如果 Py3nvml 可用
if is_py3nvml_available():
    # 导入 Py3nvml 模块
    import py3nvml.py3nvml as nvml

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义 PyTorch 基准测试类,继承自 Benchmark 类
class PyTorchBenchmark(Benchmark):
    # 声明 PyTorch 基准测试类的参数
    args: PyTorchBenchmarkArguments
    # 声明预训练配置
    configs: PretrainedConfig
    # 框架名称为 PyTorch
    framework: str = "PyTorch"

    # 框架版本属性,返回 Torch 的版本号
    @property
    def framework_version(self):
        return torch.__version__

    # 推理速度方法,返回推理的速度
    def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 准备推理函数
        _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
        # 测量推理速度
        return self._measure_speed(_inference)

    # 推理内存方法,返回内存占用
    def _inference_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 准备推理函数
        _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
        # 测量内存占用
        return self._measure_memory(_inference)

    # 训练速度方法,返回训练的速度
    def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 准备训练函数
        _train = self._prepare_train_func(model_name, batch_size, sequence_length)
        # 测量训练速度
        return self._measure_speed(_train)

    # 训练内存方法,返回内存占用
    def _train_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 准备训练函数
        _train = self._prepare_train_func(model_name, batch_size, sequence_length)
        # 测量内存占用
        return self._measure_memory(_train)
    # 定义一个方法,用于准备推理函数,接受模型名称、批大小和序列长度作为参数,并返回一个无参数的回调函数
    def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
        # 从配置字典中获取指定模型名称的配置信息
        config = self.config_dict[model_name]

        # 如果设置了 torchscript 标志,则将配置中的 torchscript 属性设置为 True
        if self.args.torchscript:
            config.torchscript = True

        # 检查配置中是否包含模型类信息,并且列表不为空
        has_model_class_in_config = (
            hasattr(config, "architectures")
            and isinstance(config.architectures, list)
            and len(config.architectures) > 0
        )

        # 如果不仅仅是预训练模型且配置中包含模型类信息,则尝试实例化指定的模型类
        if not self.args.only_pretrain_model and has_model_class_in_config:
            try:
                # 获取配置中的第一个模型类名称
                model_class = config.architectures[0]
                # 动态导入 transformers 模块,并从中获取指定的模型类
                transformers_module = __import__("transformers", fromlist=[model_class])
                model_cls = getattr(transformers_module, model_class)
                # 使用模型类和配置信息实例化模型
                model = model_cls(config)
            except ImportError:
                # 抛出 ImportError 如果指定的模型类不存在
                raise ImportError(
                    f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
                    " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
                )
        else:
            # 根据配置中的类信息从 MODEL_MAPPING 中选择相应的模型,并实例化
            model = MODEL_MAPPING[config.__class__](config)

        # 将模型设置为评估模式
        model.eval()
        # 将模型移动到指定的设备上(GPU 或 CPU)
        model.to(self.args.device)

        # 对于 encoder-decoder 模型,词汇表大小可能会以不同方式保存
        vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
        # 创建一个随机的输入张量 input_ids,形状为 (batch_size, sequence_length),数据类型为长整型,放置在指定的设备上
        input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)

        # 如果设置了 fp16 标志,则将模型转换为半精度浮点数运行
        if self.args.fp16:
            logger.info("Running training in Mixed Precision...")
            if not self.args.is_gpu:
                # 如果不是 GPU,抛出 ValueError,因为混合精度计算只支持 GPU
                raise ValueError("Mixed precision is possible only for GPU.")
            # 将模型转换为半精度浮点数
            model.half()

        # 如果设置了 torchscript 标志,则使用 torch.jit.trace 对模型进行跟踪
        if self.args.torchscript:
            with torch.no_grad():
                inference_model = torch.jit.trace(model, input_ids)
        else:
            # 否则,直接使用原始模型
            inference_model = model

        # 定义 encoder-decoder 模型和 encoder 模型的前向推理函数
        def encoder_decoder_forward():
            with torch.no_grad():
                # 对输入数据 input_ids 进行推理,同时提供 decoder_input_ids 作为输入
                outputs = inference_model(input_ids, decoder_input_ids=input_ids)
            return outputs

        def encoder_forward():
            with torch.no_grad():
                # 对输入数据 input_ids 进行推理
                outputs = inference_model(input_ids)
            return outputs

        # 根据配置信息中是否为 encoder-decoder 模型选择对应的推理函数
        _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
        return _forward
    def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
        # 获取指定模型名称对应的配置信息
        config = self.config_dict[model_name]

        # 检查配置中是否包含模型类信息
        has_model_class_in_config = (
            hasattr(config, "architectures")
            and isinstance(config.architectures, list)
            and len(config.architectures) > 0
        )

        # 如果不仅仅是预训练模型,并且配置中包含模型类信息
        if not self.args.only_pretrain_model and has_model_class_in_config:
            try:
                # 从配置中获取模型类名
                model_class = config.architectures[0]
                # 动态导入 transformers 模块中的模型类
                transformers_module = __import__("transformers", fromlist=[model_class])
                model_cls = getattr(transformers_module, model_class)
                # 使用配置创建模型实例
                model = model_cls(config)
            except ImportError:
                # 抛出导入错误,指示模型类不存在
                raise ImportError(
                    f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
                    " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
                )
        else:
            # 如果仅使用预定义的语言模型头部映射来创建模型
            model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)

        # 如果设置了 torchscript,目前还未实现 torchscript 的训练
        if self.args.torchscript:
            raise NotImplementedError("Training for torchscript is currently not implemented")
        else:
            # 将模型设置为训练模式
            train_model = model

        # 将模型切换到指定的设备(GPU 或 CPU)
        model.train()
        model.to(self.args.device)

        # 对于 encoder-decoder 模型,词汇表大小可能以不同方式保存
        vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
        # 生成随机输入 ID,形状为 (batch_size, sequence_length),数据类型为 long,放置在指定设备上
        input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)

        # 如果启用混合精度训练
        if self.args.fp16:
            logger.info("Running training in Mixed Precision...")
            if not self.args.is_gpu:
                # 如果不是 GPU,不能使用混合精度
                raise ValueError("Mixed precision is possible only for GPU.")

            # 使用半精度浮点数进行训练,以减少内存使用
            model.half()

        # 定义计算损失和反向传播的函数,针对 encoder 模型
        def compute_loss_and_backprob_encoder():
            loss = train_model(input_ids, labels=input_ids)[0]
            loss.backward()
            return loss

        # 定义计算损失和反向传播的函数,针对 encoder-decoder 模型
        def compute_loss_and_backprob_encoder_decoder():
            loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
            loss.backward()
            return loss

        # 根据配置是否为 encoder-decoder 模型,选择不同的训练函数
        _train = (
            compute_loss_and_backprob_encoder_decoder
            if config.is_encoder_decoder
            else compute_loss_and_backprob_encoder
        )
        return _train
    # 定义一个方法,用于测量函数执行速度,返回一个浮点数表示执行时间
    def _measure_speed(self, func) -> float:
        try:
            # 如果使用 TPU 或者需要 torchscript 编译,额外运行10次以稳定编译过程
            logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
            timeit.repeat(
                func,
                repeat=1,
                number=5,
            )

            # 根据参数设定重复运行 func 函数,记录运行时间
            runtimes = timeit.repeat(
                func,
                repeat=self.args.repeat,
                number=10,
            )

            # 如果使用 TPU 并且开启了 torch_xla_tpu_print_metrics,则打印性能指标
            if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
                import torch_xla.debug.metrics as met

                self.print_fn(met.metrics_report())

            # 返回最小运行时间除以10的结果,以获得平均每次运行的时间
            return min(runtimes) / 10.0
        except RuntimeError as e:
            # 如果运行时出现异常,打印错误信息并返回 "N/A"
            self.print_fn(f"Doesn't fit on GPU. {e}")
            return "N/A"
    # 定义一个方法 `_measure_memory`,接收一个不接受参数并不返回任何内容的函数作为参数,
    # 返回一个元组,包含一个 `Memory` 对象和一个 `MemorySummary` 对象
    def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
        try:
            # 如果设置了逐行追踪内存使用情况
            if self.args.trace_memory_line_by_line:
                # 启动以 `transformers` 为标识的内存追踪
                trace = start_memory_tracing("transformers")

            # 如果程序运行在 TPU 上
            if self.args.is_tpu:
                # 抛出未实现错误,因为目前尚未实现 TPU 的内存基准测试
                raise NotImplementedError(
                    "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
                    " `--no-memory` or `args.memory=False`"
                )
            # 如果程序运行在 GPU 上
            elif self.args.is_gpu:
                # 如果没有安装 py3nvml 库
                if not is_py3nvml_available():
                    # 发出警告,提示未安装 py3nvml 库,无法记录 GPU 内存使用情况
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    # 将 memory 设为字符串 "N/A"
                    memory = "N/A"
                else:
                    # 记录日志,提示正在测量 GPU 设备的总体使用情况
                    logger.info(
                        "Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
                        " on the same GPU."
                    )
                    # 初始化 nvml 库
                    nvml.nvmlInit()
                    # 执行传入的函数 func
                    func()
                    # 获取指定索引的 GPU 设备句柄
                    handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
                    # 获取 GPU 设备的内存信息
                    meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
                    # 获取已使用的最大字节数
                    max_bytes_in_use = meminfo.used
                    # 创建 Memory 对象,表示已使用的最大字节数
                    memory = Memory(max_bytes_in_use)
                    # 关闭 nvml 库
                    nvml.nvmlShutdown()
            # 如果程序运行在 CPU 上
            else:
                # 测量 CPU 的峰值内存使用情况
                memory_bytes = measure_peak_memory_cpu(func)
                # 如果 memory_bytes 是整数,则创建 Memory 对象,表示测量到的内存字节数,否则直接使用 memory_bytes
                memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes

            # 如果设置了逐行追踪内存使用情况
            if self.args.trace_memory_line_by_line:
                # 停止内存追踪,并获取追踪结果的汇总信息
                summary = stop_memory_tracing(trace)
            else:
                # 否则,汇总信息设为 None
                summary = None

            # 返回内存对象和汇总信息对象的元组
            return memory, summary
        # 捕获 RuntimeError 异常
        except RuntimeError as e:
            # 打印异常信息,指出 GPU 不适合执行当前任务
            self.print_fn(f"Doesn't fit on GPU. {e}")
            # 返回 "N/A" 表示不适合 GPU 执行
            return "N/A", None

.\benchmark\benchmark_args.py

# 导入必要的库和模块
from dataclasses import dataclass, field
from typing import Tuple

# 导入工具函数和变量
from ..utils import (
    cached_property,
    is_torch_available,
    is_torch_xla_available,
    is_torch_xpu_available,
    logging,
    requires_backends,
)
# 导入基准测试参数类
from .benchmark_args_utils import BenchmarkArguments

# 如果 Torch 可用,则导入 Torch 库
if is_torch_available():
    import torch

# 如果 Torch XLA 可用,则导入 Torch XLA 核心模块
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 PyTorch 基准测试参数类,继承自 BenchmarkArguments 类
@dataclass
class PyTorchBenchmarkArguments(BenchmarkArguments):
    # 已弃用的参数列表
    deprecated_args = [
        "no_inference",
        "no_cuda",
        "no_tpu",
        "no_speed",
        "no_memory",
        "no_env_print",
        "no_multi_process",
    ]

    def __init__(self, **kwargs):
        """
        此 __init__ 方法用于向后兼容。在完全移除弃用参数后,可以简单删除这个类。
        """
        # 遍历所有弃用参数
        for deprecated_arg in self.deprecated_args:
            # 如果 kwargs 中包含该弃用参数
            if deprecated_arg in kwargs:
                # 获取对应的正向参数名
                positive_arg = deprecated_arg[3:]
                # 设置实例属性为相反的值
                setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
                # 记录警告信息
                logger.warning(
                    f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
                    f" {positive_arg}={kwargs[positive_arg]}"
                )

        # 设置 torchscript 属性,如果未提供则使用默认值 False
        self.torchscript = kwargs.pop("torchscript", self.torchscript)
        # 设置 torch_xla_tpu_print_metrics 属性,如果未提供则使用默认值 False
        self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
        # 设置 fp16_opt_level 属性,如果未提供则使用默认值 "O1"
        self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
        
        # 调用父类的构造函数
        super().__init__(**kwargs)

    # 定义缓存属性的装饰器
    @cached_property
    # 设置设备初始化函数,返回一个 torch.device 对象和 GPU 数量
    def _setup_devices(self) -> Tuple["torch.device", int]:
        # 检查是否需要加载 torch 后端
        requires_backends(self, ["torch"])
        # 打印日志,标识正在设置 PyTorch 设备
        logger.info("PyTorch: setting up devices")
        # 如果不使用 CUDA,设备为 CPU,GPU 数量为 0
        if not self.cuda:
            device = torch.device("cpu")
            n_gpu = 0
        # 如果支持 Torch XLA,则设备为 TPU,GPU 数量为 0
        elif is_torch_xla_available():
            device = xm.xla_device()
            n_gpu = 0
        # 如果支持 Torch XPU,则设备为 XPU,获取 XPU 设备数量作为 GPU 数量
        elif is_torch_xpu_available():
            device = torch.device("xpu")
            n_gpu = torch.xpu.device_count()
        # 否则,默认使用 CUDA 设备,根据 CUDA 是否可用获取 GPU 数量
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            n_gpu = torch.cuda.device_count()
        # 返回设备和 GPU 数量的元组
        return device, n_gpu

    @property
    def is_tpu(self):
        # 返回当前是否支持 Torch XLA 并且启用了 TPU
        return is_torch_xla_available() and self.tpu

    @property
    def device_idx(self) -> int:
        # 获取当前 CUDA 设备的索引
        requires_backends(self, ["torch"])
        # TODO(PVP): 目前仅支持单 GPU
        return torch.cuda.current_device()

    @property
    def device(self) -> "torch.device":
        # 返回设置的主设备对象(torch.device)
        requires_backends(self, ["torch"])
        return self._setup_devices[0]

    @property
    def n_gpu(self):
        # 返回设置的 GPU 数量
        requires_backends(self, ["torch"])
        return self._setup_devices[1]

    @property
    def is_gpu(self):
        # 返回是否至少有一个 GPU 可用
        return self.n_gpu > 0

.\benchmark\benchmark_args_tf.py

# 设置编码格式为 UTF-8
# 版权声明,版权归 HuggingFace Inc. 团队和 NVIDIA CORPORATION 所有
# 根据 Apache License, Version 2.0 许可证使用本文件,除非符合许可证的条款,否则不得使用本文件
# 获取许可证的副本,请访问 http://www.apache.org/licenses/LICENSE-2.0
# 根据适用法律要求或书面同意,本软件按“原样”分发,无任何明示或暗示的担保或条件
# 有关许可证的详细信息,请参阅许可证文档

# 导入必要的模块和库
from dataclasses import dataclass, field  # 导入 dataclass 类型和 field 函数
from typing import Tuple  # 导入 Tuple 类型

# 从自定义的 utils 模块中导入 cached_property, is_tf_available, logging, requires_backends 函数
from ..utils import cached_property, is_tf_available, logging, requires_backends
# 从 benchmark_args_utils 模块中导入 BenchmarkArguments 类
from .benchmark_args_utils import BenchmarkArguments

# 如果 TensorFlow 可用,导入 TensorFlow 模块
if is_tf_available():
    import tensorflow as tf

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 TensorFlowBenchmarkArguments 类,继承自 BenchmarkArguments 类
@dataclass
class TensorFlowBenchmarkArguments(BenchmarkArguments):
    # 已弃用的参数列表
    deprecated_args = [
        "no_inference",
        "no_cuda",
        "no_tpu",
        "no_speed",
        "no_memory",
        "no_env_print",
        "no_multi_process",
    ]

    def __init__(self, **kwargs):
        """
        初始化方法用于处理已弃用的参数。在完全移除弃用参数后,可以删除此方法和相应代码。
        """
        # 遍历已弃用的参数列表
        for deprecated_arg in self.deprecated_args:
            # 如果传入的参数中包含已弃用的参数
            if deprecated_arg in kwargs:
                # 根据约定将参数名处理为正向的名称
                positive_arg = deprecated_arg[3:]
                # 将已弃用参数的值设置为相反的值,并移除原有的已弃用参数
                kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
                # 记录警告日志,提示用户使用正确的参数或标志
                logger.warning(
                    f"{deprecated_arg} is deprecated. Please use --no-{positive_arg} or "
                    f"{positive_arg}={kwargs[positive_arg]}"
                )
        
        # 将 TPU 名称从传入的参数中提取出来,如果不存在则使用默认值
        self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
        # 将设备索引号从传入的参数中提取出来,如果不存在则使用默认值
        self.device_idx = kwargs.pop("device_idx", self.device_idx)
        # 将 eager 模式标志从传入的参数中提取出来,如果不存在则使用默认值
        self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
        # 将使用 XLA 编译的标志从传入的参数中提取出来,如果不存在则使用默认值
        self.use_xla = kwargs.pop("use_xla", self.use_xla)
        
        # 调用父类 BenchmarkArguments 的初始化方法,传入剩余的参数
        super().__init__(**kwargs)

    # TPU 名称,支持使用帮助文档
    tpu_name: str = field(
        default=None,
        metadata={"help": "Name of TPU"},
    )
    
    # 设备索引号,默认为 0,支持使用帮助文档
    device_idx: int = field(
        default=0,
        metadata={"help": "CPU / GPU device index. Defaults to 0."},
    )
    
    # 是否启用 eager 模式的标志,默认为 False,支持使用帮助文档
    eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
    
    # 是否使用 XLA JIT 编译的标志,默认为 False,支持使用帮助文档
    use_xla: bool = field(
        default=False,
        metadata={
            "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
        },
    )

    @cached_property
    # 设置用于处理 TPU 的函数,返回一个 TPUClusterResolver 对象或 None
    def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        tpu = None
        # 如果已经配置了 TPU
        if self.tpu:
            try:
                # 如果指定了 TPU 名称,使用指定名称创建 TPUClusterResolver 对象
                if self.tpu_name:
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
                # 否则创建默认 TPUClusterResolver 对象
                else:
                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
            except ValueError:
                tpu = None
        return tpu

    # 设置分布策略的缓存属性,返回一个包含策略和 TPUClusterResolver 对象的元组
    @cached_property
    def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        # 如果是 TPU 环境
        if self.is_tpu:
            # 连接到 TPU 集群
            tf.config.experimental_connect_to_cluster(self._setup_tpu)
            # 初始化 TPU 系统
            tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)

            # 创建 TPUStrategy 对象
            strategy = tf.distribute.TPUStrategy(self._setup_tpu)
        else:
            # 当前不允许多 GPU 情况
            if self.is_gpu:
                # TODO: 目前仅支持单 GPU
                # 设置可见的 GPU 设备为指定索引的 GPU
                tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
                # 创建 OneDeviceStrategy 对象,指定设备为指定索引的 GPU
                strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
            else:
                # 禁用 GPU,设置可见的设备为空列表
                tf.config.set_visible_devices([], "GPU")
                # 创建 OneDeviceStrategy 对象,指定设备为指定索引的 CPU
                strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")

        return strategy

    # 返回当前是否配置了 TPU
    @property
    def is_tpu(self) -> bool:
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        return self._setup_tpu is not None

    # 返回当前的分布策略
    @property
    def strategy(self) -> "tf.distribute.Strategy":
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        return self._setup_strategy

    # 返回当前可用的 GPU 列表
    @property
    def gpu_list(self):
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        return tf.config.list_physical_devices("GPU")

    # 返回当前可用的 GPU 数量
    @property
    def n_gpu(self) -> int:
        # 要求当前对象支持 TensorFlow 后端
        requires_backends(self, ["tf"])
        # 如果支持 CUDA,则返回 GPU 列表的长度
        if self.cuda:
            return len(self.gpu_list)
        # 否则返回 0
        return 0

    # 返回当前是否配置了 GPU
    @property
    def is_gpu(self) -> bool:
        return self.n_gpu > 0

.\benchmark\benchmark_args_utils.py

# 导入必要的模块和库
import dataclasses  # 导入用于定义数据类的模块
import json  # 导入处理 JSON 数据的模块
import warnings  # 导入警告处理模块
from dataclasses import dataclass, field  # 从 dataclasses 模块导入 dataclass 装饰器和 field 函数
from time import time  # 从 time 模块导入 time 函数
from typing import List  # 导入 List 类型提示

from ..utils import logging  # 导入相对路径的 logging 模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def list_field(default=None, metadata=None):
    # 返回一个数据类 field,用于处理列表类型的字段
    return field(default_factory=lambda: default, metadata=metadata)


@dataclass
class BenchmarkArguments:
    """
    BenchMarkArguments are arguments we use in our benchmark scripts **which relate to the training loop itself**.

    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
    line.
    """

    models: List[str] = list_field(
        default=[],
        metadata={
            "help": (
                "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
                " of all available models"
            )
        },
    )

    batch_sizes: List[int] = list_field(
        default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
    )

    sequence_lengths: List[int] = list_field(
        default=[8, 32, 128, 512],
        metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
    )

    inference: bool = field(
        default=True,
        metadata={"help": "Whether to benchmark inference of model. Inference can be disabled via --no-inference."},
    )
    cuda: bool = field(
        default=True,
        metadata={"help": "Whether to run on available cuda devices. Cuda can be disabled via --no-cuda."},
    )
    tpu: bool = field(
        default=True, metadata={"help": "Whether to run on available tpu devices. TPU can be disabled via --no-tpu."}
    )
    fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
    training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
    verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
    speed: bool = field(
        default=True,
        metadata={"help": "Whether to perform speed measurements. Speed measurements can be disabled via --no-speed."},
    )
    memory: bool = field(
        default=True,
        metadata={
            "help": "Whether to perform memory measurements. Memory measurements can be disabled via --no-memory"
        },
    )
    # 设置一个布尔类型的字段,用于指示是否进行内存测量,可以通过 --no-memory 参数禁用内存测量

    trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
    # 设置一个布尔类型的字段,用于指示是否逐行跟踪内存使用情况

    save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
    # 设置一个布尔类型的字段,用于指示是否将结果保存到 CSV 文件中

    log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
    # 设置一个布尔类型的字段,用于指示是否将所有的打印语句保存到日志文件中

    env_print: bool = field(default=False, metadata={"help": "Whether to print environment information"})
    # 设置一个布尔类型的字段,用于指示是否打印环境信息

    multi_process: bool = field(
        default=True,
        metadata={
            "help": (
                "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
                " multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
                " for debugging / testing and on TPU."
            )
        },
    )
    # 设置一个布尔类型的字段,用于指示是否使用多进程进行内存和速度测量,建议用于准确的 CPU 和 GPU 内存测量,仅在调试/测试和使用 TPU 时禁用此选项

    inference_time_csv_file: str = field(
        default=f"inference_time_{round(time())}.csv",
        metadata={"help": "CSV filename used if saving time results to csv."},
    )
    # 设置一个字符串类型的字段,指定保存推理时间结果的 CSV 文件名

    inference_memory_csv_file: str = field(
        default=f"inference_memory_{round(time())}.csv",
        metadata={"help": "CSV filename used if saving memory results to csv."},
    )
    # 设置一个字符串类型的字段,指定保存推理内存结果的 CSV 文件名

    train_time_csv_file: str = field(
        default=f"train_time_{round(time())}.csv",
        metadata={"help": "CSV filename used if saving time results to csv for training."},
    )
    # 设置一个字符串类型的字段,指定保存训练时间结果的 CSV 文件名

    train_memory_csv_file: str = field(
        default=f"train_memory_{round(time())}.csv",
        metadata={"help": "CSV filename used if saving memory results to csv for training."},
    )
    # 设置一个字符串类型的字段,指定保存训练内存结果的 CSV 文件名

    env_info_csv_file: str = field(
        default=f"env_info_{round(time())}.csv",
        metadata={"help": "CSV filename used if saving environment information."},
    )
    # 设置一个字符串类型的字段,指定保存环境信息的 CSV 文件名

    log_filename: str = field(
        default=f"log_{round(time())}.csv",
        metadata={"help": "Log filename used if print statements are saved in log."},
    )
    # 设置一个字符串类型的字段,指定保存打印语句的日志文件名

    repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
    # 设置一个整数类型的字段,指定实验运行的次数

    only_pretrain_model: bool = field(
        default=False,
        metadata={
            "help": (
                "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
                " model weights."
            )
        },
    )
    # 设置一个布尔类型的字段,用于指示是否仅加载预训练模型权重,而不加载 config.architectures 中定义的模型结构

    def __post_init__(self):
        warnings.warn(
            f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
            " are deprecated in general and it is advised to use external Benchmarking libraries "
            " to benchmark Transformer models.",
            FutureWarning,
        )
        # 初始化方法,发出警告提示类已过时,建议使用外部基准库对 Transformer 模型进行基准测试

    def to_json_string(self):
        """
        Serializes this instance to a JSON string.
        """
        return json.dumps(dataclasses.asdict(self), indent=2)
        # 将当前实例序列化为 JSON 字符串的方法
    # 返回模型名称列表,如果模型列表为空,则引发值错误异常
    def model_names(self) -> List[str]:
        # 检查模型列表是否为空
        if len(self.models) <= 0:
            # 如果为空,抛出值错误异常,提醒用户至少提供一个模型名称或模型标识符
            raise ValueError(
                "Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
                " google-bert/bert-base-cased` or `args.models = ['google-bert/bert-base-cased']."
            )
        # 返回模型名称列表
        return self.models

    @property
    # 返回布尔值,指示是否进行多进程处理
    def do_multi_processing(self):
        # 如果不使用多进程,则返回 False
        if not self.multi_process:
            return False
        # 如果使用 TPU,则记录信息并返回 False,因为目前不支持在 TPU 上进行多进程处理
        elif self.is_tpu:
            logger.info("Multiprocessing is currently not possible on TPU.")
            return False
        else:
            # 否则返回 True,表示可以进行多进程处理
            return True

.\benchmark\benchmark_tf.py

# coding=utf-8
# 设置编码格式为 UTF-8

# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
# 版权声明

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据 Apache 许可证版本 2.0 进行许可

# you may not use this file except in compliance with the License.
# 除非符合许可证要求,否则不得使用此文件

# You may obtain a copy of the License at
# 可以在以下网址获取许可证副本

#     http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 除非适用法律要求或书面同意,本软件按"原样"分发,不提供任何形式的担保或条件

# See the License for the specific language governing permissions and
# limitations under the License.
# 请查看许可证以了解特定的语言授权和限制

"""
    Benchmarking the library on inference and training in PyTorch.
"""
# 此模块用于在 PyTorch 中进行推断和训练的性能基准测试

import random  # 导入随机数模块
import timeit  # 导入计时模块
from functools import wraps  # 导入 wraps 装饰器
from typing import Callable, Optional  # 导入类型提示

from ..configuration_utils import PretrainedConfig  # 导入预训练配置
from ..models.auto.modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING  # 导入 TensorFlow 模型映射
from ..utils import is_py3nvml_available, is_tf_available, logging  # 导入工具函数和日志模块
from .benchmark_utils import (  # 导入性能基准测试相关工具
    Benchmark,
    Memory,
    MemorySummary,
    measure_peak_memory_cpu,
    start_memory_tracing,
    stop_memory_tracing,
)

# 如果 TensorFlow 可用,则导入 TensorFlow 模块和相关错误类
if is_tf_available():
    import tensorflow as tf
    from tensorflow.python.framework.errors_impl import ResourceExhaustedError

    from .benchmark_args_tf import TensorFlowBenchmarkArguments  # 导入 TensorFlow 的性能基准测试参数

# 如果 py3nvml 可用,则导入 py3nvml 模块
if is_py3nvml_available():
    import py3nvml.py3nvml as nvml

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
    """
    返回一个装饰器函数,根据参数决定以急切模式还是图模式运行 TensorFlow 函数。

    Args:
        do_eager_mode (bool): 是否使用急切执行模式
        use_xla (bool): 是否使用 XLA 加速

    Returns:
        Callable: 装饰器函数,用于在急切模式或图模式下运行给定函数
    """
    def run_func(func):
        @wraps(func)
        def run_in_eager_mode(*args, **kwargs):
            return func(*args, **kwargs)

        @wraps(func)
        @tf.function(experimental_compile=use_xla)
        def run_in_graph_mode(*args, **kwargs):
            return func(*args, **kwargs)

        if do_eager_mode is True:
            if use_xla is not False:
                raise ValueError(
                    "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
                )
            return run_in_eager_mode
        else:
            return run_in_graph_mode

    return run_func


def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]:
    """
    生成指定形状和范围内随机整数张量作为输入 ID。

    Args:
        batch_size (int): 批量大小
        sequence_length (int): 序列长度
        vocab_size (int): 词汇表大小

    Returns:
        tf.Tensor: 随机整数张量,形状为 (batch_size, sequence_length)
    """
    rng = random.Random()
    values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)]
    return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)


class TensorFlowBenchmark(Benchmark):
    """
    TensorFlow 的性能基准测试类,继承自 Benchmark 类。
    """
    args: TensorFlowBenchmarkArguments  # TensorFlow 的性能基准测试参数
    configs: PretrainedConfig  # 预训练模型的配置
    framework: str = "TensorFlow"  # 框架名称为 TensorFlow

    @property
    def framework_version(self):
        """
        返回当前 TensorFlow 的版本号。

        Returns:
            str: TensorFlow 的版本号字符串
        """
        return tf.__version__  # 返回 TensorFlow 的版本号
    # 计算推理速度的私有方法,返回模型推理速度(每秒推理样本数)
    def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 获取设备策略
        strategy = self.args.strategy
        # 如果策略为空,则抛出数值错误异常
        if strategy is None:
            raise ValueError("A device strategy has to be initialized before using TensorFlow.")
        # 准备推理函数
        _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
        # 测量推理函数的速度并返回
        return self._measure_speed(_inference)

    # 计算训练速度的私有方法,返回模型训练速度(每秒训练样本数)
    def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 获取设备策略
        strategy = self.args.strategy
        # 如果策略为空,则抛出数值错误异常
        if strategy is None:
            raise ValueError("A device strategy has to be initialized before using TensorFlow.")
        # 准备训练函数
        _train = self._prepare_train_func(model_name, batch_size, sequence_length)
        # 测量训练函数的速度并返回
        return self._measure_speed(_train)

    # 计算推理内存占用的私有方法,返回模型推理时内存信息
    def _inference_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 如果使用 GPU,则设置 GPU 内存增长策略
        if self.args.is_gpu:
            tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
        # 获取设备策略
        strategy = self.args.strategy
        # 如果策略为空,则抛出数值错误异常
        if strategy is None:
            raise ValueError("A device strategy has to be initialized before using TensorFlow.")
        # 准备推理函数
        _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
        # 测量推理函数的内存占用并返回
        return self._measure_memory(_inference)

    # 计算训练内存占用的私有方法,返回模型训练时内存信息
    def _train_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 如果使用 GPU,则设置 GPU 内存增长策略
        if self.args.is_gpu:
            tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
        # 获取设备策略
        strategy = self.args.strategy
        # 如果策略为空,则抛出数值错误异常
        if strategy is None:
            raise ValueError("A device strategy has to be initialized before using TensorFlow.")

        # 准备训练函数
        _train = self._prepare_train_func(model_name, batch_size, sequence_length)
        # 测量训练函数的内存占用并返回
        return self._measure_memory(_train)
    # 准备推断函数,用于模型推断
    def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
        # 获取指定模型配置
        config = self.config_dict[model_name]

        # 如果启用了混合精度,则抛出未实现错误
        if self.args.fp16:
            raise NotImplementedError("Mixed precision is currently not supported.")

        # 检查配置中是否包含模型类信息
        has_model_class_in_config = (
            hasattr(config, "architectures")
            and isinstance(config.architectures, list)
            and len(config.architectures) > 0
        )
        # 如果不仅仅是预训练模型且配置中有模型类信息,则尝试初始化模型
        if not self.args.only_pretrain_model and has_model_class_in_config:
            try:
                # 构建模型类名,以'TF'开头表示使用TensorFlow模型
                model_class = "TF" + config.architectures[0]
                # 动态导入transformers库中的模型类
                transformers_module = __import__("transformers", fromlist=[model_class])
                model_cls = getattr(transformers_module, model_class)
                # 使用配置初始化模型
                model = model_cls(config)
            except ImportError:
                # 如果导入失败,则抛出导入错误,提示用户设置`--only_pretrain_model`参数测试预训练模型
                raise ImportError(
                    f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
                    " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
                )
        else:
            # 如果仅仅是预训练模型或配置中没有模型类信息,则使用预定义的映射创建模型
            model = TF_MODEL_MAPPING[config.__class__](config)

        # 对于编码器-解码器模型,vocab_size的保存方式有所不同
        vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
        # 生成随机输入ID,用于模型推断
        input_ids = random_input_ids(batch_size, sequence_length, vocab_size)

        # 定义编码器-解码器模型推断函数,根据是否是编码器-解码器模型选择不同的输入方式
        @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
        def encoder_decoder_forward():
            return model(input_ids, decoder_input_ids=input_ids, training=False)

        # 定义编码器模型推断函数
        @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
        def encoder_forward():
            return model(input_ids, training=False)

        # 根据配置选择推断函数是编码器-解码器推断还是编码器推断
        _inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward

        # 返回选择的推断函数
        return _inference
    # 定义一个私有方法,用于准备训练函数,该函数返回一个无参数的可调用对象
    def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
        # 从配置字典中获取特定模型名称对应的配置
        config = self.config_dict[model_name]

        # 如果参数中设置了 eager_mode 不为 False,抛出数值错误
        if self.args.eager_mode is not False:
            raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")

        # 如果参数中启用了 fp16,抛出未实现错误,暂不支持混合精度训练
        if self.args.fp16:
            raise NotImplementedError("Mixed precision is currently not supported.")

        # 检查配置中是否包含模型类信息
        has_model_class_in_config = (
            hasattr(config, "architectures")  # 检查配置是否包含 architectures 属性
            and isinstance(config.architectures, list)  # architectures 属性是否为列表类型
            and len(config.architectures) > 0  # architectures 列表长度大于 0
        )
        # 如果不仅是预训练模型,并且配置中包含模型类信息,则尝试加载模型类
        if not self.args.only_pretrain_model and has_model_class_in_config:
            try:
                # 构建模型类名称,以 'TF' 开头表示 TensorFlow 模型
                model_class = "TF" + config.architectures[0]
                # 动态导入 transformers 模块中的指定模型类
                transformers_module = __import__("transformers", fromlist=[model_class])
                model_cls = getattr(transformers_module, model_class)
                # 使用模型类和配置创建模型实例
                model = model_cls(config)
            except ImportError:
                # 如果导入失败,抛出导入错误,提醒用户检查模型类是否存在
                raise ImportError(
                    f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
                    " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
                )
        else:
            # 如果仅加载预训练模型或配置中不包含模型类信息,则使用默认的语言模型和配置创建模型
            model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)

        # 对于 encoder-decoder 类型的模型,需要特殊处理词汇表大小的设置
        vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
        # 生成随机的输入 ID,用于模型训练
        input_ids = random_input_ids(batch_size, sequence_length, vocab_size)

        # 定义 encoder-decoder 模型训练函数,根据 eager_mode 和 use_xla 参数优化执行方式
        @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
        def encoder_decoder_train():
            # 计算模型在给定输入下的损失值,并获取损失相对于可训练变量的梯度
            loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
            gradients = tf.gradients(loss, model.trainable_variables)
            return gradients

        # 定义 encoder 模型训练函数,根据 eager_mode 和 use_xla 参数优化执行方式
        @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
        def encoder_train():
            # 计算模型在给定输入下的损失值,并获取损失相对于可训练变量的梯度
            loss = model(input_ids, labels=input_ids, training=True)[0]
            gradients = tf.gradients(loss, model.trainable_variables)
            return gradients

        # 根据模型配置决定返回 encoder-decoder 训练函数还是 encoder 训练函数
        _train = encoder_decoder_train if config.is_encoder_decoder else encoder_train

        return _train
    def _measure_speed(self, func) -> float:
        # 使用给定的策略作用域执行以下代码块
        with self.args.strategy.scope():
            try:
                if self.args.is_tpu or self.args.use_xla:
                    # 如果使用 TPU 或者启用 XLA,则额外运行 5 次以稳定编译过程
                    logger.info("Do inference on TPU. Running model 5 times to stabilize compilation")
                    timeit.repeat(func, repeat=1, number=5)

                # 根据文档建议,使用最小值而非平均值来计算时间
                runtimes = timeit.repeat(
                    func,
                    repeat=self.args.repeat,  # 重复测量次数
                    number=10,  # 每次测量执行的次数
                )

                # 返回最小运行时间的平均值
                return min(runtimes) / 10.0
            except ResourceExhaustedError as e:
                # 如果资源不足错误,则打印相关信息
                self.print_fn(f"Doesn't fit on GPU. {e}")

.\benchmark\benchmark_utils.py

# 导入所需的模块和库
import copy  # 导入 copy 模块,用于对象的复制操作
import csv  # 导入 csv 模块,用于 CSV 文件的读写操作
import linecache  # 导入 linecache 模块,用于按行缓存操作
import os  # 导入 os 模块,提供了与操作系统交互的功能
import platform  # 导入 platform 模块,用于访问平台相关属性
import sys  # 导入 sys 模块,提供了对 Python 解释器的访问
import warnings  # 导入 warnings 模块,用于警告控制
from abc import ABC, abstractmethod  # 从 abc 模块导入 ABC 和 abstractmethod 用于抽象基类定义
from collections import defaultdict, namedtuple  # 导入 defaultdict 和 namedtuple 类型,用于默认值字典和命名元组
from datetime import datetime  # 导入 datetime 类,用于日期时间操作
from multiprocessing import Pipe, Process, Queue  # 导入多进程相关模块,包括 Pipe、Process 和 Queue
from multiprocessing.connection import Connection  # 导入 Connection 类,用于多进程通信
from typing import Callable, Iterable, List, NamedTuple, Optional, Union  # 导入类型提示相关功能

from .. import AutoConfig, PretrainedConfig  # 导入上层目录的 AutoConfig 和 PretrainedConfig 类
from .. import __version__ as version  # 导入版本号
from ..utils import (
    is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
)  # 从上层目录中导入一些工具函数和变量
from .benchmark_args_utils import BenchmarkArguments  # 从当前目录中导入 BenchmarkArguments 类


if is_torch_available():
    from torch.cuda import empty_cache as torch_empty_cache  # 如果 Torch 可用,导入清空 GPU 缓存函数

if is_tf_available():
    from tensorflow.python.eager import context as tf_context  # 如果 TensorFlow 可用,导入 TensorFlow context

if is_psutil_available():
    import psutil  # 如果 psutil 可用,导入 psutil 模块

if is_py3nvml_available():
    import py3nvml.py3nvml as nvml  # 如果 py3nvml 可用,导入 py3nvml 模块

if platform.system() == "Windows":
    from signal import CTRL_C_EVENT as SIGKILL  # 如果是 Windows 系统,导入 CTRL_C_EVENT 作为 SIGKILL
else:
    from signal import SIGKILL  # 如果是其他系统,导入 SIGKILL 信号

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象,用于日志输出

_is_memory_tracing_enabled = False  # 内存追踪开关,默认关闭

BenchmarkOutput = namedtuple(
    "BenchmarkOutput",
    [
        "time_inference_result",
        "memory_inference_result",
        "time_train_result",
        "memory_train_result",
        "inference_summary",
        "train_summary",
    ],
)  # 定义命名元组 BenchmarkOutput,用于存储基准测试的输出结果


def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]:
    """
    This function wraps another function into its own separated process. In order to ensure accurate memory
    measurements it is important that the function is executed in a separate process

    Args:
        - `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
        - `do_multi_processing`: (`bool`) Whether to run function on separate process or not
    """
    # 这个函数将另一个函数包装成自己的独立进程。为了确保精确的内存测量,重要的是函数在独立进程中执行。
    # `func`: 要执行的函数,必须是一个可以在独立进程中运行的通用函数
    # `do_multi_processing`: 是否在独立进程上运行函数的布尔值
    # 定义一个函数,用于在单独的进程中执行给定的函数,并确保内存的正确使用

    def multi_process_func(*args, **kwargs):
        # 定义一个内部函数,用于在单独进程中运行指定的函数,并将结果放入队列
        def wrapper_func(queue: Queue, *args):
            try:
                # 调用传入的函数,获取其结果
                result = func(*args)
            except Exception as e:
                # 捕获异常,记录错误日志并打印异常信息
                logger.error(e)
                print(e)
                # 将结果设置为 "N/A"
                result = "N/A"
            # 将结果放入队列
            queue.put(result)

        # 创建一个队列对象
        queue = Queue()
        # 创建一个进程对象,目标为内部定义的 wrapper_func 函数,传入队列和参数
        p = Process(target=wrapper_func, args=[queue] + list(args))
        # 启动进程
        p.start()
        # 从队列中获取结果
        result = queue.get()
        # 等待进程结束
        p.join()
        # 返回从进程中获取的结果
        return result

    # 如果需要多进程处理
    if do_multi_processing:
        # 记录信息,指示函数将在自己的进程中执行
        logger.info(f"Function {func} is executed in its own process...")
        # 返回多进程处理的函数 multi_process_func
        return multi_process_func
    else:
        # 如果不需要多进程处理,直接返回原始的函数 func
        return func
def is_memory_tracing_enabled():
    # 返回全局变量 `_is_memory_tracing_enabled` 的值,表示内存追踪是否启用
    global _is_memory_tracing_enabled
    return _is_memory_tracing_enabled


class Frame(NamedTuple):
    """
    `Frame` 是一个 NamedTuple,用于收集当前帧的状态。`Frame` 有以下字段:

        - 'filename' (string): 当前执行的文件名
        - 'module' (string): 当前执行的模块名
        - 'line_number' (int): 当前执行的行号
        - 'event' (string): 触发追踪的事件(默认为 "line")
        - 'line_text' (string): Python 脚本中行的文本内容
    """

    filename: str
    module: str
    line_number: int
    event: str
    line_text: str


class UsedMemoryState(NamedTuple):
    """
    `UsedMemoryState` 是一个命名元组,具有以下字段:

        - 'frame': 一个 `Frame` 命名元组,存储当前追踪帧的信息(当前文件、当前文件中的位置)
        - 'cpu_memory': 执行该行前的 CPU RSS 内存状态
        - 'gpu_memory': 执行该行前的 GPU 使用内存(所有 GPU 的总和,或者仅限于 `gpus_to_trace` 指定的 GPU)
    """

    frame: Frame
    cpu_memory: int
    gpu_memory: int


class Memory(NamedTuple):
    """
    `Memory` 命名元组只有一个字段 `bytes`,可以通过调用 `__repr__` 方法得到以兆字节为单位的人类可读字符串。

        - `bytes` (integer): 字节数
    """

    bytes: int

    def __repr__(self) -> str:
        return str(bytes_to_mega_bytes(self.bytes))


class MemoryState(NamedTuple):
    """
    `MemoryState` 是一个命名元组,列出了带有以下字段的帧 + CPU/GPU 内存:

        - `frame` (`Frame`): 当前帧(参见上面的定义)
        - `cpu`: 当前帧期间消耗的 CPU 内存,作为 `Memory` 命名元组
        - `gpu`: 当前帧期间消耗的 GPU 内存,作为 `Memory` 命名元组
        - `cpu_gpu`: 当前帧期间消耗的 CPU + GPU 内存,作为 `Memory` 命名元组
    """

    frame: Frame
    cpu: Memory
    gpu: Memory
    cpu_gpu: Memory


class MemorySummary(NamedTuple):
    """
    `MemorySummary` 是一个命名元组,还未定义字段,将来可能会添加关于内存概述的信息。
    """
    # 定义一个名为 `MemorySummary` 的命名元组,包含以下字段:

    # - `sequential`: 从 `memory_trace` 计算而来的 `MemoryState` 命名元组列表,表示每行代码执行前后内存的差值。
    # - `cumulative`: 从 `memory_trace` 计算而来的 `MemoryState` 命名元组列表,表示每行代码的累积内存增加量,
    #   如果某行代码被多次执行,其内存增加量会被多次累加。列表按内存消耗从大到小排序(可能为负数,表示释放内存)。
    # - `current`: 当前内存状态的 `MemoryState` 命名元组列表。
    # - `total`: `Memory` 命名元组,表示完整追踪期间的内存总增加量。如果 `ignore_released_memory` 为 `True`
    #   (默认值),则忽略内存释放(消耗为负数)的行。

    sequential: List[MemoryState]
    cumulative: List[MemoryState]
    current: List[MemoryState]
    total: Memory
MemoryTrace = List[UsedMemoryState]
# 定义了一个类型别名 MemoryTrace,表示一个列表,列表元素是 UsedMemoryState 类型的对象

def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int:
    """
    测量给定函数 `function` 的 CPU 内存峰值消耗,运行时间至少 interval 秒,最多 20 * interval 秒。
    此函数受 `memory_profiler` 包中 `memory_usage` 函数的启发:
    https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239

    Args:
        - `function`: (`callable`): 无参数函数,用于测量其内存消耗的函数

        - `interval`: (`float`, `optional`, 默认为 `0.5`): 测量内存使用的时间间隔(秒)

        - `device_idx`: (`int`, `optional`, 默认为 `None`): 要测量 GPU 使用情况的设备 ID

    Returns:
        - `max_memory`: (`int`) 字节单位的内存峰值消耗
    """

    def get_cpu_memory(process_id: int) -> int:
        """
        测量给定 `process_id` 的当前 CPU 内存使用量

        Args:
            - `process_id`: (`int`) 要测量内存的进程 ID

        Returns:
            - `memory`: (`int`) 字节单位的内存消耗
        """
        process = psutil.Process(process_id)
        try:
            # 获取进程内存信息,根据 psutil 版本不同选择不同的方法
            meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
            memory = getattr(process, meminfo_attr)()[0]
        except psutil.AccessDenied:
            raise ValueError("Psutil 访问错误.")
        return memory

    # 检查是否安装了 psutil 库,如果没有则给出警告
    if not is_psutil_available():
        logger.warning(
            "未安装 Psutil,无法记录 CPU 内存使用情况。"
            "安装 Psutil (pip install psutil) 以使用 CPU 内存跟踪。"
        )
        max_memory = "N/A"
        else:
            # 定义一个继承自 Process 的 MemoryMeasureProcess 类,用于测量进程的内存使用情况
            class MemoryMeasureProcess(Process):

                """
                `MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
                memory usage of a process
                """

                def __init__(self, process_id: int, child_connection: Connection, interval: float):
                    super().__init__()
                    self.process_id = process_id
                    self.interval = interval
                    self.connection = child_connection
                    self.num_measurements = 1
                    self.mem_usage = get_cpu_memory(self.process_id)

                def run(self):
                    # 发送信号给父进程,表示开始测量
                    self.connection.send(0)
                    stop = False
                    while True:
                        # 更新内存使用情况
                        self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id))
                        self.num_measurements += 1

                        if stop:
                            break

                        # 检查是否要停止测量
                        stop = self.connection.poll(self.interval)

                    # 将测量结果发送给父进程管道
                    self.connection.send(self.mem_usage)
                    self.connection.send(self.num_measurements)

            while True:
                # 创建子进程与父进程之间的管道
                child_connection, parent_connection = Pipe()

                # 实例化 MemoryMeasureProcess 进程对象
                mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
                mem_process.start()

                # 等待直到收到内存测量的信号
                parent_connection.recv()

                try:
                    # 执行指定的函数
                    function()

                    # 向父进程发送信号,表示执行完毕
                    parent_connection.send(0)

                    # 接收内存使用情况和测量次数
                    max_memory = parent_connection.recv()
                    num_measurements = parent_connection.recv()
                except Exception:
                    # 在一个干净的方式下终止进程
                    parent = psutil.Process(os.getpid())
                    for child in parent.children(recursive=True):
                        os.kill(child.pid, SIGKILL)
                    mem_process.join(0)
                    # 抛出运行时异常,表示进程被终止,有错误发生
                    raise RuntimeError("Process killed. Error in Process")

                # 等待进程至少运行 20 倍的时间间隔或者直到它完成
                mem_process.join(20 * interval)

                # 如果测量次数大于 4 或者间隔小于 1e-6,则跳出循环
                if (num_measurements > 4) or (interval < 1e-6):
                    break

                # 减小时间间隔
                interval /= 10

            # 返回最大内存使用情况
            return max_memory
# 定义一个函数 `start_memory_tracing`,用于设置内存跟踪,记录模块或子模块每行的 RAM 使用情况。
def start_memory_tracing(
    modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
    modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
    events_to_trace: str = "line",
    gpus_to_trace: Optional[List[int]] = None,
) -> MemoryTrace:
    """
    设置逐行跟踪,记录模块或子模块每行的 RAM 使用情况。详见 `./benchmark.py` 示例。

    Args:
        - `modules_to_trace`: (None, string, list/tuple of string) 如果为 None,则记录所有事件;如果是字符串或字符串列表:仅记录列出的模块/子模块的事件(例如 'fairseq' 或 'transformers.models.gpt2.modeling_gpt2')
        - `modules_not_to_trace`: (None, string, list/tuple of string) 如果为 None,则不避免任何模块;如果是字符串或字符串列表:不记录列出的模块/子模块的事件(例如 'torch')
        - `events_to_trace`: 要记录的事件字符串或事件字符串列表(参见官方 Python 文档的 `sys.settrace` 关于事件的列表),默认为 line
        - `gpus_to_trace`: (可选列表,默认为 None) 要跟踪的 GPU 列表。默认为跟踪所有 GPU

    Return:
        - `memory_trace`: 一个包含每个事件的 `UsedMemoryState` 列表(默认为跟踪脚本的每行)。

            - `UsedMemoryState` 是命名元组,包含以下字段:
                - 'frame': 一个 `Frame` 命名元组(见下文),存储当前追踪帧的信息(当前文件、当前文件中的位置)
                - 'cpu_memory': 执行该行前的 CPU RSS 内存状态
                - 'gpu_memory': 执行该行前的 GPU 使用内存(所有 GPU 的总和或仅对 `gpus_to_trace` 如果提供的 GPU)

    `Frame` 是由 `UsedMemoryState` 使用的命名元组,列出当前帧的状态。`Frame` 具有以下字段:
        - 'filename' (字符串): 当前执行的文件名
        - 'module' (字符串): 当前执行的模块名
        - 'line_number' (整数): 当前执行的行号
        - 'event' (字符串): 触发跟踪的事件(默认为 "line")
        - 'line_text' (字符串): Python 脚本中该行的文本

    """
    # 检查是否安装了 psutil 库
    if is_psutil_available():
        # 获取当前进程的 psutil.Process 对象
        process = psutil.Process(os.getpid())
    else:
        # 如果未安装 psutil,则记录警告信息,并设置 process 为 None
        logger.warning(
            "Psutil not installed, we won't log CPU memory usage. "
            "Install psutil (pip install psutil) to use CPU memory tracing."
        )
        process = None
    # 检查是否可以使用 py3nvml 模块进行 GPU 监控
    if is_py3nvml_available():
        try:
            # 初始化 nvml 库
            nvml.nvmlInit()
            # 如果没有指定具体要追踪的 GPU 列表,则获取所有 GPU 设备的索引列表
            devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
            # 关闭 nvml 库
            nvml.nvmlShutdown()
        # 捕获可能出现的 OSError 或 nvml.NVMLError 异常
        except (OSError, nvml.NVMLError):
            # 输出警告信息,指出初始化与 GPU 的通信时出现错误,因此无法进行 GPU 内存追踪
            logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
            # 禁用 GPU 内存追踪功能
            log_gpu = False
        else:
            # 如果没有异常,则根据条件决定是否记录 GPU 内存使用情况
            log_gpu = is_torch_available() or is_tf_available()
    else:
        # 如果 py3nvml 模块不可用,则输出警告信息,提示用户安装该模块以启用 GPU 内存追踪功能
        logger.warning(
            "py3nvml not installed, we won't log GPU memory usage. "
            "Install py3nvml (pip install py3nvml) to use GPU memory tracing."
        )
        # 禁用 GPU 内存追踪功能
        log_gpu = False

    # 初始化内存追踪列表
    memory_trace = []
    def traceit(frame, event, args):
        """
        定义一个追踪函数,在模块或子模块的每行执行之前执行,记录分配的内存到带有调试信息的列表中
        """
        global _is_memory_tracing_enabled

        # 如果内存追踪未启用,则直接返回 traceit 函数自身
        if not _is_memory_tracing_enabled:
            return traceit

        # 过滤事件类型
        if events_to_trace is not None:
            if isinstance(events_to_trace, str) and event != events_to_trace:
                return traceit
            elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
                return traceit

        # 如果当前 frame 的全局变量中不存在 "__name__",则返回 traceit 函数自身
        if "__name__" not in frame.f_globals:
            return traceit

        # 获取模块名
        name = frame.f_globals["__name__"]
        # 如果模块名不是字符串类型,则返回 traceit 函数自身
        if not isinstance(name, str):
            return traceit
        else:
            # 过滤要追踪的模块白名单
            if modules_to_trace is not None:
                if isinstance(modules_to_trace, str) and modules_to_trace not in name:
                    return traceit
                elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
                    return traceit

            # 过滤不需要追踪的模块黑名单
            if modules_not_to_trace is not None:
                if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
                    return traceit
                elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
                    return traceit

        # 记录当前追踪状态(文件名、文件中的行号等)
        lineno = frame.f_lineno
        filename = frame.f_globals["__file__"]
        # 如果文件名以 ".pyc" 或 ".pyo" 结尾,则去除最后一个字符
        if filename.endswith(".pyc") or filename.endswith(".pyo"):
            filename = filename[:-1]
        # 获取当前行的代码内容,并去除末尾的换行符
        line = linecache.getline(filename, lineno).rstrip()
        # 创建一个 Frame 对象来保存追踪状态信息
        traced_state = Frame(filename, name, lineno, event, line)

        # 记录当前内存状态(进程的 RSS 内存),并计算与先前内存状态的差异
        cpu_mem = 0
        if process is not None:
            mem = process.memory_info()
            cpu_mem = mem.rss

        gpu_mem = 0
        if log_gpu:
            # 清除 GPU 缓存
            if is_torch_available():
                torch_empty_cache()
            if is_tf_available():
                tf_context.context()._clear_caches()  # 参见 https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802

            # 统计所有 GPU 的已使用内存
            nvml.nvmlInit()

            for i in devices:
                handle = nvml.nvmlDeviceGetHandleByIndex(i)
                meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
                gpu_mem += meminfo.used

            nvml.nvmlShutdown()

        # 创建一个 UsedMemoryState 对象,记录当前的追踪状态、CPU 内存和 GPU 内存
        mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
        # 将当前内存状态添加到 memory_trace 列表中
        memory_trace.append(mem_state)

        # 返回 traceit 函数自身,以便在每行执行前再次调用
        return traceit

    # 将 traceit 函数设置为系统的追踪函数
    sys.settrace(traceit)
    # 设置全局变量 _is_memory_tracing_enabled 为 True,表示启用内存追踪功能
    global _is_memory_tracing_enabled
    _is_memory_tracing_enabled = True
    
    # 返回内存追踪对象或值,这可能是一个函数、类或者一个特定的对象
    return memory_trace
# 停止内存追踪并清理相关设置,如果提供了内存追踪,则返回内存追踪的摘要信息。
def stop_memory_tracing(
    memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
) -> Optional[MemorySummary]:
    """
    停止内存追踪并返回内存追踪的摘要信息。

    Args:
        `memory_trace` (optional output of start_memory_tracing, default: None):
            要转换为摘要的内存追踪
        `ignore_released_memory` (boolean, default: None):
            如果为 True,则仅计算内存增加量以获取总内存

    Return:
        - 如果 `memory_trace` 为 None,则返回 None
        - 否则返回一个 `MemorySummary` 命名元组,包含以下字段:

            - `sequential`:从提供的 `memory_trace` 计算而来的 `MemoryState` 命名元组列表,通过减去每行执行后的内存从而计算出来。
            - `cumulative`:每行的累积内存增加量的 `MemoryState` 命名元组列表,如果一行被多次执行,则累加其内存增加量。列表按内存消耗最大到最小排序(如果内存释放则可能为负数),如果 `ignore_released_memory` 为 True(默认)则忽略释放内存的行。
            - `total`:完整追踪期间的总内存增加量,作为 `Memory` 命名元组。

    `Memory` 命名元组包含以下字段:

        - `byte` (integer): 字节数
        - `string` (string): 人类可读的字符串表示 (例如:"3.5MB")

    `Frame` 是命名元组,用于列出当前帧状态,包含以下字段:

        - 'filename' (string): 当前执行的文件名
        - 'module' (string): 当前执行的模块名
        - 'line_number' (int): 当前执行的行号
        - 'event' (string): 触发追踪的事件(默认为 "line")
        - 'line_text' (string): Python 脚本中行的文本

    `MemoryState` 是命名元组,列出了帧 + CPU/GPU 内存,包含以下字段:

        - `frame` (`Frame`): 当前帧 (参见上文)
        - `cpu`: 当前帧期间消耗的 CPU 内存,作为 `Memory` 命名元组
        - `gpu`: 当前帧期间消耗的 GPU 内存,作为 `Memory` 命名元组
        - `cpu_gpu`: 当前帧期间消耗的 CPU + GPU 内存,作为 `Memory` 命名元组
    """
    global _is_memory_tracing_enabled
    # 禁用内存追踪标志
    _is_memory_tracing_enabled = False
    # 如果内存跟踪不为None且长度大于1,则执行以下操作
    if memory_trace is not None and len(memory_trace) > 1:
        # 初始化存储内存变化的列表和当前内存状态的列表
        memory_diff_trace = []
        memory_curr_trace = []

        # 使用默认字典创建累积内存字典,每个键值对的默认值为[0, 0, 0]
        cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])

        # 遍历内存跟踪列表中每对相邻的帧及其内存状态
        for (
            (frame, cpu_mem, gpu_mem),
            (next_frame, next_cpu_mem, next_gpu_mem),
        ) in zip(memory_trace[:-1], memory_trace[1:]):
            # 计算 CPU 内存增量和 GPU 内存增量
            cpu_mem_inc = next_cpu_mem - cpu_mem
            gpu_mem_inc = next_gpu_mem - gpu_mem
            cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
            
            # 将帧及其内存增量封装成 MemoryState 对象,加入内存差异追踪列表
            memory_diff_trace.append(
                MemoryState(
                    frame=frame,
                    cpu=Memory(cpu_mem_inc),
                    gpu=Memory(gpu_mem_inc),
                    cpu_gpu=Memory(cpu_gpu_mem_inc),
                )
            )

            # 将帧及其下一个内存状态封装成 MemoryState 对象,加入当前内存追踪列表
            memory_curr_trace.append(
                MemoryState(
                    frame=frame,
                    cpu=Memory(next_cpu_mem),
                    gpu=Memory(next_gpu_mem),
                    cpu_gpu=Memory(next_cpu_mem + next_gpu_mem),
                )
            )

            # 更新累积内存字典中当前帧的累积内存增量
            cumulative_memory_dict[frame][0] += cpu_mem_inc
            cumulative_memory_dict[frame][1] += gpu_mem_inc
            cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc

        # 按照 CPU + GPU 内存增量的总和降序排序累积内存字典
        cumulative_memory = sorted(
            cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
        )

        # 将排序后的累积内存字典转换为 MemoryState 对象列表
        cumulative_memory = [
            MemoryState(
                frame=frame,
                cpu=Memory(cpu_mem_inc),
                gpu=Memory(gpu_mem_inc),
                cpu_gpu=Memory(cpu_gpu_mem_inc),
            )
            for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
        ]

        # 按照当前内存追踪列表中的 CPU + GPU 内存字节数降序排序
        memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)

        # 如果忽略已释放的内存,则计算非负数内存的总和;否则计算所有内存的总和
        if ignore_released_memory:
            total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
        else:
            total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)

        # 将总内存字节数转换为 Memory 对象
        total_memory = Memory(total_memory)

        # 返回内存摘要对象,包括顺序内存追踪、累积内存、当前内存追踪和总内存
        return MemorySummary(
            sequential=memory_diff_trace,
            cumulative=cumulative_memory,
            current=memory_curr_trace,
            total=total_memory,
        )

    # 如果内存跟踪为None或长度不大于1,则返回None
    return None
# 定义一个函数,用于将字节数转换为兆字节数
def bytes_to_mega_bytes(memory_amount: int) -> int:
    """Utility to convert a number of bytes (int) into a number of mega bytes (int)"""
    return memory_amount >> 20


# 抽象基类 Benchmark,用于比较模型在 Transformers 中的内存和时间性能
class Benchmark(ABC):
    """
    Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
    Transformers.
    """

    # 类属性
    args: BenchmarkArguments
    configs: PretrainedConfig
    framework: str

    def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
        # 初始化方法
        self.args = args
        # 如果未提供配置,则根据 args 中的模型名称动态创建配置字典
        if configs is None:
            self.config_dict = {
                model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
            }
        else:
            self.config_dict = dict(zip(self.args.model_names, configs))

        # 发出未来警告,提示该类已过时
        warnings.warn(
            f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
            " are deprecated in general and it is advised to use external Benchmarking libraries "
            " to benchmark Transformer models.",
            FutureWarning,
        )

        # 如果要测量内存,并且未启用多进程,则发出警告
        if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
            logger.warning(
                "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
                " flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
            )

        # 初始化打印函数和其他环境信息
        self._print_fn = None
        self._framework_version = None
        self._environment_info = None

    @property
    def print_fn(self):
        # 打印函数的属性访问器
        if self._print_fn is None:
            if self.args.log_print:

                # 如果需要记录打印信息,则创建一个打印并写入日志的函数
                def print_and_log(*args):
                    with open(self.args.log_filename, "a") as log_file:
                        log_file.write("".join(args) + "\n")
                    print(*args)

                self._print_fn = print_and_log
            else:
                # 否则直接使用 print 函数
                self._print_fn = print
        return self._print_fn

    @property
    @abstractmethod
    def framework_version(self):
        # 框架版本的抽象属性
        pass

    @abstractmethod
    def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 抽象方法,用于计算推理速度
        pass

    @abstractmethod
    def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
        # 抽象方法,用于计算训练速度
        pass

    @abstractmethod
    def _inference_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 抽象方法,用于计算推理内存消耗
        pass

    @abstractmethod
    def _train_memory(
        self, model_name: str, batch_size: int, sequence_length: int
    ) -> [Memory, Optional[MemorySummary]]:
        # 抽象方法,用于计算训练内存消耗
        pass

    def inference_speed(self, *args, **kwargs) -> float:
        # 推理速度方法,调用分离进程的包装函数执行推理速度计算
        return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs)
    # 定义一个方法,用于获取训练速度,返回一个浮点数
    def train_speed(self, *args, **kwargs) -> float:
        # 调用 separate_process_wrapper_fn 函数,用于包装 self._train_speed 方法,根据 self.args.do_multi_processing 参数决定是否多进程处理
        return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs)

    # 定义一个方法,用于推断内存占用,返回一个元组,包含 Memory 对象和可选的 MemorySummary 对象
    def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
        # 调用 separate_process_wrapper_fn 函数,用于包装 self._inference_memory 方法,根据 self.args.do_multi_processing 参数决定是否多进程处理
        return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs)

    # 定义一个方法,用于获取训练内存占用,返回一个元组,包含 Memory 对象和可选的 MemorySummary 对象
    def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
        # 调用 separate_process_wrapper_fn 函数,用于包装 self._train_memory 方法,根据 self.args.do_multi_processing 参数决定是否多进程处理
        return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs)

    @property
    # 返回当前环境信息字典,如果还未初始化则进行初始化
    def environment_info(self):
        if self._environment_info is None:
            # 初始化空字典用于存储环境信息
            info = {}
            # 添加Transformers版本信息到环境信息字典中
            info["transformers_version"] = version
            # 添加框架名称到环境信息字典中
            info["framework"] = self.framework
            # 如果框架是PyTorch,添加是否使用TorchScript到环境信息字典中
            if self.framework == "PyTorch":
                info["use_torchscript"] = self.args.torchscript
            # 如果框架是TensorFlow,添加是否使用Eager Mode和是否使用XLA到环境信息字典中
            if self.framework == "TensorFlow":
                info["eager_mode"] = self.args.eager_mode
                info["use_xla"] = self.args.use_xla
            # 添加框架版本信息到环境信息字典中
            info["framework_version"] = self.framework_version
            # 添加Python版本信息到环境信息字典中
            info["python_version"] = platform.python_version()
            # 添加系统平台信息到环境信息字典中
            info["system"] = platform.system()
            # 添加CPU处理器信息到环境信息字典中
            info["cpu"] = platform.processor()
            # 添加系统架构信息到环境信息字典中
            info["architecture"] = platform.architecture()[0]
            # 添加当前日期到环境信息字典中
            info["date"] = datetime.date(datetime.now())
            # 添加当前时间到环境信息字典中
            info["time"] = datetime.time(datetime.now())
            # 添加是否使用FP16到环境信息字典中
            info["fp16"] = self.args.fp16
            # 添加是否使用多进程处理到环境信息字典中
            info["use_multiprocessing"] = self.args.do_multi_processing
            # 添加是否仅预训练模型到环境信息字典中
            info["only_pretrain_model"] = self.args.only_pretrain_model

            # 如果可以使用psutil库,添加CPU内存信息(单位MB)到环境信息字典中
            if is_psutil_available():
                info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
            else:
                # 如果psutil库不可用,记录警告信息并将CPU内存信息标记为不可用
                logger.warning(
                    "Psutil not installed, we won't log available CPU memory. "
                    "Install psutil (pip install psutil) to log available CPU memory."
                )
                info["cpu_ram_mb"] = "N/A"

            # 添加是否使用GPU到环境信息字典中
            info["use_gpu"] = self.args.is_gpu
            # 如果使用GPU,添加GPU数量信息到环境信息字典中
            if self.args.is_gpu:
                info["num_gpus"] = 1  # TODO(PVP) Currently only single GPU is supported
                # 如果可以使用py3nvml库,记录GPU相关信息到环境信息字典中
                if is_py3nvml_available():
                    nvml.nvmlInit()
                    handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
                    info["gpu"] = nvml.nvmlDeviceGetName(handle)
                    info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
                    info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
                    info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
                    nvml.nvmlShutdown()
                else:
                    # 如果py3nvml库不可用,记录警告信息并将GPU相关信息标记为不可用
                    logger.warning(
                        "py3nvml not installed, we won't log GPU memory usage. "
                        "Install py3nvml (pip install py3nvml) to log information about GPU."
                    )
                    info["gpu"] = "N/A"
                    info["gpu_ram_mb"] = "N/A"
                    info["gpu_power_watts"] = "N/A"
                    info["gpu_performance_state"] = "N/A"

            # 添加是否使用TPU到环境信息字典中
            info["use_tpu"] = self.args.is_tpu
            # TODO(PVP): 查看是否可以添加更多关于TPU的信息
            # 参考: https://github.com/pytorch/xla/issues/2180

            # 将完整的环境信息字典保存到实例变量中
            self._environment_info = info
        
        # 返回存储的环境信息字典
        return self._environment_info
    def print_results(self, result_dict, type_label):
        # 打印结果表头,包括模型名称、批量大小、序列长度和类型标签
        self.print_fn(80 * "-")
        self.print_fn(
            "Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15)
        )
        self.print_fn(80 * "-")
        # 遍历每个模型名称
        for model_name in self.args.model_names:
            # 遍历结果字典中模型名称对应的批量大小列表
            for batch_size in result_dict[model_name]["bs"]:
                # 遍历结果字典中模型名称对应的序列长度列表
                for sequence_length in result_dict[model_name]["ss"]:
                    # 获取结果字典中模型名称对应的结果数据
                    result = result_dict[model_name]["result"][batch_size][sequence_length]
                    # 如果结果是浮点数,进行格式化处理,保留三位小数或显示 "< 0.001"
                    if isinstance(result, float):
                        result = round(1000 * result) / 1000
                        result = "< 0.001" if result == 0.0 else str(result)
                    else:
                        result = str(result)
                    # 打印模型名称、批量大小、序列长度和结果数据
                    self.print_fn(
                        model_name[:30].center(30) + str(batch_size).center(15),
                        str(sequence_length).center(15),
                        result.center(15),
                    )
        # 打印结果表尾
        self.print_fn(80 * "-")

    def print_memory_trace_statistics(self, summary: MemorySummary):
        # 打印逐行内存消耗的摘要信息
        self.print_fn(
            "\nLine by line memory consumption:\n"
            + "\n".join(
                f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
                for state in summary.sequential
            )
        )
        # 打印具有最高内存消耗的行摘要信息
        self.print_fn(
            "\nLines with top memory consumption:\n"
            + "\n".join(
                f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
                for state in summary.cumulative[:6]
            )
        )
        # 打印具有最低内存消耗的行摘要信息
        self.print_fn(
            "\nLines with lowest memory consumption:\n"
            + "\n".join(
                f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
                for state in summary.cumulative[-6:]
            )
        )
        # 打印总内存增加量摘要信息
        self.print_fn(f"\nTotal memory increase: {summary.total}")
    # 如果未设置保存到 CSV 标志,则直接返回,不执行保存操作
    def save_to_csv(self, result_dict, filename):
        if not self.args.save_to_csv:
            return
        # 打印提示信息,表示正在保存结果到 CSV 文件
        self.print_fn("Saving results to csv.")
        # 打开指定文件名的 CSV 文件,以写入模式
        with open(filename, mode="w") as csv_file:
            # 如果模型名称列表为空,抛出数值错误异常,提示至少需要定义一个模型
            if len(self.args.model_names) <= 0:
                raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")

            # 定义 CSV 文件的列名
            fieldnames = ["model", "batch_size", "sequence_length"]
            # 创建 CSV 写入器对象,指定列名和额外的 "result" 列
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
            # 写入 CSV 文件的表头行
            writer.writeheader()

            # 遍历每个模型名称
            for model_name in self.args.model_names:
                # 获取当前模型在结果字典中的结果
                result_dict_model = result_dict[model_name]["result"]
                # 遍历每个批量大小(batch_size)
                for bs in result_dict_model:
                    # 遍历每个序列长度(sequence_length)
                    for ss in result_dict_model[bs]:
                        # 获取当前模型在给定批量大小和序列长度下的结果
                        result_model = result_dict_model[bs][ss]
                        # 将结果写入 CSV 文件,格式化结果值为字符串,保留小数点后四位(如果是浮点数)
                        writer.writerow(
                            {
                                "model": model_name,
                                "batch_size": bs,
                                "sequence_length": ss,
                                "result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
                                    result_model
                                ),
                            }
                        )

.\benchmark\__init__.py

# 定义一个名为 bubble_sort 的函数,接受一个列表参数 arr
def bubble_sort(arr):
    # 获取列表的长度,用于确定需要比较的次数
    n = len(arr)
    # 外层循环,控制比较的轮数,总共需要比较 n-1 轮
    for i in range(n - 1):
        # 内层循环,每轮比较相邻的元素并交换顺序
        for j in range(0, n - i - 1):
            # 如果前一个元素比后一个元素大,则交换它们的位置
            if arr[j] > arr[j + 1]:
                arr[j], arr[j + 1] = arr[j + 1], arr[j]
    # 函数执行完成后,返回排序后的列表
    return arr

.\cache_utils.py

# 导入必要的模块和类
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch

# 导入日志模块中的日志记录器
from .configuration_utils import PretrainedConfig
from .utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义一个数据类,表示缓存的抽象基类
@dataclass
class Cache:
    """
    Base, abstract class for all caches. The actual data structure is specific to each subclass.
    """

    # 更新缓存,存储新的键和值的状态到特定层的缓存中
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        raise NotImplementedError("Make sure to implement `update` in a subclass.")

    # 获取缓存状态的序列长度
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

    # 获取缓存状态的最大序列长度
    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states, if there is any."""
        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")

    # 根据新输入的序列长度返回可用的缓存长度
    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
        """Given the sequence length of the new inputs, returns the usable length of the cache."""
        # Cache without size limit -> all cache is usable
        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
        #   length, we will need to evict part of the cache (and thus not all cache is usable)
        max_length = self.get_max_length()
        previous_seq_length = self.get_seq_length(layer_idx)
        if max_length is not None and previous_seq_length + new_seq_length > max_length:
            return max_length - new_seq_length
        return previous_seq_length

    # 已弃用警告:`seen_tokens` 属性将在 v4.41 中移除,请使用 `cache_position` 模型输入代替
    @property
    def seen_tokens(self):
        logger.warning_once(
            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
            "model input instead."
        )
        if hasattr(self, "_seen_tokens"):
            return self._seen_tokens
        else:
            return None


class DynamicCache(Cache):
    """
    Concrete subclass of Cache representing a dynamic cache.
    """
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    # 初始化函数,设置空的缓存列表和初始的 tokens 计数
    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []  # 用于存储每个层的 Key 状态的列表
        self.value_cache: List[torch.Tensor] = []  # 用于存储每个层的 Value 状态的列表
        self._seen_tokens = 0  # 在 `generate` 方法中用于记录缓存已见过的 tokens 数量的计数器

    # 支持通过索引访问 `past_key_value`,例如 `past_key_value[0][0].shape[2]` 获取序列长度
    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    # 支持通过迭代访问 `past_key_value`,例如 `for x in past_key_value:` 迭代访问键和值
    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])

    # 返回当前缓存的层数,支持 `len(past_key_value)` 的操作,对应模型中的层数
    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.key_cache)

    # 更新缓存中特定层的 Key 和 Value 状态
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        Update function to update the cache with new key and value states for a specific layer.
        """
        # 将新的 Key 和 Value 状态添加到指定层的缓存中
        self.key_cache[layer_idx] = key_states
        self.value_cache[layer_idx] = value_states
        # 可选的其他缓存参数,这里可以用来扩展更新功能
        if cache_kwargs is not None:
            pass  # Placeholder for additional cache update logic
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        # 如果 layer_idx 为 0,则更新已见过的 token 数量
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        # 如果 key_cache 的长度小于等于 layer_idx,则将 key_states 和 value_states 添加到 cache 中
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(key_states)
            self.value_cache.append(value_states)
        else:
            # 否则,在已有的 cache 中更新 layer_idx 对应的 key_states 和 value_states
            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

        # 返回更新后的 key_states 和 value_states
        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # 如果 key_cache 的长度小于等于 layer_idx,则返回 0
        if len(self.key_cache) <= layer_idx:
            return 0
        # 否则返回 key_cache 中 layer_idx 对应的 tensor 的第二维度的大小
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
        # 返回 None,因为 DynamicCache 类型的缓存没有最大长度限制
        return None

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        # 遍历每个层的缓存,根据 beam_idx 重新排序 key_cache 和 value_cache
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
            device = self.value_cache[layer_idx].device
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
        # 将 DynamicCache 实例转换成遗留缓存格式的等价表示,并返回为元组的形式
        legacy_cache = ()
        for layer_idx in range(len(self)):
            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
        # 创建一个空的 DynamicCache 对象
        cache = cls()
        # 如果传入了过去的键-值对数据
        if past_key_values is not None:
            # 遍历过去的键-值对数据的每一层
            for layer_idx in range(len(past_key_values)):
                # 分别获取键状态和值状态
                key_states, value_states = past_key_values[layer_idx]
                # 将键状态和值状态更新到缓存中的指定层
                cache.update(key_states, value_states, layer_idx)
        # 返回转换后的 DynamicCache 对象
        return cache
# 定义一个名为 `SinkCache` 的类,继承自 `Cache` 类,实现了一个缓存机制,根据 [Attention Sinks paper](https://arxiv.org/abs/2309.17453) 描述的内容,
# 允许模型在超出其上下文窗口长度的情况下生成内容,同时保持对话的流畅性。当丢弃过去的标记时,模型将失去依赖于被丢弃上下文的标记生成能力。

class SinkCache(Cache):
    """
    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Parameters:
        window_length (`int`):
            The length of the context window.
        num_sink_tokens (`int`):
            The number of sink tokens. See the original paper for more information.
    """

    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
        # 初始化空列表,用于存储每一层的 Key 状态和 Value 状态的张量
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        # 设定上下文窗口长度和沉降标记的数量
        self.window_length = window_length
        self.num_sink_tokens = num_sink_tokens
        # 缓存余弦和正弦值的字典
        self.cos_sin_cache = {}
        # 记录缓存已见标记的数量,在 `generate` 方法中使用,用于统计缓存处理的标记数
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen

    @staticmethod
    def _rotate_half(x):
        # 将张量 `x` 在最后一个维度上分成两半,进行半旋转操作
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_key_rotary_pos_emb(
        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> torch.Tensor:
        # 对 Key 状态应用旋转位置嵌入,使用余弦和正弦值进行加权
        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
        return rotated_key_states

    def _get_rerotation_cos_sin(
        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if key_states.shape[-2] not in self.cos_sin_cache:
            # 临时提升到 float32 类型以提高精度
            cos = cos.to(torch.float32)
            sin = sin.to(torch.float32)

            # 计算用于向前和向后旋转到序列中前一位置所需的余弦和正弦值
            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin

            # 缓存计算结果,使用张量的数据类型,并扩展维度
            self.cos_sin_cache[key_states.shape[-2]] = (
                rerotation_cos.to(key_states.dtype).unsqueeze(0),
                rerotation_sin.to(key_states.dtype).unsqueeze(0),
            )
        return self.cos_sin_cache[key_states.shape[-2]]
    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # 如果 self.key_cache 的长度小于等于 layer_idx,返回 0
        if len(self.key_cache) <= layer_idx:
            return 0
        # 返回 self.key_cache[layer_idx] 张量的倒数第二个维度的长度
        return self.key_cache[layer_idx].shape[-2]

    def get_max_length(self) -> Optional[int]:
        """Returns the maximum sequence length of the cached states."""
        # 返回 window_length 属性的值,即缓存状态的最大序列长度
        return self.window_length

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """Updates the cache with new key and value states for a specific layer."""
        # 更新指定层的缓存状态 key_cache 和 value_cache
        self.key_cache[layer_idx] = key_states
        self.value_cache[layer_idx] = value_states

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        # 针对 beam search 重新排序缓存状态
        for layer_idx in range(len(self.key_cache)):
            device = self.key_cache[layer_idx].device
            # 根据 beam_idx 重新排序 key_cache
            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
            device = self.value_cache[layer_idx].device
            # 根据 beam_idx 重新排序 value_cache
            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
    """
    Static Cache class to be used with `torch.compile(model)`.

    Parameters:
        config (`PretrainedConfig):
            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
            required to initialize the static cache.
        max_batch_size (`int`):
            The maximum batch size with which the model will be used.
        max_cache_len (`int`):
            The maximum sequence length with which the model will be used.
        device (`torch.device`):
            The device on which the cache should be initialized. Should be the same as the layer.
        dtype (*optional*, defaults to `torch.float32`):
            The default `dtype` to use when initializing the layer.
    """

    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
        super().__init__()
        # 设置最大批处理大小
        self.max_batch_size = max_batch_size
        # 设置最大缓存长度,如果未指定则使用配置文件中的最大位置嵌入数
        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
        # 计算头部维度,如果配置中定义了自定义头部维度,则使用;否则根据隐藏层大小和注意力头数计算
        self.head_dim = (
            config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
        )

        # 设置数据类型,默认为 torch.float32
        self.dtype = dtype if dtype is not None else torch.float32
        # 设置键值头的数量,如果未指定则使用配置文件中的注意力头数
        self.num_key_value_heads = (
            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
        )

        # 初始化键和值的缓存张量,形状为 (最大批处理大小, 键值头数, 最大缓存长度, 头部维度)
        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
        self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
        self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    def update_cache(self, key_states: torch.Tensor, value_states: torch.Tensor,
                     layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for. Kept for backward compatibility
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
                to know how much of the cache it should overwrite.

        Return:
            A tuple containing the updated key and value states.
        """

        new_cache_positions = cache_kwargs.get("cache_position")  # 获取缓存位置参数
        k_out = self.key_cache  # 获取当前的键缓存
        v_out = self.value_cache  # 获取当前的值缓存

        k_out[:, :, new_cache_positions] = key_states  # 更新键缓存的指定位置的状态
        v_out[:, :, new_cache_positions] = value_states  # 更新值缓存的指定位置的状态

        return k_out, v_out  # 返回更新后的键和值缓存

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """
        Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC
        """
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
        # https://github.com/pytorch/pytorch/issues/120248 is fixed
        return (self.key_cache[0, 0].any(dim=-1)).sum()  # 计算缓存中非零值的数量,用于表示序列长度

    def get_max_length(self) -> Optional[int]:
        """
        Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
        """
        return self.max_cache_len  # 返回缓存中的最大序列长度

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """
        Reorders the cache for beam search, given the selected beam indices.
        """
        device = self.key_cache.device  # 获取当前键缓存所在的设备
        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))  # 根据beam索引重新排序键缓存
        device = self.value_cache.device  # 获取当前值缓存所在的设备
        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))  # 根据beam索引重新排序值缓存

    def to_legacy_cache(self):
        """
        Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it
        """
        return None  # 返回空值,用于保持向后兼容

.\commands\add_new_model.py

# 导入必要的模块
import json
import os
import shutil
import warnings
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List

# 导入自定义的日志模块
from ..utils import logging
# 导入基础命令类
from . import BaseTransformersCLICommand

# 尝试导入cookiecutter模块,检查是否可用
try:
    from cookiecutter.main import cookiecutter
    _has_cookiecutter = True
except ImportError:
    _has_cookiecutter = False

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 工厂函数,创建添加新模型命令实例
def add_new_model_command_factory(args: Namespace):
    return AddNewModelCommand(args.testing, args.testing_file, path=args.path)

# 添加新模型命令类,继承自基础命令类
class AddNewModelCommand(BaseTransformersCLICommand):

    # 静态方法,用于注册子命令
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        # 添加"add-new-model"子命令及其参数
        add_new_model_parser = parser.add_parser("add-new-model")
        add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
        add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
        add_new_model_parser.add_argument(
            "--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
        )
        # 设置默认的命令处理函数为add_new_model_command_factory
        add_new_model_parser.set_defaults(func=add_new_model_command_factory)

    # 初始化方法,设置命令的属性
    def __init__(self, testing: bool, testing_file: str, path=None, *args):
        self._testing = testing
        self._testing_file = testing_file
        self._path = path

.\commands\add_new_model_like.py

# 导入 difflib 模块,用于生成文本差异的比较结果
import difflib
# 导入 json 模块,用于处理 JSON 数据的编解码
import json
# 导入 os 模块,提供了与操作系统交互的功能
import os
# 导入 re 模块,用于支持正则表达式的操作
import re
# 从 argparse 模块中导入 ArgumentParser 和 Namespace 类,用于解析命令行参数
from argparse import ArgumentParser, Namespace
# 从 dataclasses 模块中导入 dataclass 装饰器,用于简化定义数据类
from dataclasses import dataclass
# 从 datetime 模块中导入 date 类,用于处理日期信息
from datetime import date
# 从 itertools 模块中导入 chain 函数,用于将多个迭代器连接在一起
from itertools import chain
# 从 pathlib 模块中导入 Path 类,用于处理文件路径
from pathlib import Path
# 从 typing 模块中导入各种类型提示,用于静态类型检查
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union

# 导入 yaml 模块,用于处理 YAML 格式的数据
import yaml

# 从 ..models 中导入 auto 模块,可能为自动生成的模型模块
from ..models import auto as auto_module
# 从 ..models.auto.configuration_auto 中导入 model_type_to_module_name 函数,用于获取模型类型对应的模块名称
from ..models.auto.configuration_auto import model_type_to_module_name
# 从 ..utils 中导入 is_flax_available、is_tf_available、is_torch_available、logging 函数和类
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
# 从 当前目录 的 BaseTransformersCLICommand 模块中导入全部内容
from . import BaseTransformersCLICommand

# 使用 logging 模块获取当前模块的 logger 对象,用于记录日志
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 获取当前年份
CURRENT_YEAR = date.today().year
# 获取 Transformers 模块所在的路径
TRANSFORMERS_PATH = Path(__file__).parent.parent
# 获取代码库的根路径
REPO_PATH = TRANSFORMERS_PATH.parent.parent

@dataclass
class ModelPatterns:
    """
    Holds the basic information about a new model for the add-new-model-like command.
    """
    # 这是一个数据类,用于存储用于 add-new-model-like 命令的新模型的基本信息
    # 函数签名,定义了一个函数,用于初始化模型相关的各种参数和选项
    Args:
        model_name (`str`): 模型名称。
        checkpoint (`str`): 用于文档示例的检查点。
        model_type (`str`, *optional*):
            模型类型,内部库中使用的标识符,如 `bert` 或 `xlm-roberta`。默认为 `model_name` 的小写形式,空格用短横线(-)替换。
        model_lower_cased (`str`, *optional*):
            模型名称的小写形式,用于模块名称或函数名称。默认为 `model_name` 的小写形式,空格和短横线都替换为下划线。
        model_camel_cased (`str`, *optional*):
            模型名称的驼峰式命名形式,用于类名。默认为 `model_name` 的驼峰式命名(考虑空格和短横线都作为单词分隔符)。
        model_upper_cased (`str`, *optional*):
            模型名称的大写形式,用于常量名称。默认为 `model_name` 的大写形式,空格和短横线都替换为下划线。
        config_class (`str`, *optional*):
            与此模型关联的配置类。默认为 `"{model_camel_cased}Config"`。
        tokenizer_class (`str`, *optional*):
            与此模型关联的分词器类(对于不使用分词器的模型,请将其保留为 `None`)。
        image_processor_class (`str`, *optional*):
            与此模型关联的图像处理器类(对于不使用图像处理器的模型,请将其保留为 `None`)。
        feature_extractor_class (`str`, *optional*):
            与此模型关联的特征提取器类(对于不使用特征提取器的模型,请将其保留为 `None`)。
        processor_class (`str`, *optional*):
            与此模型关联的处理器类(对于不使用处理器的模型,请将其保留为 `None`)。
    # 在对象初始化完成后执行的方法,用于设置默认属性
    def __post_init__(self):
        # 如果未指定模型类型,则根据模型名称生成一个小写的类型名称
        if self.model_type is None:
            self.model_type = self.model_name.lower().replace(" ", "-")
        
        # 如果未指定小写模型名称,则根据模型名称生成一个小写且用下划线替换空格和破折号的名称
        if self.model_lower_cased is None:
            self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
        
        # 如果未指定驼峰式模型名称,则按照一定规则生成驼峰式的模型名称
        if self.model_camel_cased is None:
            # 将模型名称按照空格和破折号拆分成单词列表
            words = self.model_name.split(" ")
            words = list(chain(*[w.split("-") for w in words]))
            # 将每个单词的首字母大写,其余字母小写
            words = [w[0].upper() + w[1:] for w in words]
            self.model_camel_cased = "".join(words)
        
        # 如果未指定大写模型名称,则生成一个大写且用下划线替换空格和破折号的名称
        if self.model_upper_cased is None:
            self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
        
        # 如果未指定配置类名称,则根据驼峰式模型名称生成一个默认的配置类名称
        if self.config_class is None:
            self.config_class = f"{self.model_camel_cased}Config"
ATTRIBUTE_TO_PLACEHOLDER = {
    "config_class": "[CONFIG_CLASS]",  # 属性到占位符的映射字典,用于标记配置类
    "tokenizer_class": "[TOKENIZER_CLASS]",  # 标记标记器类
    "image_processor_class": "[IMAGE_PROCESSOR_CLASS]",  # 标记图像处理器类
    "feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",  # 标记特征提取器类
    "processor_class": "[PROCESSOR_CLASS]",  # 标记处理器类
    "checkpoint": "[CHECKPOINT]",  # 标记检查点
    "model_type": "[MODEL_TYPE]",  # 标记模型类型
    "model_upper_cased": "[MODEL_UPPER_CASED]",  # 标记大写模型名称
    "model_camel_cased": "[MODEL_CAMELCASED]",  # 标记驼峰式模型名称
    "model_lower_cased": "[MODEL_LOWER_CASED]",  # 标记小写模型名称
    "model_name": "[MODEL_NAME]",  # 标记模型名称
}


def is_empty_line(line: str) -> bool:
    """
    Determines whether a line is empty or not.
    判断一行是否为空行。
    """
    return len(line) == 0 or line.isspace()


def find_indent(line: str) -> int:
    """
    Returns the number of spaces that start a line indent.
    返回一行开头的空格数,即缩进量。
    """
    search = re.search(r"^(\s*)(?:\S|$)", line)
    if search is None:
        return 0
    return len(search.groups()[0])


def parse_module_content(content: str) -> List[str]:
    """
    Parse the content of a module in the list of objects it defines.

    Args:
        content (`str`): The content to parse
        要解析的模块内容。

    Returns:
        `List[str]`: The list of objects defined in the module.
        返回模块定义的对象列表。
    """
    objects = []
    current_object = []
    lines = content.split("\n")
    end_markers = [")", "]", "}", '"""']  # 结束标记列表

    for line in lines:
        is_valid_object = len(current_object) > 0
        if is_valid_object and len(current_object) == 1:
            is_valid_object = not current_object[0].startswith("# Copied from")
        if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
            if line in end_markers:
                current_object.append(line)
                objects.append("\n".join(current_object))
                current_object = []
            else:
                objects.append("\n".join(current_object))
                current_object = [line]
        else:
            current_object.append(line)

    if len(current_object) > 0:
        objects.append("\n".join(current_object))

    return objects


def extract_block(content: str, indent_level: int = 0) -> str:
    """
    Return the first block in `content` with the indent level `indent_level`.

    The first line in `content` should be indented at `indent_level` level, otherwise an error will be thrown.

    This method will immediately stop the search when a (non-empty) line with indent level less than `indent_level` is
    encountered.

    Args:
        content (`str`): The content to parse
        indent_level (`int`, *optional*, default to 0): The indent level of the blocks to search for

    Returns:
        `str`: The first block in `content` with the indent level `indent_level`.
    返回在`content`中具有缩进级别`indent_level`的第一个块。

    Raises:
        ValueError: If the content does not start with the specified indent level.
        如果内容不以指定的缩进级别开头,则引发 ValueError 异常。
    """
    current_object = []
    lines = content.split("\n")
    # 结束标记列表,用于判断对象结尾的可能字符
    end_markers = [")", "]", "}", '"""']

    # 遍历每一行代码
    for idx, line in enumerate(lines):
        # 如果是第一行且缩进级别大于0,且不是空行,并且第一行的缩进级别与指定的缩进级别不符合,则抛出数值错误
        if idx == 0 and indent_level > 0 and not is_empty_line(line) and find_indent(line) != indent_level:
            raise ValueError(
                f"When `indent_level > 0`, the first line in `content` should have indent level {indent_level}. Got "
                f"{find_indent(line)} instead."
            )

        # 如果当前行的缩进级别小于指定的缩进级别,并且不是空行,则退出循环
        if find_indent(line) < indent_level and not is_empty_line(line):
            break

        # 判断是否为对象的结尾
        is_valid_object = len(current_object) > 0
        if (
            not is_empty_line(line)                          # 不是空行
            and not line.endswith(":")                       # 不是以冒号结尾
            and find_indent(line) == indent_level            # 缩进级别与指定的缩进级别相同
            and is_valid_object                              # 当前对象非空
        ):
            # 如果当前行的左边去除空白后在结束标记列表中,则将该行添加到当前对象中
            if line.lstrip() in end_markers:
                current_object.append(line)
            # 返回当前对象的字符串表示形式
            return "\n".join(current_object)
        else:
            # 将当前行添加到当前对象中
            current_object.append(line)

    # 添加最后一个对象
    if len(current_object) > 0:
        return "\n".join(current_object)
def add_content_to_text(
    text: str,
    content: str,
    add_after: Optional[Union[str, Pattern]] = None,
    add_before: Optional[Union[str, Pattern]] = None,
    exact_match: bool = False,
) -> str:
    """
    A utility to add some content inside a given text.

    Args:
       text (`str`): The text in which we want to insert some content.
       content (`str`): The content to add.
       add_after (`str` or `Pattern`):
           The pattern to test on a line of `text`, the new content is added after the first instance matching it.
       add_before (`str` or `Pattern`):
           The pattern to test on a line of `text`, the new content is added before the first instance matching it.
       exact_match (`bool`, *optional*, defaults to `False`):
           A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
           otherwise, if `add_after`/`add_before` is present in the line.

    <Tip warning={true}>

    The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.

    </Tip>

    Returns:
        `str`: The text with the new content added if a match was found.
    """
    # 检查是否同时提供了 `add_after` 和 `add_before` 参数
    if add_after is None and add_before is None:
        raise ValueError("You need to pass either `add_after` or `add_before`")
    if add_after is not None and add_before is not None:
        raise ValueError("You can't pass both `add_after` or `add_before`")
    
    # 根据参数设置要匹配的模式
    pattern = add_after if add_before is None else add_before

    def this_is_the_line(line):
        # 检查当前行是否符合模式
        if isinstance(pattern, Pattern):
            return pattern.search(line) is not None
        elif exact_match:
            return pattern == line
        else:
            return pattern in line

    new_lines = []
    # 遍历文本的每一行
    for line in text.split("\n"):
        # 如果当前行符合条件
        if this_is_the_line(line):
            # 根据参数决定添加内容的位置
            if add_before is not None:
                new_lines.append(content)
            new_lines.append(line)
            if add_after is not None:
                new_lines.append(content)
        else:
            # 如果不符合条件,直接将当前行添加到新的文本列表中
            new_lines.append(line)

    # 将新的文本列表合并为一个字符串并返回
    return "\n".join(new_lines)


def add_content_to_file(
    file_name: Union[str, os.PathLike],
    content: str,
    add_after: Optional[Union[str, Pattern]] = None,
    add_before: Optional[Union[str, Pattern]] = None,
    exact_match: bool = False,
):
    """
    A utility to add some content inside a given file.
    
    <Tip warning={true}>

    The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.

    </Tip>
    """
    # 打开指定文件以读取其内容,文件名由参数 `file_name` 指定,使用 UTF-8 编码
    with open(file_name, "r", encoding="utf-8") as f:
        # 将文件全部内容读取到 `old_content` 变量中
        old_content = f.read()
    
    # 调用函数 `add_content_to_text`,将 `content` 添加到 `old_content` 中的指定位置
    new_content = add_content_to_text(
        old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
    )
    
    # 以写入模式打开文件 `file_name`,使用 UTF-8 编码
    with open(file_name, "w", encoding="utf-8") as f:
        # 将处理过的 `new_content` 写入文件
        f.write(new_content)
def replace_model_patterns(
    text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
) -> Tuple[str, str]:
    """
    Replace all patterns present in a given text.

    Args:
        text (`str`): The text to treat.
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.

    Returns:
        `Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
    """
    # 顺序至关重要,因为我们将按照此顺序检查和替换。例如,配置可能包含驼峰命名,但将在之前处理。
    attributes_to_check = ["config_class"]

    # 添加相关的预处理类
    for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
        # 如果旧模型和新模型都有这个属性,则添加到检查列表中
        if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
            attributes_to_check.append(attr)

    # 特殊情况:checkpoint 和 model_type
    if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
        attributes_to_check.append("checkpoint")
    if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
        attributes_to_check.append("model_type")
    else:
        # 在文本中用正则表达式替换旧模型类型为占位符"[MODEL_TYPE]"
        text = re.sub(
            rf'(\s*)model_type = "{old_model_patterns.model_type}"',
            r'\1model_type = "[MODEL_TYPE]"',
            text,
        )

    # 特殊情况:当旧模型的大写驼峰名称与小写驼峰名称相同时(例如对于GPT2),但新模型不同时,需要特殊处理
    if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
        old_model_value = old_model_patterns.model_upper_cased
        # 如果在文本中找到了旧模型大写驼峰名称的特定格式,用新的大写驼峰占位符替换
        if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
            text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
    else:
        attributes_to_check.append("model_upper_cased")

    # 添加其他需要检查的属性
    attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])

    # 替换每个属性为其占位符
    for attr in attributes_to_check:
        text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])

    # 最后,用新值替换占位符
    replacements = []
    for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
        # 如果文本中包含占位符,将其替换为新模型对应的属性值
        if placeholder in text:
            replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
            text = text.replace(placeholder, getattr(new_model_patterns, attr))

    # 如果存在两个不一致的替换,则不返回任何值(例如:GPT2->GPT_NEW 和 GPT2->GPTNew)
    # 将 replacements 列表中的第一个元素(旧值)抽取出来形成列表 old_replacement_values
    old_replacement_values = [old for old, new in replacements]
    # 检查 old_replacement_values 是否有重复的元素,如果有,则返回原始文本和空字符串
    if len(set(old_replacement_values)) != len(old_replacement_values):
        return text, ""
    
    # 简化 replacements 列表中的元素,并重新赋值给 replacements
    replacements = simplify_replacements(replacements)
    # 将简化后的 replacements 列表转换为形如 "old->new" 的字符串列表
    replacements = [f"{old}->{new}" for old, new in replacements]
    # 返回原始文本和用逗号连接的替换字符串列表
    return text, ",".join(replacements)
# 将给定的替换模式列表简化,确保没有不必要的模式。
def simplify_replacements(replacements):
    # 如果替换列表长度小于等于1,无需简化,直接返回原列表
    if len(replacements) <= 1:
        return replacements

    # 按照替换模式的长度排序,因为较短的模式可能会被较长的模式"隐含"
    replacements.sort(key=lambda x: len(x[0]))

    idx = 0
    # 遍历替换列表中的每一个模式
    while idx < len(replacements):
        old, new = replacements[idx]
        j = idx + 1
        # 再次遍历当前模式之后的所有模式
        while j < len(replacements):
            old_2, new_2 = replacements[j]
            # 如果当前模式可以被之后的模式"隐含",则移除之后的模式
            if old_2.replace(old, new) == new_2:
                replacements.pop(j)
            else:
                j += 1
        idx += 1

    return replacements


# 返回指定模块文件对应的模块名称
def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
    full_module_path = Path(module_file).absolute()
    module_parts = full_module_path.with_suffix("").parts

    idx = len(module_parts) - 1
    # 从文件路径的末尾开始查找第一个名为"transformers"的部分
    while idx >= 0 and module_parts[idx] != "transformers":
        idx -= 1
    # 如果未找到"transformers",抛出数值错误
    if idx < 0:
        raise ValueError(f"{module_file} is not a transformers module.")

    return ".".join(module_parts[idx:])


# 特殊模式映射字典,将特定字符串替换为相应的类别名称
SPECIAL_PATTERNS = {
    "_CHECKPOINT_FOR_DOC =": "checkpoint",
    "_CONFIG_FOR_DOC =": "config_class",
    "_TOKENIZER_FOR_DOC =": "tokenizer_class",
    "_IMAGE_PROCESSOR_FOR_DOC =": "image_processor_class",
    "_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
    "_PROCESSOR_FOR_DOC =": "processor_class",
}


# 正则表达式对象,用于匹配类和函数的定义
_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)


# 从对象中移除指定的属性或方法
def remove_attributes(obj, target_attr):
    lines = obj.split(os.linesep)

    target_idx = None
    for idx, line in enumerate(lines):
        # 查找赋值语句
        if line.lstrip().startswith(f"{target_attr} = "):
            target_idx = idx
            break
        # 查找函数或方法的定义
        elif line.lstrip().startswith(f"def {target_attr}("):
            target_idx = idx
            break

    # 如果未找到目标属性或方法,直接返回原始对象
    if target_idx is None:
        return obj

    line = lines[target_idx]
    indent_level = find_indent(line)
    # 前向传递以找到块的结束位置(包括空行)
    parsed = extract_block("\n".join(lines[target_idx:]), indent_level)
    # 计算解析后的文本以换行符分割后的行数
    num_lines = len(parsed.split("\n"))
    # 将目标索引处后面的 num_lines 行设为 None,表示删除这些行
    for idx in range(num_lines):
        lines[target_idx + idx] = None

    # 逆向遍历以找到注释或装饰器的行
    for idx in range(target_idx - 1, -1, -1):
        line = lines[idx]
        # 如果行以 '#' 或 '@' 开头,并且缩进等级与目标相同,则将该行设为 None
        if (line.lstrip().startswith("#") or line.lstrip().startswith("@")) and find_indent(line) == indent_level:
            lines[idx] = None
        else:
            # 如果不满足条件,退出循环
            break

    # 将列表中非 None 的行连接起来,使用操作系统的换行符分隔
    new_obj = os.linesep.join([x for x in lines if x is not None])

    # 返回处理后的新文本对象
    return new_obj
    """
    Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.

    Args:
        module_file (`str` or `os.PathLike`): Path to the module to duplicate.
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
        dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
        add_copied_from (`bool`, *optional*, defaults to `True`):
            Whether or not to add `# Copied from` statements in the duplicated module.
    """
    # If `dest_file` is not provided, generate it based on `module_file` and replace old model pattern with new one
    if dest_file is None:
        dest_file = str(module_file).replace(
            old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
        )

    # Open the existing module file for reading
    with open(module_file, "r", encoding="utf-8") as f:
        content = f.read()

    # Update the year in any copyright statements to the current year
    content = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)

    # Parse the module content into individual objects (functions, classes, etc.)
    objects = parse_module_content(content)

    # Loop through each object in the module content
    new_objects = []
    for obj in objects:
        # Handle special case for `PRETRAINED_CONFIG_ARCHIVE_MAP` assignment
        if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj:
            # docstyle-ignore
            # Replace with a new entry specific to the new model patterns
            obj = (
                f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = "
                + "{"
                + f"""
    "{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json",
"""
                + "}\n"
            )
            new_objects.append(obj)
            continue
        # Handle special case for `PRETRAINED_MODEL_ARCHIVE_LIST` assignment
        elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj:
            if obj.startswith("TF_"):
                prefix = "TF_"
            elif obj.startswith("FLAX_"):
                prefix = "FLAX_"
            else:
                prefix = ""
            # docstyle-ignore
            # Replace with a new list including the new model checkpoint and a reference URL
            obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "{new_model_patterns.checkpoint}",
    # See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type}
]
"""
        # Collect the updated object into the list of new objects
        new_objects.append(obj)
def filter_framework_files(
    files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None
) -> List[Union[str, os.PathLike]]:
    """
    Filter a list of files to only keep the ones corresponding to a list of frameworks.

    Args:
        files (`List[Union[str, os.PathLike]]`): The list of files to filter.
        frameworks (`List[str]`, *optional*): The list of allowed frameworks.

    Returns:
        `List[Union[str, os.PathLike]]`: The list of filtered files.
    """
    # 如果未提供frameworks参数,则使用默认的框架列表
    if frameworks is None:
        frameworks = get_default_frameworks()

    # 创建一个字典来存储每个框架对应的文件
    framework_to_file = {}
    # 创建一个空列表来存储不属于任何框架的文件
    others = []
    # 遍历每个文件
    for f in files:
        # 将文件路径分割成组成文件名的部分
        parts = Path(f).name.split("_")
        # 如果文件名中不包含"modeling",将其添加到others列表中并跳过
        if "modeling" not in parts:
            others.append(f)
            continue
        # 根据文件名中的关键词判断框架类型,并将文件路径添加到相应框架的条目中
        if "tf" in parts:
            framework_to_file["tf"] = f
        elif "flax" in parts:
            framework_to_file["flax"] = f
        else:
            framework_to_file["pt"] = f

    # 返回符合给定框架列表的文件路径列表,以及不属于任何框架的文件路径列表
    return [framework_to_file[f] for f in frameworks if f in framework_to_file] + others


def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]:
    """
    Retrieves all the files associated to a model.

    Args:
        model_type (`str`): A valid model type (like "bert" or "gpt2")
        frameworks (`List[str]`, *optional*):
            If passed, will only keep the model files corresponding to the passed frameworks.

    Returns:
        `Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys:
        - **doc_file** -- The documentation file for the model.
        - **model_files** -- All the files in the model module.
        - **module_name** -- The name of the module corresponding to the model type.
        - **test_files** -- The test files for the model.
    """
    # Convert model type to its corresponding module name
    module_name = model_type_to_module_name(model_type)

    # Define the path to the model module within TRANSFORMERS_PATH
    model_module = TRANSFORMERS_PATH / "models" / module_name
    # List all Python files within the model module
    model_files = list(model_module.glob("*.py"))
    # Filter model files based on specified frameworks, if provided
    model_files = filter_framework_files(model_files, frameworks=frameworks)

    # Define the path to the documentation file for the model
    doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{model_type}.md"

    # Basic pattern for test files related to the model module
    test_files = [
        f"test_modeling_{module_name}.py",
        f"test_modeling_tf_{module_name}.py",
        f"test_modeling_flax_{module_name}.py",
        f"test_tokenization_{module_name}.py",
        f"test_image_processing_{module_name}.py",
        f"test_feature_extraction_{module_name}.py",
        f"test_processor_{module_name}.py",
    ]
    # Filter test files based on specified frameworks, if provided
    test_files = filter_framework_files(test_files, frameworks=frameworks)
    # Construct full paths for test files within the tests/models/module_name directory
    test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
    # Filter out non-existing test files
    test_files = [f for f in test_files if f.exists()]

    # Return a dictionary containing paths to relevant files and module name
    return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
# 编译正则表达式,用于匹配文档字符串中的_CHECKPOINT_FOR_DOC赋值语句
_re_checkpoint_for_doc = re.compile(r"^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)

# 查找给定模型类型的文档字符串中使用的模型检查点
def find_base_model_checkpoint(
    model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None
) -> str:
    """
    Finds the model checkpoint used in the docstrings for a given model.

    Args:
        model_type (`str`): A valid model type (like "bert" or "gpt2")
        model_files (`Dict[str, Union[Path, List[Path]]`, *optional*):
            The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.

    Returns:
        `str`: The checkpoint used.
    """
    # 如果未提供模型文件,调用函数获取模型文件列表
    if model_files is None:
        model_files = get_model_files(model_type)
    
    # 从模型文件列表中获取模型文件
    module_files = model_files["model_files"]
    
    # 遍历模型文件列表
    for fname in module_files:
        # 如果文件名中不包含"modeling",跳过该文件
        if "modeling" not in str(fname):
            continue
        
        # 打开文件并读取内容
        with open(fname, "r", encoding="utf-8") as f:
            content = f.read()
            # 在文件内容中搜索_CHECKPOINT_FOR_DOC赋值语句
            if _re_checkpoint_for_doc.search(content) is not None:
                # 提取检查点值,并移除可能的引号
                checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
                checkpoint = checkpoint.replace('"', "")
                checkpoint = checkpoint.replace("'", "")
                return checkpoint

    # 如果未找到_CHECKPOINT_FOR_DOC赋值语句,返回空字符串作为默认值
    # TODO: 如果所有的模型文件中都找不到_CHECKPOINT_FOR_DOC,可能需要一些备用方案
    return ""


# 返回当前环境中安装的默认框架列表(如PyTorch、TensorFlow、Flax)
def get_default_frameworks():
    """
    Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
    """
    frameworks = []
    if is_torch_available():  # 如果PyTorch可用,将"pt"添加到框架列表中
        frameworks.append("pt")
    if is_tf_available():  # 如果TensorFlow可用,将"tf"添加到框架列表中
        frameworks.append("tf")
    if is_flax_available():  # 如果Flax可用,将"flax"添加到框架列表中
        frameworks.append("flax")
    return frameworks


# 编译正则表达式,用于匹配模型名称映射中的模型类名
_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")

# 根据给定的模型类型和框架列表,检索相关的模型类
def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]:
    """
    Retrieve the model classes associated to a given model.

    Args:
        model_type (`str`): A valid model type (like "bert" or "gpt2")
        frameworks (`List[str]`, *optional*):
            The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
            the classes returned.

    Returns:
        `Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to
        that framework as values.
    """
    # 如果未提供框架列表,使用默认框架列表
    if frameworks is None:
        frameworks = get_default_frameworks()
    
    # 定义模块字典,包含每种框架对应的模型自动加载模块
    modules = {
        "pt": auto_module.modeling_auto if is_torch_available() else None,
        "tf": auto_module.modeling_tf_auto if is_tf_available() else None,
        "flax": auto_module.modeling_flax_auto if is_flax_available() else None,
    }
    
    # 初始化模型类字典
    model_classes = {}
    # 遍历给定的框架列表
    for framework in frameworks:
        # 初始化一个空列表来存放新的模型类
        new_model_classes = []
        # 检查当前框架是否已安装模块,若未安装则抛出数值错误异常
        if modules[framework] is None:
            raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
        # 获取当前框架模块中所有包含模型映射的属性名列表
        model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
        # 遍历当前框架的模型映射名列表
        for model_mapping_name in model_mappings:
            # 根据模型映射名获取对应的模型映射对象
            model_mapping = getattr(modules[framework], model_mapping_name)
            # 检查给定的模型类型是否在当前模型映射中
            if model_type in model_mapping:
                # 将符合条件的模型类添加到新模型类列表中
                new_model_classes.append(model_mapping[model_type])

        # 如果新模型类列表不为空
        if len(new_model_classes) > 0:
            # 去除重复的模型类,并将结果存入模型类字典中
            model_classes[framework] = list(set(new_model_classes))

    # 返回最终的模型类字典
    return model_classes
    """
    Retrieves all the information from a given model_type.

    Args:
        model_type (`str`): A valid model type (like "bert" or "gpt2")
        frameworks (`List[str]`, *optional*):
            If passed, will only keep the info corresponding to the passed frameworks.

    Returns:
        `Dict`: A dictionary with the following keys:
        - **frameworks** (`List[str]`): The list of frameworks that back this model type.
        - **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type.
        - **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type.
        - **model_patterns** (`ModelPatterns`): The various patterns for the model.
    """
    # Check if the provided model_type exists in the MODEL_NAMES_MAPPING
    if model_type not in auto_module.MODEL_NAMES_MAPPING:
        raise ValueError(f"{model_type} is not a valid model type.")

    # Retrieve the actual model name from the mapping
    model_name = auto_module.MODEL_NAMES_MAPPING[model_type]

    # Retrieve the configuration class name for the given model type
    config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]

    # Retrieve the archive map if available for the given model type
    archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None)

    # Retrieve the tokenizer classes if available for the given model type
    if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
        tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
        tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
    else:
        tokenizer_class = None

    # Retrieve the image processor class if available for the given model type
    image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)

    # Retrieve the feature extractor class if available for the given model type
    feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)

    # Retrieve the processor class if available for the given model type
    processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)

    # Retrieve the files associated with the given model type
    model_files = get_model_files(model_type, frameworks=frameworks)

    # Create a camel-cased version of the config class name without "Config"
    model_camel_cased = config_class.replace("Config", "")

    # Initialize an empty list to store available frameworks
    available_frameworks = []

    # Iterate through the model files and identify the frameworks they belong to
    for fname in model_files["model_files"]:
        if "modeling_tf" in str(fname):
            available_frameworks.append("tf")
        elif "modeling_flax" in str(fname):
            available_frameworks.append("flax")
        elif "modeling" in str(fname):
            available_frameworks.append("pt")

    # If frameworks parameter is None, retrieve default frameworks
    if frameworks is None:
        frameworks = get_default_frameworks()

    # Filter frameworks to include only those available in the model files
    frameworks = [f for f in frameworks if f in available_frameworks]

    # Retrieve model classes based on the model type and selected frameworks
    model_classes = retrieve_model_classes(model_type, frameworks=frameworks)

    # Retrieve model upper-cased name from the constant name of the pretrained archive map, if available
    if archive_map is None:
        model_upper_cased = model_camel_cased.upper()
    # 如果archive_map不包含"PRETRAINED",则按下面的逻辑处理
    else:
        # 使用下划线分割archive_map,并初始化索引
        parts = archive_map.split("_")
        idx = 0
        # 循环直到找到"PRETRAINED"或者到达末尾
        while idx < len(parts) and parts[idx] != "PRETRAINED":
            idx += 1
        # 如果找到了"PRETRAINED"
        if idx < len(parts):
            # 将"PRETRAINED"之前的部分连接起来作为model_upper_cased
            model_upper_cased = "_".join(parts[:idx])
        else:
            # 如果没有找到"PRETRAINED",则使用model_camel_cased的大写形式作为model_upper_cased
            model_upper_cased = model_camel_cased.upper()

    # 创建一个ModelPatterns对象,用于存储模型相关的配置和信息
    model_patterns = ModelPatterns(
        model_name,
        # 调用函数find_base_model_checkpoint找到模型的基础检查点
        checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
        model_type=model_type,
        model_camel_cased=model_camel_cased,
        model_lower_cased=model_files["module_name"],
        model_upper_cased=model_upper_cased,
        config_class=config_class,
        tokenizer_class=tokenizer_class,
        image_processor_class=image_processor_class,
        feature_extractor_class=feature_extractor_class,
        processor_class=processor_class,
    )

    # 返回一个包含各种模型相关信息的字典
    return {
        "frameworks": frameworks,
        "model_classes": model_classes,
        "model_files": model_files,
        "model_patterns": model_patterns,
    }
    # 打开指定路径的初始化文件以供处理
    with open(init_file, "r", encoding="utf-8") as f:
        # 读取整个文件内容
        content = f.read()

    # 将文件内容按行分割成列表
    lines = content.split("\n")
    # 初始化一个空列表,用于存储处理后的新行
    new_lines = []
    # 初始化索引变量,用于迭代处理每一行
    idx = 0
    # 循环处理每一行代码,直到处理完所有行
    while idx < len(lines):
        # 在 try-except-else 块中处理条件导入
        if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
            # 移除前面的 `try:` 语句
            new_lines.pop()
            idx += 1
            # 找到下一个 `else:` 之前的空行或者非空行
            while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
                idx += 1
            idx += 1
            # 确定缩进级别
            indent = find_indent(lines[idx])
            # 继续处理直到缩进小于当前缩进或者是空行
            while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
                idx += 1
        # 移除来自 utils 的导入
        elif re_is_xxx_available.search(lines[idx]) is not None:
            line = lines[idx]
            # 替换需要移除的 framework 导入
            for framework in to_remove:
                line = line.replace(f", is_{framework}_available", "")
                line = line.replace(f"is_{framework}_available, ", "")
                line = line.replace(f"is_{framework}_available,", "")
                line = line.replace(f"is_{framework}_available", "")

            # 如果替换后的行不为空,则添加到新行列表中
            if len(line.strip()) > 0:
                new_lines.append(line)
            idx += 1
        # 否则保留该行,除非是关于 tokenizer 导入且不需要保留的情况
        elif keep_processing or (
            re.search(r'^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
            and re.search(r"^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
            is None
        ):
            new_lines.append(lines[idx])
            idx += 1
        else:
            idx += 1

    # 将处理后的新行写入到指定的初始化文件中
    with open(init_file, "w", encoding="utf-8") as f:
        f.write("\n".join(new_lines))
# 打开 Transformers 库的 __init__.py 文件以进行读取操作,使用 UTF-8 编码
with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
    # 读取整个文件内容并存储在变量 content 中
    content = f.read()

# 将文件内容按行分割成列表 lines
lines = content.split("\n")
# 初始化索引变量 idx 为 0
idx = 0
# 初始化空列表 new_lines,用于存储处理后的新行内容
new_lines = []
# 初始化 framework 变量为 None,用于存储框架名称
framework = None
    # 当前行号小于文本行数时,继续循环处理文本行
    while idx < len(lines):
        # 新的框架标志置为 False
        new_framework = False
        # 如果当前行不是空行且缩进为0,则将框架设置为 None
        if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
            framework = None
        # 如果当前行左侧去除空白后以特定字符串开头,则确定框架类型为 "pt",并设置新框架标志为 True
        elif lines[idx].lstrip().startswith("if not is_torch_available"):
            framework = "pt"
            new_framework = True
        # 如果当前行左侧去除空白后以特定字符串开头,则确定框架类型为 "tf",并设置新框架标志为 True
        elif lines[idx].lstrip().startswith("if not is_tf_available"):
            framework = "tf"
            new_framework = True
        # 如果当前行左侧去除空白后以特定字符串开头,则确定框架类型为 "flax",并设置新框架标志为 True
        elif lines[idx].lstrip().startswith("if not is_flax_available"):
            framework = "flax"
            new_framework = True
    
        # 如果是新框架,则需要跳过直到 else: 块以找到导入位置
        if new_framework:
            while lines[idx].strip() != "else:":
                new_lines.append(lines[idx])
                idx += 1
    
        # 如果框架不是所需的框架且框架列表不为空且当前框架不在列表中,则跳过当前行
        if framework is not None and frameworks is not None and framework not in frameworks:
            new_lines.append(lines[idx])
            idx += 1
        # 如果当前行包含旧模型模式的模型引用,则收集整个代码块
        elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
            block = [lines[idx]]
            indent = find_indent(lines[idx])
            idx += 1
            # 收集整个缩进块
            while find_indent(lines[idx]) > indent:
                block.append(lines[idx])
                idx += 1
            # 如果当前行的内容是特定列表中的一员,则也添加到块中
            if lines[idx].strip() in [")", "]", "],"]:
                block.append(lines[idx])
                idx += 1
            block = "\n".join(block)
            new_lines.append(block)
    
            add_block = True
            # 如果不需要处理,则只保留非空的处理类
            if not with_processing:
                processing_classes = [
                    old_model_patterns.tokenizer_class,
                    old_model_patterns.image_processor_class,
                    old_model_patterns.feature_extractor_class,
                    old_model_patterns.processor_class,
                ]
                processing_classes = [c for c in processing_classes if c is not None]
                # 遍历处理类列表,将其从块中移除
                for processing_class in processing_classes:
                    block = block.replace(f' "{processing_class}",', "")
                    block = block.replace(f', "{processing_class}"', "")
                    block = block.replace(f" {processing_class},", "")
                    block = block.replace(f", {processing_class}", "")
                    # 如果块中仍包含处理类,则不添加此块
                    if processing_class in block:
                        add_block = False
            # 如果需要添加块,则将替换后的模型模式块添加到新行列表中
            if add_block:
                new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
        else:
            # 否则,将当前行直接添加到新行列表中
            new_lines.append(lines[idx])
            idx += 1
    
    # 将新行列表写入到 "__init__.py" 文件中
    with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
        f.write("\n".join(new_lines))
# 将模型的标记器添加到自动模块的相关映射中
def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
    """
    Add a tokenizer to the relevant mappings in the auto module.

    Args:
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
    """
    # 如果旧模型或新模型的标记器类为None,则返回
    if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
        return

    # 打开自动模块中的tokenization_auto.py文件,以utf-8编码读取其内容
    with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
        content = f.read()

    # 将文件内容按行分割为列表
    lines = content.split("\n")
    idx = 0
    # 首先定位到TOKENIZER_MAPPING_NAMES块
    while not lines[idx].startswith("    TOKENIZER_MAPPING_NAMES = OrderedDict("):
        idx += 1
    idx += 1

    # 定位到TOKENIZER_MAPPING块的结尾
    while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
        # 如果tokenizer块在一行上定义,则以"),结束"
        if lines[idx].endswith(","):
            block = lines[idx]
        # 否则,tokenizer块跨多行,直到找到"),结束"
        else:
            block = []
            while not lines[idx].startswith("            ),"):
                block.append(lines[idx])
                idx += 1
            block = "\n".join(block)
        idx += 1

        # 如果在该块中找到了旧模型类型和标记器类,则找到了旧模型的tokenizer块
        if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
            break

    # 将旧模型类型和标记器类替换为新模型类型和标记器类
    new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
    new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)

    # 构建新的文件内容行列表,包括更新后的tokenizer块
    new_lines = lines[:idx] + [new_block] + lines[idx:]

    # 将更新后的文件内容写回tokenization_auto.py文件
    with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
        f.write("\n".join(new_lines))


AUTO_CLASSES_PATTERNS = {
    "configuration_auto.py": [
        '        ("{model_type}", "{model_name}"),',
        '        ("{model_type}", "{config_class}"),',
        '        ("{model_type}", "{pretrained_archive_map}"),',
    ],
    "feature_extraction_auto.py": ['        ("{model_type}", "{feature_extractor_class}"),'],
    "image_processing_auto.py": ['        ("{model_type}", "{image_processor_class}"),'],
    "modeling_auto.py": ['        ("{model_type}", "{any_pt_class}"),'],
    "modeling_tf_auto.py": ['        ("{model_type}", "{any_tf_class}"),'],
    "modeling_flax_auto.py": ['        ("{model_type}", "{any_flax_class}"),'],
    "processing_auto.py": ['        ("{model_type}", "{processor_class}"),'],
}


def add_model_to_auto_classes(
    old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]]
):
    """
    Add a model to the relevant mappings in the auto module.
    
    Args:
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
        model_classes (`Dict[str, List[str]]`): A dictionary mapping auto module filenames to lists of model class names.
    """
    Args:
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
        model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented.
    """
    # 调用函数将旧模型模式中的所有分词器插入到新模型模式的自动模块中
    insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
# 模板文档字符串,用于生成新模型的概述性文档
DOC_OVERVIEW_TEMPLATE = """## Overview

The {model_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).

"""


def duplicate_doc_file(
    doc_file: Union[str, os.PathLike],
    old_model_patterns: ModelPatterns,
    new_model_patterns: ModelPatterns,
    dest_file: Optional[Union[str, os.PathLike]] = None,
    frameworks: Optional[List[str]] = None,
):
    """
    Duplicate a documentation file and adapts it for a new model.

    Args:
        module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
        dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
            Will default to the a file named `{new_model_patterns.model_type}.md` in the same folder as `module_file`.
        frameworks (`List[str]`, *optional*):
            If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
    """
    # 读取原始文档文件内容
    with open(doc_file, "r", encoding="utf-8") as f:
        content = f.read()

    # 更新版权信息为当前年份
    content = re.sub(r"<!--\s*Copyright (\d+)\s", f"<!--Copyright {CURRENT_YEAR} ", content)
    
    # 如果未提供特定框架列表,则使用默认框架列表
    if frameworks is None:
        frameworks = get_default_frameworks()
    
    # 如果未提供目标文件路径,则默认为与原文档文件同目录下,新模型类型命名的文件
    if dest_file is None:
        dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.md"

    # 解析文档内容为块。每个块对应一个部分/标题
    lines = content.split("\n")
    blocks = []
    current_block = []

    for line in lines:
        if line.startswith("#"):
            blocks.append("\n".join(current_block))
            current_block = [line]
        else:
            current_block.append(line)
    blocks.append("\n".join(current_block))

    new_blocks = []
    in_classes = False
    # 遍历输入的文本块列表
    for block in blocks:
        # 检查是否以版权声明开始,如果不是则添加到新的文本块列表中
        if not block.startswith("#"):
            new_blocks.append(block)
        # 检查是否为主标题,如果是则替换为新模型名称的标题
        elif re.search(r"^#\s+\S+", block) is not None:
            new_blocks.append(f"# {new_model_patterns.model_name}\n")
        # 检查是否进入类定义部分,根据旧模型配置类来确定
        elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]:
            # 标记已进入类定义部分,并添加文档概述模板及替换后的模型配置块
            in_classes = True
            new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name))
            new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
            new_blocks.append(new_block)
        # 处理在类定义部分的情况
        elif in_classes:
            in_classes = True
            # 获取当前文本块的标题,并提取类名
            block_title = block.split("\n")[0]
            block_class = re.search(r"^#+\s+(\S.*)$", block_title).groups()[0]
            new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)

            # 根据类名条件性地添加新的文本块
            if "Tokenizer" in block_class:
                # 仅在需要时添加标记器类
                if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
                    new_blocks.append(new_block)
            elif "ImageProcessor" in block_class:
                # 仅在需要时添加图像处理器类
                if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
                    new_blocks.append(new_block)
            elif "FeatureExtractor" in block_class:
                # 仅在需要时添加特征提取器类
                if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
                    new_blocks.append(new_block)
            elif "Processor" in block_class:
                # 仅在需要时添加处理器类
                if old_model_patterns.processor_class != new_model_patterns.processor_class:
                    new_blocks.append(new_block)
            elif block_class.startswith("Flax"):
                # 仅在所选框架中包含 Flax 模型时添加
                if "flax" in frameworks:
                    new_blocks.append(new_block)
            elif block_class.startswith("TF"):
                # 仅在所选框架中包含 TF 模型时添加
                if "tf" in frameworks:
                    new_blocks.append(new_block)
            elif len(block_class.split(" ")) == 1:
                # 仅在所选框架中包含 PyTorch 模型时添加
                if "pt" in frameworks:
                    new_blocks.append(new_block)
            else:
                new_blocks.append(new_block)

    # 将新的文本块列表写入目标文件
    with open(dest_file, "w", encoding="utf-8") as f:
        f.write("\n".join(new_blocks))
# 在文档目录中插入新模型的条目,与旧模型在同一部分。
def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
    """
    Insert the new model in the doc TOC, in the same section as the old model.

    Args:
        old_model_patterns (`ModelPatterns`): The patterns for the old model.
        new_model_patterns (`ModelPatterns`): The patterns for the new model.
    """
    # 指定文档目录文件路径
    toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
    # 打开并加载 YAML 格式的目录文件内容
    with open(toc_file, "r", encoding="utf8") as f:
        content = yaml.safe_load(f)

    # 定位到 API 文档的索引
    api_idx = 0
    while content[api_idx]["title"] != "API":
        api_idx += 1
    # 获取 API 文档下的各个部分
    api_doc = content[api_idx]["sections"]

    # 定位到 Models 部分的索引
    model_idx = 0
    while api_doc[model_idx]["title"] != "Models":
        model_idx += 1
    # 获取 Models 部分下的各个小节
    model_doc = api_doc[model_idx]["sections"]

    # 在目录中查找基础模型的位置
    old_model_type = old_model_patterns.model_type
    section_idx = 0
    while section_idx < len(model_doc):
        # 获取当前小节中的本地目录项列表
        sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
        # 如果旧模型的目录项在当前小节中,则跳出循环
        if f"model_doc/{old_model_type}" in sections:
            break
        section_idx += 1

    # 如果未找到旧模型的目录项,则输出警告信息并返回
    if section_idx == len(model_doc):
        old_model = old_model_patterns.model_name
        new_model = new_model_patterns.model_name
        print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
        return

    # 准备新模型的目录项信息
    toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
    # 将新模型的目录项添加到找到的旧模型所在的小节中
    model_doc[section_idx]["sections"].append(toc_entry)
    # 根据标题排序小节中的目录项
    model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
    # 更新 API 文档中的 Models 部分
    api_doc[model_idx]["sections"] = model_doc
    # 更新整体内容中的 API 文档
    content[api_idx]["sections"] = api_doc

    # 将更新后的内容重新写入目录文件
    with open(toc_file, "w", encoding="utf-8") as f:
        f.write(yaml.dump(content, allow_unicode=True))
    # 获取给定模型类型的相关信息,包括模型文件、模型模式等
    model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
    
    # 从模型信息中获取模型文件列表和旧模型模式
    model_files = model_info["model_files"]
    old_model_patterns = model_info["model_patterns"]
    
    # 如果有提供旧的检查点,则更新旧模型模式的检查点属性
    if old_checkpoint is not None:
        old_model_patterns.checkpoint = old_checkpoint
    
    # 检查旧模型模式的检查点属性是否为空,如果是则引发 ValueError
    if len(old_model_patterns.checkpoint) == 0:
        raise ValueError(
            "The old model checkpoint could not be recovered from the model type. Please pass it to the "
            "`old_checkpoint` argument."
        )
    
    # 初始化保持旧处理方式的标志为 True
    keep_old_processing = True
    
    # 检查特定处理属性(如图像处理类、特征提取器类、处理器类、分词器类)是否与新模型模式相同,若有不同则将标志设为 False
    for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
        if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
            keep_old_processing = False
    
    # 从模型信息中获取模型类别
    model_classes = model_info["model_classes"]
    
    # 1. 创建新模型的模块
    old_module_name = model_files["module_name"]
    module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased
    
    # 确保模块文件夹存在,如果不存在则创建
    os.makedirs(module_folder, exist_ok=True)
    
    # 根据保持旧处理方式的标志筛选要适应的文件列表
    files_to_adapt = model_files["model_files"]
    if keep_old_processing:
        files_to_adapt = [
            f
            for f in files_to_adapt
            if "tokenization" not in str(f)
            and "processing" not in str(f)
            and "feature_extraction" not in str(f)
            and "image_processing" not in str(f)
        ]
    
    # 再次确保模块文件夹存在,如果不存在则创建
    os.makedirs(module_folder, exist_ok=True)
    
    # 遍历要适应的文件列表,生成新的模块文件名并复制到目标位置
    for module_file in files_to_adapt:
        new_module_name = module_file.name.replace(
            old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
        )
        dest_file = module_folder / new_module_name
        duplicate_module(
            module_file,
            old_model_patterns,
            new_model_patterns,
            dest_file=dest_file,
            add_copied_from=add_copied_from and "modeling" in new_module_name,
        )
    
    # 清理模块的初始化文件,根据保持旧处理方式的标志更新处理类别
    clean_frameworks_in_init(
        module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing
    )
    
    # 2. 将新模型添加到模型包的初始化文件和主初始化文件中
    add_content_to_file(
        TRANSFORMERS_PATH / "models" / "__init__.py",
        f"    {new_model_patterns.model_lower_cased},",
        add_after=f"    {old_module_name},",
        exact_match=True,
    )
    add_model_to_main_init(
        old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing
    )
    
    # 3. 添加测试文件
    files_to_adapt = model_files["test_files"]
    if keep_old_processing:
        files_to_adapt = [
            f
            for f in files_to_adapt
            if "tokenization" not in str(f)
            and "processor" not in str(f)
            and "feature_extraction" not in str(f)
            and "image_processing" not in str(f)
        ]
    # 定义一个函数,用于禁用与指定文件相关的特定功能测试
    def disable_fx_test(filename: Path) -> bool:
        # 打开文件并读取其内容
        with open(filename) as fp:
            content = fp.read()
        # 使用正则表达式替换文件内容中的特定文本
        new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
        # 将修改后的内容写回到文件中
        with open(filename, "w") as fp:
            fp.write(new_content)
        # 返回值指示是否有内容被修改过
        return content != new_content

    # 初始化一个标志,用于追踪是否禁用了功能测试
    disabled_fx_test = False

    # 创建测试文件夹,如果不存在则创建
    tests_folder = REPO_PATH / "tests" / "models" / new_model_patterns.model_lower_cased
    os.makedirs(tests_folder, exist_ok=True)

    # 创建一个空的 __init__.py 文件
    with open(tests_folder / "__init__.py", "w"):
        pass

    # 遍历需要调整的文件列表
    for test_file in files_to_adapt:
        # 根据模式替换文件名中的旧模型名称为新模型名称
        new_test_file_name = test_file.name.replace(
            old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
        )
        # 构建目标文件的路径
        dest_file = test_file.parent.parent / new_model_patterns.model_lower_cased / new_test_file_name
        # 复制指定的测试文件到目标位置,并禁用功能测试
        duplicate_module(
            test_file,
            old_model_patterns,
            new_model_patterns,
            dest_file=dest_file,
            add_copied_from=False,
            attrs_to_remove=["pipeline_model_mapping", "is_pipeline_test_to_skip"],
        )
        # 更新功能测试禁用状态
        disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)

    # 如果有功能测试被禁用,则输出提示信息
    if disabled_fx_test:
        print(
            "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works"
            " for your new model."
        )

    # 将新模型添加到自动类中
    add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)

    # 添加文档文件
    doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
    duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
    # 在文档目录中插入新模型
    insert_model_in_doc_toc(old_model_patterns, new_model_patterns)

    # 如果旧模型类型与其检查点名称相同,输出警告信息
    if old_model_patterns.model_type == old_model_patterns.checkpoint:
        print(
            "The model you picked has the same name for the model type and the checkpoint name "
            f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint "
            f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of "
            f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints."
        )
    # 如果旧模型名称(小写形式)与其检查点名称相同,输出警告信息
    elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint:
        print(
            "The model you picked has the same name for the model type and the checkpoint name "
            f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
            f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
            f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
            "used as checkpoints."
        )
    # 检查旧模型模式的类型是否为小写,并且新模型模式的类型不是小写时
    if (
        old_model_patterns.model_type == old_model_patterns.model_lower_cased
        and new_model_patterns.model_type != new_model_patterns.model_lower_cased
    ):
        # 输出警告信息,说明选择的模型类型和小写模型名称相同,可能导致新模型类型在某些地方被误用为小写模型名称
        print(
            "The model you picked has the same name for the model type and the lowercased model name "
            f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
            f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
            f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
            "used as the model type."
        )

    # 如果不保留旧的处理逻辑并且旧模型模式的分词器类不为空时
    if not keep_old_processing and old_model_patterns.tokenizer_class is not None:
        # 输出提示信息,指出需要手动修复新分词器文件开头的常量问题。如果新模型有一个快速分词器,还需手动将转换器添加到 `convert_slow_tokenizer.py` 的 `SLOW_TO_FAST_CONVERTERS` 常量中
        print(
            "The constants at the start of the new tokenizer file created needs to be manually fixed. If your new "
            "model has a tokenizer fast, you will also need to manually add the converter in the "
            "`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`."
        )
def add_new_model_like_command_factory(args: Namespace):
    # 创建并返回一个 AddNewModelLikeCommand 对象,使用参数中的配置文件和仓库路径
    return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo)


class AddNewModelLikeCommand(BaseTransformersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        # 注册子命令 "add-new-model-like" 到指定的 ArgumentParser 对象
        add_new_model_like_parser = parser.add_parser("add-new-model-like")
        add_new_model_like_parser.add_argument(
            "--config_file", type=str, help="A file with all the information for this model creation."
        )
        add_new_model_like_parser.add_argument(
            "--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo."
        )
        # 设置默认的函数处理程序为 add_new_model_like_command_factory 函数
        add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory)

    def __init__(self, config_file=None, path_to_repo=None, *args):
        if config_file is not None:
            # 如果配置文件不为 None,从配置文件中加载配置信息
            with open(config_file, "r", encoding="utf-8") as f:
                config = json.load(f)
            # 初始化对象的各个属性
            self.old_model_type = config["old_model_type"]
            self.model_patterns = ModelPatterns(**config["new_model_patterns"])
            self.add_copied_from = config.get("add_copied_from", True)
            self.frameworks = config.get("frameworks", get_default_frameworks())
            self.old_checkpoint = config.get("old_checkpoint", None)
        else:
            # 如果配置文件为 None,调用 get_user_input() 函数获取用户输入的属性值
            (
                self.old_model_type,
                self.model_patterns,
                self.add_copied_from,
                self.frameworks,
                self.old_checkpoint,
            ) = get_user_input()

        self.path_to_repo = path_to_repo

    def run(self):
        if self.path_to_repo is not None:
            # 如果仓库路径不为 None,则设定全局变量 TRANSFORMERS_PATH 和 REPO_PATH
            global TRANSFORMERS_PATH
            global REPO_PATH

            REPO_PATH = Path(self.path_to_repo)
            TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers"

        # 调用 create_new_model_like 函数创建新模型
        create_new_model_like(
            model_type=self.old_model_type,
            new_model_patterns=self.model_patterns,
            add_copied_from=self.add_copied_from,
            frameworks=self.frameworks,
            old_checkpoint=self.old_checkpoint,
        )


def get_user_field(
    question: str,
    default_value: Optional[str] = None,
    is_valid_answer: Optional[Callable] = None,
    convert_to: Optional[Callable] = None,
    fallback_message: Optional[str] = None,
) -> Any:
    """
    A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
    answer.
    """
    # 简单的用户输入获取函数,带有一些可选的参数和验证功能
    # 如果问题字符串不以空格结尾,添加一个空格
    if not question.endswith(" "):
        question = question + " "
    # 如果提供了默认值,将默认值添加到问题的末尾
    if default_value is not None:
        question = f"{question} [{default_value}] "

    # 初始化有效答案为 False,用于循环直到得到有效答案
    valid_answer = False
    while not valid_answer:
        # 提示用户输入问题,并获取用户输入的答案
        answer = input(question)

        # 如果提供了默认值且用户未输入任何内容,则使用默认值
        if default_value is not None and len(answer) == 0:
            answer = default_value

        # 如果提供了自定义的答案验证函数 is_valid_answer
        if is_valid_answer is not None:
            valid_answer = is_valid_answer(answer)
        # 如果提供了转换函数 convert_to
        elif convert_to is not None:
            try:
                # 尝试将答案转换为指定类型
                answer = convert_to(answer)
                valid_answer = True
            except Exception:
                # 如果转换失败,则标记答案为无效,继续循环
                valid_answer = False
        else:
            # 如果没有提供 is_valid_answer 或 convert_to,直接标记答案为有效
            valid_answer = True

        # 如果答案无效,则打印回退消息
        if not valid_answer:
            print(fallback_message)

    # 返回经过验证和可能转换的答案
    return answer
# 将字符串转换为布尔值
def convert_to_bool(x: str) -> bool:
    """
    Converts a string to a bool.
    """
    # 检查字符串是否在可接受的真值列表中,返回对应的布尔值
    if x.lower() in ["1", "y", "yes", "true"]:
        return True
    # 检查字符串是否在可接受的假值列表中,返回对应的布尔值
    if x.lower() in ["0", "n", "no", "false"]:
        return False
    # 如果字符串既不是真值也不是假值,抛出 ValueError 异常
    raise ValueError(f"{x} is not a value that can be converted to a bool.")


# 获取用户输入以添加新模型
def get_user_input():
    """
    Ask the user for the necessary inputs to add the new model.
    """
    # 获取模型类型列表
    model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())

    # 获取旧模型类型
    valid_model_type = False
    while not valid_model_type:
        # 提示用户输入要复制的模型类型
        old_model_type = input(
            "What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
        )
        # 检查用户输入是否在模型类型列表中
        if old_model_type in model_types:
            valid_model_type = True
        else:
            # 如果输入不在列表中,提示用户并尝试提供建议
            print(f"{old_model_type} is not a valid model type.")
            near_choices = difflib.get_close_matches(old_model_type, model_types)
            if len(near_choices) >= 1:
                if len(near_choices) > 1:
                    near_choices = " or ".join(near_choices)
                print(f"Did you mean {near_choices}?")

    # 获取旧模型的详细信息
    old_model_info = retrieve_info_for_model(old_model_type)
    old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
    old_image_processor_class = old_model_info["model_patterns"].image_processor_class
    old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
    old_processor_class = old_model_info["model_patterns"].processor_class
    old_frameworks = old_model_info["frameworks"]

    # 如果旧模型没有检查点信息,要求用户输入基础检查点的名称
    old_checkpoint = None
    if len(old_model_info["model_patterns"].checkpoint) == 0:
        old_checkpoint = get_user_field(
            "We couldn't find the name of the base checkpoint for that model, please enter it here."
        )

    # 获取新模型的名称
    model_name = get_user_field(
        "What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
    )
    # 创建默认模型模式对象
    default_patterns = ModelPatterns(model_name, model_name)

    # 获取用户输入的模型标识符
    model_type = get_user_field(
        "What identifier would you like to use for the `model_type` of this model? ",
        default_value=default_patterns.model_type,
    )
    # 获取用户输入的模型模块名(小写)
    model_lower_cased = get_user_field(
        "What lowercase name would you like to use for the module (folder) of this model? ",
        default_value=default_patterns.model_lower_cased,
    )
    # 获取用户输入的模型类的前缀(驼峰命名)
    model_camel_cased = get_user_field(
        "What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
        default_value=default_patterns.model_camel_cased,
    )
    # 获取用户输入的模型常量的前缀(大写)
    model_upper_cased = get_user_field(
        "What prefix (upper-cased) would you like to use for the constants relative to this model? ",
        default_value=default_patterns.model_upper_cased,
    )
    # 获取用户输入的配置类名称
    config_class = get_user_field(
        "What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
    )
    )
    # 调用 get_user_field 函数获取用户输入,用于指定新模型的检查点标识符
    checkpoint = get_user_field(
        "Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/FacebookAI/roberta-base): "
    )

    # 创建旧处理类列表,仅包含非空元素
    old_processing_classes = [
        c
        for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
        if c is not None
    ]
    # 将列表转换为逗号分隔的字符串
    old_processing_classes = ", ".join(old_processing_classes)
    # 获取用户输入,确认新模型是否使用与旧模型相同的处理类
    keep_processing = get_user_field(
        f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
        convert_to=convert_to_bool,
        fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
    )
    # 根据用户的选择,确定新模型的处理类
    if keep_processing:
        image_processor_class = old_image_processor_class
        feature_extractor_class = old_feature_extractor_class
        processor_class = old_processor_class
        tokenizer_class = old_tokenizer_class
    else:
        # 如果不使用与旧模型相同的处理类,则根据需要获取各种处理类的新名称
        if old_tokenizer_class is not None:
            tokenizer_class = get_user_field(
                "What will be the name of the tokenizer class for this model? ",
                default_value=f"{model_camel_cased}Tokenizer",
            )
        else:
            tokenizer_class = None
        if old_image_processor_class is not None:
            image_processor_class = get_user_field(
                "What will be the name of the image processor class for this model? ",
                default_value=f"{model_camel_cased}ImageProcessor",
            )
        else:
            image_processor_class = None
        if old_feature_extractor_class is not None:
            feature_extractor_class = get_user_field(
                "What will be the name of the feature extractor class for this model? ",
                default_value=f"{model_camel_cased}FeatureExtractor",
            )
        else:
            feature_extractor_class = None
        if old_processor_class is not None:
            processor_class = get_user_field(
                "What will be the name of the processor class for this model? ",
                default_value=f"{model_camel_cased}Processor",
            )
        else:
            processor_class = None

    # 创建 ModelPatterns 对象,用于保存新模型的相关属性
    model_patterns = ModelPatterns(
        model_name,
        checkpoint,
        model_type=model_type,
        model_lower_cased=model_lower_cased,
        model_camel_cased=model_camel_cased,
        model_upper_cased=model_upper_cased,
        config_class=config_class,
        tokenizer_class=tokenizer_class,
        image_processor_class=image_processor_class,
        feature_extractor_class=feature_extractor_class,
        processor_class=processor_class,
    )

    # 获取用户输入,确定在创建新建模型文件时是否添加 # Copied from 注释
    add_copied_from = get_user_field(
        "Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
        convert_to=convert_to_bool,
        default_value="yes",
        fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
    )
    # 调用函数获取用户字段,询问是否在所有旧模型类型的框架中添加新模型的版本
    # 用户字段包括确认消息、类型转换函数、默认值和回退消息
    all_frameworks = get_user_field(
        "Should we add a version of your new model in all the frameworks implemented by"
        f" {old_model_type} ({old_frameworks}) (yes/no)? ",
        convert_to=convert_to_bool,  # 将用户输入转换为布尔类型的函数
        default_value="yes",  # 默认值为 "yes"
        fallback_message="Please answer yes/no, y/n, true/false or 1/0.",  # 如果用户输入不合法时的提示消息
    )
    
    # 如果用户选择在所有框架中添加新模型版本
    if all_frameworks:
        frameworks = None  # 框架列表设为 None
    else:
        # 否则,获取用户字段,请求用户输入要使用的框架列表
        frameworks = get_user_field(
            "Please enter the list of framworks you want (pt, tf, flax) separated by spaces",
            # 检查用户输入是否有效,要求所有输入项必须是 ["pt", "tf", "flax"] 中的一种
            is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
        )
        frameworks = list(set(frameworks.split(" ")))  # 将输入的框架列表转换为集合去重后再转为列表
    
    # 返回元组包含旧模型类型、模型模式、复制来源、框架列表和旧的检查点
    return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
posted @ 2024-06-30 15:39  绝不原创的飞龙  阅读(44)  评论(0编辑  收藏  举报