Transformers-源码解析-一百三十九-

Transformers 源码解析(一百三十九)

.\utils\fx.py

# 导入 Python 内置模块和第三方库
import builtins
import collections
import functools
import inspect
import math
import operator
import os
import random
import warnings
# 导入类型提示相关的模块和类
from typing import Any, Callable, Dict, List, Optional, Type, Union

# 导入 PyTorch 库
import torch
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx.proxy import ParameterProxy

# 导入 Transformers 相关模块和类
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
    MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_BACKBONE_MAPPING_NAMES,
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_CTC_MAPPING_NAMES,
    MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_IMAGE_MAPPING_NAMES,
    MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
    MODEL_FOR_MASKED_LM_MAPPING_NAMES,
    MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
    MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
    MODEL_FOR_PRETRAINING_MAPPING_NAMES,
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
    MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
    MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
    MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
    ENV_VARS_TRUE_VALUES,
    TORCH_FX_REQUIRED_VERSION,
    get_torch_version,
    is_peft_available,
    is_torch_fx_available,
)

# 如果 peft 可用,则导入 PeftModel 类
if is_peft_available():
    from peft import PeftModel

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 检查是否在调试模式下
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES


def _generate_supported_model_class_names(
    model_name: Type[PretrainedConfig],
    supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[str]:
    # 定义任务映射字典,将任务名称映射到模型名称字典
    task_mapping = {
        "default": MODEL_MAPPING_NAMES,
        "pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
        "next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
        "masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
        "causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
        "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
        "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
        "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
        "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
        "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
        "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
        "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
        "masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
        "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
        "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
        "ctc": MODEL_FOR_CTC_MAPPING_NAMES,
        "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
        "semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
        "backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
        "image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
    }
    
    # 如果 supported_tasks 参数为 None,则使用所有任务的键作为支持的任务列表
    if supported_tasks is None:
        supported_tasks = task_mapping.keys()
    
    # 如果 supported_tasks 是字符串,则转换为包含该字符串的列表
    if isinstance(supported_tasks, str):
        supported_tasks = [supported_tasks]
    
    # 初始化空列表,用于存储模型类名称
    model_class_names = []
    
    # 遍历每个支持的任务
    for task in supported_tasks:
        # 获取任务对应的模型名称的类名,如果找不到则设为 None
        class_name = task_mapping[task].get(model_name, None)
        # 如果找到了类名,则将其添加到模型类名称列表中
        if class_name:
            model_class_names.append(class_name)
    
    # 返回所有找到的模型类名称列表
    return model_class_names
# 正常支持的模型名称和任务列表,用于模型选择和加载
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
    "altclip",              # 替代版本的CLIP模型
    "albert",               # ALBERT模型
    "bart",                 # BART模型
    "bert",                 # BERT模型
    "blenderbot",           # BlenderBot模型
    "blenderbot-small",     # 小型BlenderBot模型
    "bloom",                # Bloom模型
    "clip",                 # CLIP模型
    "convnext",             # ConvNext模型
    "deberta",              # DeBERTa模型
    "deberta-v2",           # DeBERTa-v2模型
    "dinov2",               # DINOv2模型
    "distilbert",           # DistilBERT模型
    "donut-swin",           # Donut-Swin模型
    "electra",              # Electra模型
    "gpt2",                 # GPT-2模型
    "gpt_neo",              # GPT-Neo模型
    "gptj",                 # GPT-J模型
    "hubert",               # Hubert模型
    "layoutlm",             # LayoutLM模型
    "llama",                # LLaMA模型
    "cohere",               # Cohere模型
    "lxmert",               # LXMERT模型
    "m2m_100",              # M2M-100模型
    "marian",               # Marian模型
    "mbart",                # mBART模型
    "megatron-bert",        # Megatron-BERT模型
    "mobilebert",           # MobileBERT模型
    "mt5",                  # MT5模型
    "nezha",                # NeZha模型
    "opt",                  # Opt模型
    "pegasus",              # Pegasus模型
    "plbart",               # PLBART模型
    "resnet",               # ResNet模型
    "roberta",              # RoBERTa模型
    "segformer",            # Segformer模型
    "speech_to_text",       # 语音转文本模型
    "speech_to_text_2",     # 语音转文本模型的另一版本
    "swin",                 # Swin模型
    "t5",                   # T5模型
    "trocr",                # TrOCR模型
    "vit",                  # ViT模型
    "xglm",                 # XGLM模型
    "wav2vec2",             # Wav2Vec 2.0模型
    #    "xlnet",             # 暂时未支持XLNet模型
]

# 支持KV缓存的特殊模型列表
_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]

# 初始化空的正常支持模型列表
_REGULAR_SUPPORTED_MODELS = []

# 遍历正常支持的模型名称和任务列表
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
    # 如果列表项是字典,则生成支持的模型类名并扩展到正常支持模型列表
    if isinstance(item, dict):
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
    else:
        _REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))

# 特殊支持的模型列表,包含特定类名的模型
_SPECIAL_SUPPORTED_MODELS = [
    "CLIPTextModel",                    # CLIP文本模型
    "CLIPTextModelWithProjection",      # 带投影的CLIP文本模型
    "CLIPVisionModel",                  # CLIP视觉模型
    "CLIPVisionModelWithProjection",    # 带投影的CLIP视觉模型
    "AltCLIPTextModel",                 # 替代版本的CLIP文本模型
    "AltCLIPVisionModel",               # 替代版本的CLIP视觉模型
    "GitVisionModel",                   # Git视觉模型
    "GPT2DoubleHeadsModel",             # GPT-2双头模型
    "Speech2Text2Decoder",              # 语音转文本2解码器
    "TrOCRDecoder",                     # TrOCR解码器
    "PeftModelForCausalLM",             # 用于因果语言建模的Peft模型
    "PeftModelForSeq2SeqLM",            # 用于序列到序列语言建模的Peft模型
    # TODO: 添加对它们的支持,这应该很容易做到(存在小的阻碍问题)。
    # XLNetForQuestionAnswering,       # 问答XLNet模型,暂未支持
]

# 所有支持的模型列表,由正常支持和特殊支持模型列表组成,按字母顺序排序并去重
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))


def torch_nn_embedding(self, input):
    # 创建一个与输入形状相同,但最后一个维度与权重张量相同的空张量,设备为"meta",类型与权重相同
    return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)


def torch_nn_functional_embedding(
    input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
    # 创建一个与输入形状相同,但最后一个维度与权重张量相同的空张量,设备为"meta",类型与权重相同
    return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)


def torch_nn_layernorm(self, input):
    # 返回输入本身,表示不应用 LayerNorm 操作
    return input


def torch_nn_groupnorm(self, input):
    # 返回输入本身,表示不应用 GroupNorm 操作
    return input


def torch_nn_linear(self, input):
    # 创建一个与输入形状的前几维相同,但最后一维是输出特征数的空张量,设备为"meta"
    return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")


def torch_relu(x):
    # 返回输入本身,表示不应用 ReLU 激活函数
    return x


def torch_nn_relu(self, x):
    # 返回输入本身,表示不应用 ReLU 激活函数
    return x


def torch_nn_functional_relu(x, inplace=False):
    # 如果不支持原地操作,则抛出异常
    if not inplace:
        raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
    return x


def torch_where(condition, x, y):
    # 使用加法模拟 torch.where 的行为,返回条件、x 和 y 的广播张量
    return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")


def torch_abs(input, *, out=None):
    # 如果要求原地操作,则抛出异常
    if out is not None:
        raise ValueError("Don't support in-place abs for MetaTensor analysis")
    return input


def torch_arange(*args, **kwargs):
    # 计算输入参数的数量
    n = len(args)
    step = 1
    # 检查参数 n 的数量,根据不同情况设置起始、结束和步长
    if n == 1:
        # 当只有一个参数时,设定起始为0,结束为第一个参数值
        start = 0
        end = args[0]
    elif n == 2:
        # 当有两个参数时,分别设置起始和结束为两个参数的值
        start, end = args
    else:
        # 当有三个参数时,设置起始、结束和步长分别为三个参数的值
        start, end, step = args
    
    # 检查起始、结束和步长是否为浮点数,若是则转换为整数
    if isinstance(start, float):
        start = int(start)
    if isinstance(end, float):
        # 修正:应该是 end 而不是 start
        end = int(end)
    if isinstance(step, float):
        step = int(step)
    
    # 获取关键字参数中的步长,若未提供则使用之前设置的步长值
    step = kwargs.get("step", step)
    # 获取关键字参数中的数据类型,用于创建 tensor
    dtype = kwargs.get("dtype")
    # 返回一个空的 torch tensor,形状为((end - start) // step),指定数据类型和设备为 "meta"
    return torch.empty((end - start) // step, dtype=dtype, device="meta")
# 创建一个函数 torch_full,用于生成一个指定维度和值的张量
def torch_full(*args, **kwargs):
    # 将位置参数转换为列表
    args = list(args)
    # 如果第二个参数是 torch.Tensor 类型且设备是 "meta",则将其设为 1(任意值)
    if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
        args[1] = 1  # 任意值。
    # 复制关键字参数到新的字典,去除 "device" 参数
    kwargs_without_device = dict(kwargs)
    kwargs_without_device.pop("device", None)
    # 调用 torch.full 函数,生成张量并返回
    return torch.full(*args, **kwargs_without_device)


# 创建一个函数 torch_cat,用于沿指定维度对张量进行拼接
def torch_cat(tensors, dim=None, axis=None, *, out=None):
    # 如果 dim 和 axis 都未指定,则将 dim 设置为 0
    if dim is None and axis is None:
        dim = 0
    # 如果 dim 未指定但 axis 指定了,则将 dim 设置为 axis
    if dim is None and axis is not None:
        dim = axis
    # 如果 dim 是负数,将其转换为正数索引
    if dim < 0:
        dim = tensors[0].dim() + dim
    # 获取所有张量的形状
    shapes = [t.shape for t in tensors]
    # 获取第一个张量的形状
    shape = list(shapes[0])
    # 计算拼接后张量在指定维度上的总长度
    concatenated_dim = sum(shape[dim] for shape in shapes)
    # 构建最终的张量形状
    final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
    # 返回一个新的空张量,形状为 final_shape,设备为 "meta"
    return torch.empty(final_shape, device="meta")


# 创建一个函数 torch_stack,用于沿新的维度对张量序列进行堆叠
def torch_stack(tensors, dim=None, axis=None, *, out=None):
    # 如果 dim 和 axis 都未指定,则将 dim 设置为 0
    if dim is None and axis is None:
        dim = 0
    # 如果 dim 未指定但 axis 指定了,则将 dim 设置为 axis
    if dim is None and axis is not None:
        dim = axis
    # 如果 dim 是负数,将其转换为正数索引
    if dim < 0:
        dim = tensors[0].dim() + 1 + dim
    # 获取第一个张量的形状
    shape = list(tensors[0].shape)
    # 在指定维度上插入新的维度,长度为张量序列的长度
    shape.insert(dim, len(tensors))
    # 返回一个新的空张量,形状为 shape,设备为 "meta"
    return torch.empty(shape, device="meta")


# 创建一个函数 torch_add,用于对两个张量进行加法操作
def torch_add(input, other, *, alpha=1, out=None):
    # 如果 input 不是 torch.Tensor 类型,则返回一个与 other 相同形状的空张量,设备为 "meta"
    if not isinstance(input, torch.Tensor):
        return torch.empty_like(other, device="meta")
    # 如果 other 不是 torch.Tensor 类型,则返回一个与 input 相同形状的空张量,设备为 "meta"
    if not isinstance(other, torch.Tensor):
        return torch.empty_like(input, device="meta")
    # 计算两个张量的最大维度
    max_length = max(input.dim(), other.dim())
    # 将 input 和 other 扩展为相同的维度
    input_shape = list(input.shape) + [1] * (max_length - input.dim())
    other_shape = list(other.shape) + [1] * (max_length - other.dim())
    shape = []
    for i in range(max_length):
        shape.append(max(input_shape[i], other_shape[i]))
    # 返回一个新的空张量,形状为 shape,设备为 "meta"
    return torch.empty(shape, device="meta")


# 创建一个函数 torch_mul,用于对两个张量进行乘法操作,实际上调用了 torch_add 函数
def torch_mul(input, other, *, out=None):
    return torch_add(input, other, out=out)


# 创建一个函数 torch_tensor_mul,用于对两个张量进行乘法操作,实际上调用了 torch_mul 函数
def torch_tensor_mul(self, other):
    return torch_mul(self, other)


# 创建一个函数 torch_matmul,用于对两个张量进行矩阵乘法操作
def torch_matmul(input, other, *, out=None):
    # 获取 input 和 other 的维度
    d1 = input.dim()
    d2 = other.dim()
    shape = None
    # 根据不同的维度情况进行判断和处理
    if d1 == 1 and d2 == 1:
        shape = None
    elif d1 == 2 and d2 == 2:
        shape = (input.size(0), other.size(1))
    elif d1 == 1 and d2 == 2:
        shape = (other.size(1),)
    elif d1 == 2 and d1 == 1:  # 应为 d2 == 1
        shape = (input.size(0),)
    else:
        max_length = max(input.dim(), other.dim())
        shape1 = list(input.shape)
        shape2 = list(other.shape)
        if d1 == 1:
            shape1 = [1] + shape1
        if d2 == 1:
            shape2.append(1)
        shape1 = [-1] * (max_length - d1) + list(input.shape)
        shape2 = [-1] * (max_length - d2) + list(other.shape)
        shape = []
        for i in range(max_length):
            shape.append(max(shape1[i], shape2[i]))
        shape[-2] = shape1[-2]
        shape[-1] = shape2[-1]
        if d1 == 1:
            shape.pop(-2)
        if d2 == 1:
            shape.pop(-1)
    # 如果 shape 为 None,则返回一个标量张量 0.0,设备为 "meta"
    if shape is None:
        return torch.tensor(0.0, device="meta")
    # 返回一个新的空张量,形状为 shape,设备为 "meta"
    return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
    # 如果指定了输出张量out,抛出值错误异常,不支持原地操作
    if out is not None:
        raise ValueError("Don't support in-place bmm for MetaTensor analysis")
    # 获取输入张量input的批大小、行数n、列数m
    batch_size, n, m = input.shape
    # 获取mat2张量的最后两个维度的大小,即第二个维度的行数和第三个维度的列数
    _, _, p = mat2.shape
    # 返回一个空的元数据设备张量,形状为(batch_size, n, p)
    return torch.empty(batch_size, n, p, device="meta")


def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
    # 如果指定了输出张量out,抛出值错误异常,不支持原地操作
    if out is not None:
        raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
    # 调用torch_bmm函数计算batch1和batch2的批次矩阵乘积,返回结果
    return torch_bmm(batch1, batch2)


def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
    # 调用torch_baddbmm函数计算self张量和batch1、batch2的批次矩阵乘积,返回结果
    return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)


def torch_einsum(equation, *operands):
    # TODO: infer shape without performing the computation, this might be quite hard.
    # 创建与操作数具有相同形状的空张量列表,设备为CPU
    concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
    # 执行爱因斯坦求和符号运算,返回结果张量并设备为元数据设备
    return torch.einsum(equation, *concrete_operands).to("meta")


def torch_tensor_repeat(self, *sizes):
    # 获取self张量的形状列表
    shape = list(self.shape)
    # 根据sizes参数修改形状列表中的每个维度大小
    for i, x in enumerate(sizes):
        shape[i] *= x
    # 返回一个空的元数据设备张量,形状由修改后的形状列表确定
    return torch.empty(shape, device="meta")


def torch_repeat_interleave(*args, dim=None, output_size=None):
    # 获取参数数量
    num_args = len(args)
    # 如果参数数量为1
    if num_args == 1:
        # 如果output_size不为None,则创建一个形状为[output_size]的列表
        shape = [output_size if output_size is not None else args[0].sum()]
    else:
        # 否则创建一个形状与第一个参数args[0]相同的列表
        shape = list(args[0].shape)
        # 如果未指定维度dim
        if dim is None:
            # 如果参数数量大于2,则将dim设置为args[2]
            if num_args > 2:
                dim = args[2]
            else:
                # 否则将shape变为包含总和的列表,并将dim设置为0
                shape = [sum(shape)]
                dim = 0
        # 获取重复次数repeats
        repeats = args[1]
        # 如果repeats是整数或者元素数量为1
        if isinstance(repeats, int) or torch.numel(repeats) == 1:
            # 将shape[dim]乘以repeats的整数值
            shape[dim] *= int(repeats)
        else:
            # 否则将shape[dim]设置为output_size或者repeats的总和
            shape[dim] = output_size if output_size is not None else repeats.sum()
    # 返回一个空的元数据设备张量,形状由shape确定
    return torch.empty(*shape, device="meta")


def torch_index_select(input, dim, index, *, out=None):
    # 获取input张量的形状列表
    shape = list(input.shape)
    # 修改形状列表中指定维度dim的大小为索引index的长度
    shape[dim] = len(index)
    # 返回一个空的元数据设备张量,形状由修改后的形状列表确定
    return torch.empty(*shape, device="meta")


def torch_tensor_index_select(self, dim, index):
    # 调用torch_index_select函数从self张量中选择指定维度dim的索引index,返回结果
    return torch_index_select(self, dim, index)


def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
    # 获取input张量的形状列表
    shape = list(input.shape)
    # 修改形状列表中指定维度dim的大小为索引index指定维度的大小
    shape[dim] = index.shape[dim]
    # 返回一个空的元数据设备张量,形状由修改后的形状列表确定
    return torch.empty(*shape, device="meta")


def torch_tensor_gather(self, dim, index):
    # 调用torch_gather函数从self张量中收集指定维度dim的索引index,返回结果
    return torch_gather(self, dim, index)


def torch_roll(input, shifts, dims=None):
    # 返回未修改的输入张量input
    return input


def torch_flip(input, dims):
    # 返回未修改的输入张量input
    return input


def torch_tensor_flip(self, dims):
    # 返回未修改的self张量
    return self


def torch_nn_conv1d(self, input):
    # 获取输入input的最后一个维度的大小
    l_in = input.shape[-1]
    # 初始化形状为None
    shape = None
    # 获取卷积的填充方式
    padding = self.padding
    # 如果填充为"valid",则将padding设为(0, 0)
    if padding == "valid":
        padding = (0, 0)
    # 如果填充为"same"
    if padding == "same":
        # 将shape设置为输入input的形状列表
        shape = list(input.shape)
    # 如果shape仍为None
    if shape is None:
        # 将shape设置为输入input的形状列表
        shape = list(input.shape)
        # 计算输出的长度l_out,并向下取整
        l_out = math.floor(
            (l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        # 修改shape中倒数第二个维度的大小为输出通道数self.out_channels
        shape[-1] = l_out
    # 修改shape中倒数第三个维度的大小为输出通道数self.out_channels
    shape[-2] = self.out_channels
    # 返回一个空的元数据设备张量,形状由修改后的shape确定
    return torch.empty(shape, device="meta")
# 定义一个类方法用于进行二维卷积操作
def torch_nn_conv2d(self, input):
    # 获取输入张量的高度和宽度
    h_in, w_in = input.shape[-2:]
    # 初始化形状变量为 None
    shape = None
    # 获取填充参数
    padding = self.padding
    # 如果填充方式是 "valid",则将 padding 设置为 (0, 0)
    if padding == "valid":
        padding = (0, 0)
    # 如果填充方式是 "same",则复制输入张量的形状
    if padding == "same":
        shape = list(input.shape)
    # 如果形状仍为 None,则根据卷积参数计算输出形状
    if shape is None:
        shape = list(input.shape)
        h_out = math.floor(
            (h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
        )
        w_out = math.floor(
            (w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
        )
        shape[-2:] = [h_out, w_out]
    # 设置输出张量的通道数维度为 self.out_channels
    shape[-3] = self.out_channels
    # 返回一个空的元数据张量,其形状由计算得出
    return torch.empty(shape, device="meta")


# 定义一个函数用于从输入张量中挤压维度为1的维度
def torch_squeeze(input, dim=None):
    # 获取输入张量的形状
    shape = list(input.shape)
    # 如果指定了维度 dim
    if dim is not None:
        # 将负数索引转换为正数索引
        if dim < 0:
            dim = input.dim() + dim
        # 如果指定维度的大小为1,则从形状中删除该维度
        if shape[dim] == 1:
            shape.pop(dim)
    else:
        # 如果未指定维度 dim,则遍历形状,删除所有大小为1的维度
        new_shape = []
        for dim_value in shape:
            if dim_value == 1:
                continue
            new_shape.append(dim_value)
        shape = new_shape
    # 返回一个空的元数据张量,其形状为经过挤压操作后的形状
    return torch.empty(shape, device="meta")


# 定义一个类方法,用于对类的实例进行挤压操作
def torch_tensor_squeeze(self, dim=None):
    # 调用全局的挤压函数 torch_squeeze 对类实例进行操作
    return torch_squeeze(self, dim)


# 定义一个函数用于在指定维度 dim 上对输入张量进行展开操作
def torch_unsqueeze(input, dim):
    # 获取输入张量的形状
    shape = list(input.shape)
    # 将负数索引转换为正数索引
    if dim < 0:
        dim = input.dim() + 1 + dim
    # 在指定维度 dim 处插入一个大小为1的维度
    shape.insert(dim, 1)
    # 返回一个空的元数据张量,其形状为经过展开操作后的形状
    return torch.empty(shape, device="meta")


# 定义一个类方法,用于对类的实例进行展开操作
def torch_tensor_unsqueeze(self, dim):
    # 调用全局的展开函数 torch_unsqueeze 对类实例进行操作
    return torch_unsqueeze(self, dim)


# 定义一个函数用于计算输入张量的连续唯一值,并保持顺序
def torch_unique_consecutive(input, **kwargs):
    # 调用 PyTorch 的 torch.unique_consecutive 函数对输入张量进行操作
    output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
    # 如果输出是张量,则将其转移到设备 "meta"
    if isinstance(output, torch.Tensor):
        return output.to("meta")
    else:
        # 否则,将其元组化,并使用映射函数将其中的张量转移到设备 "meta"
        return tuple(map(output, lambda x: x.to("meta")))


# 定义一个函数用于为输入张量的每个元素创建一个指定长度的独热编码
def torch_nn_functional_one_hot(tensor, num_classes=-1):
    # 如果未指定 num_classes,则抛出错误,不支持自动推断 num_classes
    if num_classes < 0:
        raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
    # 计算输出张量的形状,将最后一个维度扩展为 num_classes
    shape = list(tensor.shape) + [num_classes]
    # 返回一个空的元数据张量,其形状为计算得出的形状
    return torch.empty(shape, device="meta")


# 定义一个函数用于实现缩放点积注意力机制
def torch_nn_functional_scaled_dot_product_attention(
    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
):
    # 获取查询张量的目标长度和值张量的头部维度
    target_length = query.shape[-2]
    head_dim = value.shape[-1]
    # 返回一个空的元数据张量,其形状为 (query.shape[:-2], target_length, head_dim)
    return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")


# 定义一个类方法用于计算均方误差损失
def torch_nn_mseloss(self, input, target):
    # 如果损失函数的减少方式为 "none",则输出的形状与目标形状相同
    if self.reduction == "none":
        shape = target.shape
    else:
        # 否则,输出的形状为 (1,)
        shape = (1,)
    # 返回一个空的元数据张量,其形状为计算得出的形状
    return torch.empty(shape, device="meta")


# 定义一个类方法用于计算交叉熵损失
def torch_nn_crossentropyloss(self, input, target):
    # 如果损失函数的减少方式为 "none",则输出的形状与目标形状相同
    if self.reduction == "none":
        shape = target.shape
    else:
        # 否则,输出的形状为 (1,)
        shape = (1,)
    # 返回一个空的元数据张量,其形状为计算得出的形状
    return torch.empty(shape, device="meta")


# 定义一个类方法用于计算带有 logits 的二元交叉熵损失
def torch_nn_bcewithlogitsloss(self, input, target):
    # 如果损失函数的减少方式为 "none",则输出的形状与目标形状相同
    if self.reduction == "none":
        shape = target.shape
    else:
        # 否则,输出的形状为 (1,)
        shape = (1,)
    # 返回一个空的元数据张量,其形状为计算得出的形状
    return torch.empty(shape, device="meta")


# 定义一个函数操作符,用于获取张量 a 中的 b 元素
def operator_getitem(a, b):
    # 省略函数实现,没有实际执行体
    pass
    # 定义函数 to_concrete,用于将输入的张量 t 转换为具体的张量
    def to_concrete(t):
        # 如果 t 是 torch.Tensor 类型
        if isinstance(t, torch.Tensor):
            # 创建一个与 t 相同形状的全为1的张量,存储在 CPU 上
            concrete = torch.ones_like(t, device="cpu")
            # 如果 concrete 的数据类型是浮点数或者整数32位,则将其转换为64位整数
            if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
                concrete = concrete.to(torch.int64)
            return concrete
        # 如果 t 不是 torch.Tensor 类型,直接返回 t
        return t

    # 如果 a 是 torch.Tensor 类型
    if isinstance(a, torch.Tensor):
        # TODO: 推断形状而不执行计算。
        # 如果 b 是元组类型,对元组中的每个元素应用 to_concrete 函数
        if isinstance(b, tuple):
            b = tuple(map(to_concrete, b))
        else:
            # 否则,将 b 应用 to_concrete 函数
            b = to_concrete(b)
        # 返回使用 a 形状创建的空张量的子集,转换为 "meta" 类型
        return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
    # 如果 a 不是 torch.Tensor 类型,返回 a 的子集
    return operator.getitem(a, b)
# 定义一个字典,用于存储手动指定的函数覆盖映射,将特定的 Torch 函数映射到自定义函数
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
    torch.nn.Embedding: torch_nn_embedding,  # 将 torch.nn.Embedding 映射到 torch_nn_embedding 函数
    torch.nn.functional.embedding: torch_nn_functional_embedding,  # 将 torch.nn.functional.embedding 映射到 torch_nn_functional_embedding 函数
    torch.nn.LayerNorm: torch_nn_layernorm,  # 将 torch.nn.LayerNorm 映射到 torch_nn_layernorm 函数
    torch.nn.GroupNorm: torch_nn_groupnorm,  # 将 torch.nn.GroupNorm 映射到 torch_nn_groupnorm 函数
    torch.nn.Linear: torch_nn_linear,  # 将 torch.nn.Linear 映射到 torch_nn_linear 函数
    torch.relu: torch_relu,  # 将 torch.relu 映射到 torch_relu 函数
    torch.nn.functional.relu: torch_nn_functional_relu,  # 将 torch.nn.functional.relu 映射到 torch_nn_functional_relu 函数
    torch.nn.ReLU: torch_nn_relu,  # 将 torch.nn.ReLU 映射到 torch_nn_relu 函数
    torch.where: torch_where,  # 将 torch.where 映射到 torch_where 函数
    torch.abs: torch_abs,  # 将 torch.abs 映射到 torch_abs 函数
    torch.arange: torch_arange,  # 将 torch.arange 映射到 torch_arange 函数
    torch.full: torch_full,  # 将 torch.full 映射到 torch_full 函数
    torch.cat: torch_cat,  # 将 torch.cat 映射到 torch_cat 函数
    torch.stack: torch_stack,  # 将 torch.stack 映射到 torch_stack 函数
    torch.add: torch_add,  # 将 torch.add 映射到 torch_add 函数
    torch.mul: torch_mul,  # 将 torch.mul 映射到 torch_mul 函数
    torch.Tensor.mul: torch_tensor_mul,  # 将 torch.Tensor.mul 映射到 torch_tensor_mul 函数
    torch.matmul: torch_matmul,  # 将 torch.matmul 映射到 torch_matmul 函数
    torch.bmm: torch_bmm,  # 将 torch.bmm 映射到 torch_bmm 函数
    torch.baddbmm: torch_baddbmm,  # 将 torch.baddbmm 映射到 torch_baddbmm 函数
    torch.Tensor.baddbmm: torch_tensor_baddbmm,  # 将 torch.Tensor.baddbmm 映射到 torch_tensor_baddbmm 函数
    torch.einsum: torch_einsum,  # 将 torch.einsum 映射到 torch_einsum 函数
    torch.Tensor.repeat: torch_tensor_repeat,  # 将 torch.Tensor.repeat 映射到 torch_tensor_repeat 函数
    torch.repeat_interleave: torch_repeat_interleave,  # 将 torch.repeat_interleave 映射到 torch_repeat_interleave 函数
    torch.roll: torch_roll,  # 将 torch.roll 映射到 torch_roll 函数
    torch.flip: torch_flip,  # 将 torch.flip 映射到 torch_flip 函数
    torch.Tensor.flip: torch_tensor_flip,  # 将 torch.Tensor.flip 映射到 torch_tensor_flip 函数
    torch.index_select: torch_index_select,  # 将 torch.index_select 映射到 torch_index_select 函数
    torch.Tensor.index_select: torch_tensor_index_select,  # 将 torch.Tensor.index_select 映射到 torch_tensor_index_select 函数
    torch.gather: torch_gather,  # 将 torch.gather 映射到 torch_gather 函数
    torch.Tensor.gather: torch_tensor_gather,  # 将 torch.Tensor.gather 映射到 torch_tensor_gather 函数
    torch.nn.Conv1d: torch_nn_conv1d,  # 将 torch.nn.Conv1d 映射到 torch_nn_conv1d 函数
    torch.nn.Conv2d: torch_nn_conv2d,  # 将 torch.nn.Conv2d 映射到 torch_nn_conv2d 函数
    torch.squeeze: torch_squeeze,  # 将 torch.squeeze 映射到 torch_squeeze 函数
    torch.Tensor.squeeze: torch_tensor_squeeze,  # 将 torch.Tensor.squeeze 映射到 torch_tensor_squeeze 函数
    torch.unsqueeze: torch_unsqueeze,  # 将 torch.unsqueeze 映射到 torch_unsqueeze 函数
    torch.Tensor.unsqueeze: torch_tensor_unsqueeze,  # 将 torch.Tensor.unsqueeze 映射到 torch_tensor_unsqueeze 函数
    torch.unique_consecutive: torch_unique_consecutive,  # 将 torch.unique_consecutive 映射到 torch_unique_consecutive 函数
    torch.nn.functional.one_hot: torch_nn_functional_one_hot,  # 将 torch.nn.functional.one_hot 映射到 torch_nn_functional_one_hot 函数
    torch.nn.MSELoss: torch_nn_mseloss,  # 将 torch.nn.MSELoss 映射到 torch_nn_mseloss 函数
    torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,  # 将 torch.nn.CrossEntropyLoss 映射到 torch_nn_crossentropyloss 函数
    torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,  # 将 torch.nn.BCEWithLogitsLoss 映射到 torch_nn_bcewithlogitsloss 函数
    operator.getitem: operator_getitem,  # 将 operator.getitem 映射到 operator_getitem 函数
}

# 如果 Torch 的版本大于等于 2.0,将 torch.nn.functional.scaled_dot_product_attention 映射到 torch_nn_functional_scaled_dot_product_attention 函数
if is_torch_greater_or_equal_than_2_0:
    _MANUAL_META_OVERRIDES[
        torch.nn.functional.scaled_dot_product_attention
    ] = torch_nn_functional_scaled_dot_product_attention


class HFProxy(Proxy):
    """
    Proxy that uses metadata to handle data-dependent control-flow.
    """

    def install_metadata(self, metadata):
        self._metadata = metadata

    @property
    def shape(self):
        # 使用追踪器创建一个代理对象,调用方法为 "size",参数为自身 (self),返回创建的代理对象
        return self.tracer.create_proxy("call_method", "size", (self,), {})

    @property
    def device(self):
        # 用于跟踪设备使用情况的 Hack。在元张量传播期间,将这些值替换为常量 'meta'
        return MetaDeviceAttribute(self, "device")

    def __len__(self):
        # 如果存在 _metadata 属性且不为 None,则返回 _metadata 的长度,否则调用父类的 __len__ 方法返回长度
        if hasattr(self, "_metadata") and self._metadata is not None:
            return len(self._metadata)
        return super().__len__()

    def __bool__(self):
        # 如果存在 _metadata 属性且不为 None,则返回 _metadata,否则调用父类的 __bool__ 方法返回布尔值
        if hasattr(self, "_metadata") and self._metadata is not None:
            return self._metadata
        return super().__bool__()
    # 当访问不存在的属性时被调用,这里检查属性名称是否为"_metadata"
    def __getattr__(self, k):
        # 如果属性名称为"_metadata",直接返回其属性值
        if k == "_metadata":
            return self.__getattribute__(k)
        # 如果不是"_metadata",创建并返回一个HFAttribute对象,用于处理属性访问
        # 注意:如果这是一个方法调用,它还未添加到图形中,我们会优化为方法调用
        return HFAttribute(self, k)

    # 当使用self[key] = value时调用,创建一个"call_function"代理对象来追踪此操作
    def __setitem__(self, indices, values):
        return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})

    # 当使用key in self语句检查成员资格时调用
    def __contains__(self, key):
        # 检查是否存在"_metadata"属性且不为None,如果是,则检查key是否在"_metadata"中
        if hasattr(self, "_metadata") and self._metadata is not None:
            return key in self._metadata
        # 否则,委托给超类的__contains__方法来处理key的成员资格检查
        return super().__contains__(key)
class HFAttribute(HFProxy):
    # HFAttribute 类继承自 HFProxy 类
    def __init__(self, root, attr: str):
        # 初始化方法,接受 root 和 attr 参数
        self.root = root
        self.attr = attr
        self.tracer = root.tracer  # 将 root 的 tracer 赋给实例变量 tracer
        self._node = None  # 初始化 _node 为 None

        # 如果 root 对象有 _metadata 属性,则安装对应 attr 的元数据
        if hasattr(self.root, "_metadata"):
            self.install_metadata(getattr(self.root._metadata, attr))

    @property
    def node(self):
        # node 属性,延迟加载节点,大多数情况下只有方法调用,不依赖于 getitem 调用
        if self._node is None:
            self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
        return self._node

    def __call__(self, *args, **kwargs):
        # 实例可调用,创建代理对象,调用方法为 call_method
        return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)


class MetaDeviceAttribute(HFAttribute):
    # MetaDeviceAttribute 类继承自 HFAttribute 类
    pass


def _proxies_to_metas(v):
    """Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
    # 将 HFProxies 的基础元数据返回,并对其他对象行为像是返回自身
    if isinstance(v, MetaDeviceAttribute):
        return "meta"
    if isinstance(v, torch.fx.Proxy):
        # 对于 torch.fx.Proxy 类型的对象,确保其有元数据,否则引发 RuntimeError
        if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
            raise RuntimeError(f"No metadata was found for {v}")
        return v._metadata
    return v


def _gen_constructor_wrapper(target):
    # 生成构造函数的包装器,用于包装目标函数 target
    @functools.wraps(target)
    def wrapper(*args, **kwargs):
        proxy = None

        def check_has_proxy(v):
            # 检查参数中是否有 Proxy 对象
            if isinstance(v, Proxy):
                nonlocal proxy
                proxy = v

        # 对 args 和 kwargs 进行映射,检查是否有 Proxy 对象存在
        torch.fx.node.map_aggregate(args, check_has_proxy)
        torch.fx.node.map_aggregate(kwargs, check_has_proxy)

        # 如果存在 Proxy 对象,则通过 tracer 创建代理对象,调用方式为 call_function
        if proxy is not None:
            return proxy.tracer.create_proxy("call_function", target, args, kwargs)
        else:
            return target(*args, **kwargs)

    return wrapper, target


def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
    # 生成指定范围内的随机整数,可以排除 forbidden_values 中的值
    if forbidden_values is None:
        forbidden_values = []
    value = random.randint(low, high)
    while value in forbidden_values:
        value = random.randint(low, high)
    return value


class HFTracer(Tracer):
    """
    Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
    regular PyTorch torch.fx.Proxy.
    """
    # HFTracer 类,用于符号化跟踪库中的模型,使用 HFProxy 而非常规的 PyTorch torch.fx.Proxy

    # 用于代理访问缓冲区值的功能标志
    proxy_buffer_attributes: bool = True
    allow_insert_stateless_mods: bool = True
    _TORCH_METHODS_TO_PATCH = [
        "arange",
        "zeros",
        "ones",
        "full",
        "full_like",
        "eye",
        "empty",
        "tensor",
        "clamp",
        "finfo",
    ]
    # 支持的架构类型,包括 PreTrainedModel 和 PeftModel(如果可用)
    supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
    # 初始化方法,用于设置自动包装的模块和函数
    def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
        # 调用父类的初始化方法
        super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)

        # 检查是否有可用的 Torch FX 版本
        if not is_torch_fx_available():
            # 如果没有可用的 Torch FX 版本,则抛出 ImportError 异常
            raise ImportError(
                f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
                f"{TORCH_FX_REQUIRED_VERSION} is supported."
            )

    # 生成虚拟输入的方法
    def _generate_dummy_input(
        self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
    ):
        # 已被 PyTorch 1.13 替换为 .getattr 方法

    # 用于获取模块属性的方法
    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
        # 如果禁用了模块获取属性,则直接返回属性值
        if getattr(self, "_disable_module_getattr", False):
            return attr_val
        else:
            # 内部函数,用于获取属性的代理
            def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
                for n, p in collection_to_search:
                    if attr_val is p:
                        # 如果尚未缓存属性的代理,则创建新的代理
                        if n not in parameter_proxy_cache:
                            kwargs = {}
                            # 如果支持参数形状常量,则创建参数代理
                            if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
                                kwargs["proxy_factory_fn"] = (
                                    None
                                    if not self.param_shapes_constant
                                    else lambda node: ParameterProxy(self, node, n, attr_val)
                                )
                            # 使用代理工厂函数创建代理
                            val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs)  # type: ignore[arg-type]
                            parameter_proxy_cache[n] = val_proxy
                        return parameter_proxy_cache[n]
                return None

            # 如果属性值是 torch.nn.Parameter 类型,则尝试获取参数代理
            if isinstance(attr_val, torch.nn.Parameter):
                maybe_parameter_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_parameters(), parameter_proxy_cache
                )
                if maybe_parameter_proxy is not None:
                    return maybe_parameter_proxy

            # 如果启用了代理缓冲属性,并且属性值是 torch.Tensor 类型,则尝试获取缓冲区代理
            if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
                maybe_buffer_proxy = maybe_get_proxy_for_attr(
                    attr_val, self.root.named_buffers(), parameter_proxy_cache
                )
                if maybe_buffer_proxy is not None:
                    return maybe_buffer_proxy

            # 如果未找到代理,直接返回属性值
            return attr_val

    # PyTorch 1.13+ 版本所需的方法,用于获取属性
    def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
        return self._module_getattr(attr, attr_val, parameter_proxy_cache)

    # 调用模块的方法,设置原始前向方法并调用父类方法
    def call_module(self, m, forward, args, kwargs):
        self.orig_forward = forward
        return super().call_module(m, forward, args, kwargs)

    # 返回 HFProxy 实例的方法
    def proxy(self, node):
        return HFProxy(node, self)
    def trace(
        self,
        root: Union[torch.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None,
        dummy_inputs: Optional[Dict[str, Any]] = None,
        complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
    ):
        """
        Trace method for tracing through the module hierarchy starting from `root`.

        Args:
            root (Union[torch.nn.Module, Callable[..., Any]]): The root module or callable to start tracing from.
            concrete_args (Optional[Dict[str, Any]]): Concrete arguments for the traced function.
            dummy_inputs (Optional[Dict[str, Any]]): Dummy inputs for the traced function.
            complete_concrete_args_with_inputs_not_in_dummy_inputs (bool):
                Flag indicating whether to complete concrete arguments with inputs not in dummy inputs.

        Returns:
            None
        """

    def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
        """
        Check if the module's instantiation depends on Proxies.

        Args:
            mod (nn.Module): The module to check.

        Returns:
            bool: True if the module was instantiated with Proxies, otherwise False.
        """
        return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())

    def _insert_module_as_submodule(self, mod: nn.Module) -> str:
        """
        Try to insert a module that was not declared as a submodule.

        Args:
            mod (nn.Module): The module to insert.

        Returns:
            str: Path where the module was inserted as a submodule, or an empty string if insertion failed.
        """

        # If one of the module attributes is a Proxy, its instantiation is input-dependent.
        if self._stateless_mod_instanciation_depends_on_proxies(mod):
            return ""

        idx = 0
        mod_name = mod.__class__.__name__.lower()
        path = f"{mod_name}_{idx}"
        already_inserted = False

        # Check if the module is already inserted at the computed path
        while hasattr(self.root, path):
            if getattr(self.root, path) is mod:
                already_inserted = True
                break
            path = f"{mod_name}_{idx}"
            idx += 1

        # Insert the module if it's not already present
        if not already_inserted:
            self.root.add_module(path, mod)
        return path

    def path_of_module(self, mod: nn.Module) -> str:
        """
        Find the qualified name of `mod` in the Module hierarchy of `root`.

        Args:
            mod (nn.Module): The module to retrieve the qualified name for.

        Returns:
            str: Qualified path of the module in the Module hierarchy of `root`.
        """

        try:
            return super().path_of_module(mod)
        except NameError as e:
            # Handle case where `mod` is not directly found in `root`'s modules
            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
                path = self._insert_module_as_submodule(mod)
                return path
            raise e

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
        """
        Check if a module is a leaf module in the module hierarchy.

        Args:
            m (torch.nn.Module): The module to check.
            module_qualified_name (str): Qualified name of the module in the hierarchy.

        Returns:
            bool: True if the module is a leaf module, otherwise False.
        """

        # Check if module instantiation depends on Proxies and delegate to superclass method
        return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
            m, module_qualified_name
        )

    @compatibility(is_backward_compatible=True)
    # Decorator indicating backward compatibility
    def keys(self, obj: "Proxy") -> Any:
        """Called when a proxy object has the keys() method called.
        当代理对象调用keys()方法时调用此函数。
        This is what happens when ** is called on a proxy.
        当代理对象上调用**运算符时会发生这种情况。
        This should return an iterator if ** is supposed to work in
        your custom tracer.
        如果希望在自定义的追踪器中**运算符正常工作,此方法应返回一个迭代器。
        """
        # Create an HFAttribute object for the 'keys' attribute of the proxy object
        attribute = HFAttribute(obj, "keys")()
        # Check if the target of the proxy object is '**kwargs'
        if obj.node.target == "**kwargs":
            # Return the metadata of the attribute if the target is '**kwargs'
            return attribute._metadata
        # Otherwise, return the attribute itself
        return attribute
# 获取模型的 forward 方法的参数签名
sig = inspect.signature(model.forward)

# 检查输入参数列表是否都在模型的参数签名中
if not (set(input_names) <= set(sig.parameters.keys())):
    # 如果有未在参数签名中的输入参数,生成格式化的错误信息并抛出 ValueError 异常
    formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
    formatted_allowed_input_names = ", ".join(sig.parameters.keys())
    raise ValueError(
        f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
        f" {formatted_allowed_input_names}"
    )

# 返回模型 forward 方法中除了输入参数之外的参数名和默认值构成的字典
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}

.\utils\generic.py

# 版权声明
# 版权声明和许可证信息

# 导入模块
import inspect  # 导入inspect模块,用于解析源文件
import tempfile  # 临时文件模块
from collections import OrderedDict, UserDict  # 导入OrderedDict和UserDict类
from collections.abc import MutableMapping  # 导入MutableMapping抽象基类
from contextlib import ExitStack, contextmanager  # 导入ExitStack和contextmanager上下文管理器
from dataclasses import fields, is_dataclass  # 导入fields和is_dataclass函数
from enum import Enum  # 导入枚举类Enum
from functools import partial  # 导入partial函数
from typing import Any, ContextManager, Iterable, List, Tuple  # 导入类型提示相关的模块和类

import numpy as np  # 导入numpy模块并重命名为np
from packaging import version  # 导入version类

from .import_utils import (  # 导入import_utils中的指定函数和类
    get_torch_version,
    is_flax_available,
    is_mlx_available,
    is_tf_available,
    is_torch_available,
    is_torch_fx_proxy,
)

# 如果Flax可用,则导入jax.numpy as jnp
if is_flax_available():
    import jax.numpy as jnp

# 自定义缓存属性装饰器
class cached_property(property):
    """
    Descriptor that mimics @property but caches output in member variable.

    From tensorflow_datasets

    Built-in in functools from Python 3.8.
    """
    def __get__(self, obj, objtype=None):
        # 获取属性的值,并对其进行缓存
        if obj is None:
            return self
        if self.fget is None:
            raise AttributeError("unreadable attribute")
        attr = "__cached_" + self.fget.__name__
        cached = getattr(obj, attr, None)
        if cached is None:
            cached = self.fget(obj)
            setattr(obj, attr, cached)
        return cached


# 从distutils.util模块中引入strtobool函数,用于将表示真假的字符串转换为1或0
def strtobool(val):
    """Convert a string representation of truth to true (1) or false (0).

    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
    Raises ValueError if 'val' is anything else.
    """
    val = val.lower()
    if val in {"y", "yes", "t", "true", "on", "1"}:
        return 1
    if val in {"n", "no", "f", "false", "off", "0"}:
        return 0
    raise ValueError(f"invalid truth value {val!r}")


# 从对象的repr中推断出其所属框架的函数
def infer_framework_from_repr(x):
    """
    Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
    frameworks in a smart order, without the need to import the frameworks).
    """
    representation = str(type(x))
    if representation.startswith("<class 'torch."):
        return "pt"
    elif representation.startswith("<class 'tensorflow."):
        return "tf"
    elif representation.startswith("<class 'jax"):
        return "jax"
    elif representation.startswith("<class 'numpy."):
        return "np"
    # 如果表示字符串以 "<class 'mlx." 开头,则返回字符串 "mlx"
    elif representation.startswith("<class 'mlx."):
        return "mlx"
# 返回一个按顺序排列的字典,包含了根据推断的优先框架来测试函数,优先顺序为我们从repr中能猜测到的框架首先,
# 然后是Numpy,最后是其他框架。
def _get_frameworks_and_test_func(x):
    framework_to_test = {
        "pt": is_torch_tensor,
        "tf": is_tf_tensor,
        "jax": is_jax_tensor,
        "np": is_numpy_array,
        "mlx": is_mlx_array,
    }
    preferred_framework = infer_framework_from_repr(x)
    # 首先测试推断的优先框架,然后是Numpy,最后是其他框架。
    frameworks = [] if preferred_framework is None else [preferred_framework]
    if preferred_framework != "np":
        frameworks.append("np")
    frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
    return {f: framework_to_test[f] for f in frameworks}


# 测试是否 `x` 是 `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray`, `np.ndarray` 或 `mlx.array`,
# 按照 `infer_framework_from_repr` 定义的顺序进行测试。
def is_tensor(x):
    framework_to_test_func = _get_frameworks_and_test_func(x)
    for test_func in framework_to_test_func.values():
        if test_func(x):
            return True

    # 检查是否是跟踪器
    if is_torch_fx_proxy(x):
        return True

    if is_flax_available():
        from jax.core import Tracer

        if isinstance(x, Tracer):
            return True

    return False


# 测试是否 `x` 是一个 numpy 数组。
def is_numpy_array(x):
    return _is_numpy(x)


# 判断 `x` 是否是 torch 的 tensor。
def is_torch_tensor(x):
    return False if not is_torch_available() else _is_torch(x)


# 判断 `x` 是否是 torch 的 device。
def is_torch_device(x):
    return False if not is_torch_available() else _is_torch_device(x)


# 判断 `x` 是否是 torch 的 dtype。
def is_torch_dtype(x):
    return False if not is_torch_available() else _is_torch_dtype(x)


# 判断 `x` 是否是 tensorflow 的 tensor。
def is_tf_tensor(x):
    return False if not is_tf_available() else _is_tensorflow(x)
    # 检查 TensorFlow 模块是否具有 `is_symbolic_tensor` 属性,该属性从 TensorFlow 2.14 开始可用
    if hasattr(tf, "is_symbolic_tensor"):
        # 如果有 `is_symbolic_tensor` 方法,则调用该方法来检查 x 是否为符号张量
        return tf.is_symbolic_tensor(x)
    # 如果 TensorFlow 模块没有 `is_symbolic_tensor` 方法,则直接比较 x 的类型是否为 tf.Tensor 类型
    return type(x) == tf.Tensor
# 测试 `x` 是否为 TensorFlow 符号张量(即非即时执行模式)。即使没有安装 TensorFlow 也可以安全调用。
def is_tf_symbolic_tensor(x):
    return False if not is_tf_available() else _is_tf_symbolic_tensor(x)


# 检查 `x` 是否为 Jax 数组。
def _is_jax(x):
    import jax.numpy as jnp  # noqa: F811
    return isinstance(x, jnp.ndarray)


# 测试 `x` 是否为 Jax 张量。即使没有安装 Jax 也可以安全调用。
def is_jax_tensor(x):
    return False if not is_flax_available() else _is_jax(x)


# 检查 `x` 是否为 MLX 数组。
def _is_mlx(x):
    import mlx.core as mx
    return isinstance(x, mx.array)


# 测试 `x` 是否为 MLX 数组。即使没有安装 MLX 也可以安全调用。
def is_mlx_array(x):
    return False if not is_mlx_available() else _is_mlx(x)


# 将 TensorFlow 张量、PyTorch 张量、Numpy 数组或 Python 列表转换为 Python 列表。
def to_py_obj(obj):
    framework_to_py_obj = {
        "pt": lambda obj: obj.detach().cpu().tolist(),
        "tf": lambda obj: obj.numpy().tolist(),
        "jax": lambda obj: np.asarray(obj).tolist(),
        "np": lambda obj: obj.tolist(),
    }

    if isinstance(obj, (dict, UserDict)):
        return {k: to_py_obj(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [to_py_obj(o) for o in obj]

    # 根据测试函数智能确定使用哪个框架的转换函数
    framework_to_test_func = _get_frameworks_and_test_func(obj)
    for framework, test_func in framework_to_test_func.items():
        if test_func(obj):
            return framework_to_py_obj[framework](obj)

    # tolist 也适用于 0 维的 np 数组
    if isinstance(obj, np.number):
        return obj.tolist()
    else:
        return obj


# 将 TensorFlow 张量、PyTorch 张量、Numpy 数组或 Python 列表转换为 Numpy 数组。
def to_numpy(obj):
    framework_to_numpy = {
        "pt": lambda obj: obj.detach().cpu().numpy(),
        "tf": lambda obj: obj.numpy(),
        "jax": lambda obj: np.asarray(obj),
        "np": lambda obj: obj,
    }

    if isinstance(obj, (dict, UserDict)):
        return {k: to_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return np.array(obj)

    # 根据测试函数智能确定使用哪个框架的转换函数
    framework_to_test_func = _get_frameworks_and_test_func(obj)
    for framework, test_func in framework_to_test_func.items():
        if test_func(obj):
            return framework_to_numpy[framework](obj)

    return obj


# 表示模型输出的基类,继承自 OrderedDict,作为数据类。具有一个 `__getitem__` 方法,允许按整数或切片(如元组)或字符串(如字典)进行索引,忽略 `None` 属性。否则表现类似于普通的 Python 字典。
class ModelOutput(OrderedDict):
    """
    Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
    tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
    python dictionary.
    
    <Tip warning={true}>
    """
    """
    # 注册子类作为 pytree 节点
    def __init_subclass__(cls) -> None:
        """Register subclasses as pytree nodes.

        This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
        `static_graph=True` with modules that output `ModelOutput` subclasses.
        """
        # 如果 PyTorch 可用且版本大于等于 2.2,则注册 pytree 节点
        if is_torch_available():
            if version.parse(get_torch_version()) >= version.parse("2.2"):
                _torch_pytree.register_pytree_node(
                    cls,
                    _model_output_flatten,
                    partial(_model_output_unflatten, output_type=cls),
                    serialized_type_name=f"{cls.__module__}.{cls.__name__}",
                )
            else:
                # 对于低版本的 PyTorch,使用旧的注册方式
                _torch_pytree._register_pytree_node(
                    cls,
                    _model_output_flatten,
                    partial(_model_output_unflatten, output_type=cls),
                )

    # 初始化函数,检查是否为 ModelOutput 的子类,并且必须使用 @dataclass 装饰器
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # 子类必须使用 @dataclass 装饰器,这个检查在 __init__ 中进行,因为 @dataclass 装饰器
        # 在 __init_subclass__ 之后才生效
        # 如果当前类不是 ModelOutput 本身,即当前类是其子类
        is_modeloutput_subclass = self.__class__ != ModelOutput

        # 如果当前类是 ModelOutput 的子类,并且没有使用 @dataclass 装饰器,则抛出 TypeError
        if is_modeloutput_subclass and not is_dataclass(self):
            raise TypeError(
                f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
                " This is a subclass of ModelOutput and so must use the @dataclass decorator."
            )
    def __post_init__(self):
        """初始化后检查ModelOutput数据类。

        仅在使用@dataclass装饰器时发生。
        """
        # 获取数据类的所有字段
        class_fields = fields(self)

        # 安全性和一致性检查
        if not len(class_fields):
            # 如果没有字段,则引发值错误异常
            raise ValueError(f"{self.__class__.__name__} has no fields.")
        if not all(field.default is None for field in class_fields[1:]):
            # 如果有超过一个必需字段,则引发值错误异常
            raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")

        # 获取第一个字段的值
        first_field = getattr(self, class_fields[0].name)
        # 检查其它字段是否都为None
        other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

        if other_fields_are_none and not is_tensor(first_field):
            if isinstance(first_field, dict):
                # 如果第一个字段是字典,则遍历字典项
                iterator = first_field.items()
                first_field_iterator = True
            else:
                try:
                    # 尝试迭代第一个字段
                    iterator = iter(first_field)
                    first_field_iterator = True
                except TypeError:
                    first_field_iterator = False

            # 如果第一个字段是迭代器且是(key, value)形式的迭代器
            if first_field_iterator:
                for idx, element in enumerate(iterator):
                    if (
                        not isinstance(element, (list, tuple))
                        or not len(element) == 2
                        or not isinstance(element[0], str)
                    ):
                        if idx == 0:
                            # 如果不是(key, value)形式的迭代器,将其设置为属性
                            self[class_fields[0].name] = first_field
                        else:
                            # 如果是混合迭代器,引发值错误异常
                            raise ValueError(
                                f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
                            )
                        break
                    # 设置属性为(key, value)对
                    setattr(self, element[0], element[1])
                    if element[1] is not None:
                        self[element[0]] = element[1]
            elif first_field is not None:
                # 如果第一个字段不为空,则将其设置为属性
                self[class_fields[0].name] = first_field
        else:
            # 如果存在非None的字段,则将其设置为属性
            for field in class_fields:
                v = getattr(self, field.name)
                if v is not None:
                    self[field.name] = v

    def __delitem__(self, *args, **kwargs):
        """阻止对ModelOutput实例使用``__delitem__``方法。"""
        # 抛出异常,不允许使用``__delitem__``方法
        raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")

    def setdefault(self, *args, **kwargs):
        """阻止对ModelOutput实例使用``setdefault``方法。"""
        # 抛出异常,不允许使用``setdefault``方法
        raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")

    def pop(self, *args, **kwargs):
        """阻止对ModelOutput实例使用``pop``方法。"""
        # 抛出异常,不允许使用``pop``方法
        raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
    def update(self, *args, **kwargs):
        # 抛出异常,阻止在该类实例上使用 `update` 方法
        raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")

    def __getitem__(self, k):
        if isinstance(k, str):
            # 将内部数据转换为字典,然后返回键 `k` 对应的值
            inner_dict = dict(self.items())
            return inner_dict[k]
        else:
            # 调用 `to_tuple()` 方法返回的元组,并使用 `k` 作为索引获取元组中的值
            return self.to_tuple()[k]

    def __setattr__(self, name, value):
        if name in self.keys() and value is not None:
            # 避免递归错误,不调用 `self.__setitem__` 方法
            super().__setitem__(name, value)
        # 设置对象的属性 `name` 为 `value`
        super().__setattr__(name, value)

    def __setitem__(self, key, value):
        # 调用父类的 `__setitem__` 方法设置键 `key` 对应的值 `value`
        super().__setitem__(key, value)
        # 避免递归错误,不调用 `self.__setattr__` 方法
        super().__setattr__(key, value)

    def __reduce__(self):
        if not is_dataclass(self):
            # 如果对象不是数据类,则调用父类的 `__reduce__` 方法
            return super().__reduce__()
        # 否则,获取对象所有非 `None` 属性或键的元组,并返回
        callable, _args, *remaining = super().__reduce__()
        args = tuple(getattr(self, field.name) for field in fields(self))
        return callable, args, *remaining

    def to_tuple(self) -> Tuple[Any]:
        """
        Convert self to a tuple containing all the attributes/keys that are not `None`.
        """
        # 返回包含所有非 `None` 属性或键的元组
        return tuple(self[k] for k in self.keys())
# 检查是否安装了 Torch
if is_torch_available():
    # 导入 Torch 的私有模块 _pytree
    import torch.utils._pytree as _torch_pytree

    # 将模型输出展平化的函数,返回值和上下文信息
    def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
        return list(output.values()), list(output.keys())

    # 将模型输出还原为原始结构的函数
    def _model_output_unflatten(
        values: Iterable[Any],
        context: "_torch_pytree.Context",
        output_type=None,
    ) -> ModelOutput:
        return output_type(**dict(zip(context, values)))

    # 如果 Torch 的版本大于等于 2.2,则注册 PyTree 节点
    if version.parse(get_torch_version()) >= version.parse("2.2"):
        _torch_pytree.register_pytree_node(
            ModelOutput,
            _model_output_flatten,
            partial(_model_output_unflatten, output_type=ModelOutput),
            serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
        )
    else:
        # 否则使用旧的注册方式
        _torch_pytree._register_pytree_node(
            ModelOutput,
            _model_output_flatten,
            partial(_model_output_unflatten, output_type=ModelOutput),
        )


# 定义一个显式枚举类 ExplicitEnum,继承自 str 和 Enum
class ExplicitEnum(str, Enum):
    """
    Enum with more explicit error message for missing values.
    """

    # 当枚举值缺失时,提供更明确的错误消息
    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )


# 定义一个填充策略枚举类 PaddingStrategy,继承自 ExplicitEnum
class PaddingStrategy(ExplicitEnum):
    """
    Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
    IDE.
    """

    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"


# 定义一个张量类型枚举类 TensorType,继承自 ExplicitEnum
class TensorType(ExplicitEnum):
    """
    Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
    tab-completion in an IDE.
    """

    PYTORCH = "pt"
    TENSORFLOW = "tf"
    NUMPY = "np"
    JAX = "jax"
    MLX = "mlx"


# 定义一个上下文管理器类 ContextManagers
class ContextManagers:
    """
    Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
    in the `fastcore` library.
    """

    # 初始化方法,接受一个上下文管理器列表作为参数
    def __init__(self, context_managers: List[ContextManager]):
        self.context_managers = context_managers
        self.stack = ExitStack()  # 使用 contextlib.ExitStack 创建堆栈

    # 进入上下文管理器的方法
    def __enter__(self):
        for context_manager in self.context_managers:
            self.stack.enter_context(context_manager)

    # 退出上下文管理器的方法
    def __exit__(self, *args, **kwargs):
        self.stack.__exit__(*args, **kwargs)


# 定义一个函数,检查给定的模型类是否能返回损失值
def can_return_loss(model_class):
    """
    Check if a given model can return loss.

    Args:
        model_class (`type`): The class of the model.
    """
    framework = infer_framework(model_class)  # 推断模型所属的框架
    if framework == "tf":
        signature = inspect.signature(model_class.call)  # TensorFlow 模型
    elif framework == "pt":
        signature = inspect.signature(model_class.forward)  # PyTorch 模型
    else:
        signature = inspect.signature(model_class.__call__)  # Flax 模型
    # 遍历函数签名的参数列表
    for p in signature.parameters:
        # 检查当前参数是否为 "return_loss",且其默认值为 True
        if p == "return_loss" and signature.parameters[p].default is True:
            # 如果满足条件,返回 True
            return True
    
    # 如果未找到符合条件的参数,返回 False
    return False
# 查找给定模型使用的标签参数列表
def find_labels(model_class):
    model_name = model_class.__name__  # 获取模型类的名称
    framework = infer_framework(model_class)  # 推断模型使用的框架
    if framework == "tf":
        signature = inspect.signature(model_class.call)  # 获取TensorFlow模型的调用签名
    elif framework == "pt":
        signature = inspect.signature(model_class.forward)  # 获取PyTorch模型的前向方法签名
    else:
        signature = inspect.signature(model_class.__call__)  # 获取Flax模型的调用方法签名

    if "QuestionAnswering" in model_name:  # 如果模型名称中包含"QuestionAnswering"
        return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]  # 返回标签相关的参数列表
    else:
        return [p for p in signature.parameters if "label" in p]  # 返回标签相关的参数列表


# 将嵌套字典展开为单层字典
def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
    def _flatten_dict(d, parent_key="", delimiter="."):
        for k, v in d.items():
            key = str(parent_key) + delimiter + str(k) if parent_key else k
            if v and isinstance(v, MutableMapping):
                yield from flatten_dict(v, key, delimiter=delimiter).items()  # 递归展开嵌套字典
            else:
                yield key, v  # 直接添加键值对到展开的字典中

    return dict(_flatten_dict(d, parent_key, delimiter))


# 提供工作目录或临时目录的上下文管理器
@contextmanager
def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
    if use_temp_dir:
        with tempfile.TemporaryDirectory() as tmp_dir:
            yield tmp_dir  # 使用临时目录作为上下文环境
    else:
        yield working_dir  # 使用指定的工作目录作为上下文环境


# 框架无关的数组转置函数,支持numpy、torch、tensorflow和jax的数组
def transpose(array, axes=None):
    if is_numpy_array(array):  # 如果是numpy数组
        return np.transpose(array, axes=axes)  # 使用numpy的转置函数
    elif is_torch_tensor(array):  # 如果是torch张量
        return array.T if axes is None else array.permute(*axes)  # 使用torch的转置或者按指定轴排列
    elif is_tf_tensor(array):  # 如果是tensorflow张量
        import tensorflow as tf
        return tf.transpose(array, perm=axes)  # 使用tensorflow的转置函数
    elif is_jax_tensor(array):  # 如果是jax张量
        return jnp.transpose(array, axes=axes)  # 使用jax的转置函数
    else:
        raise ValueError(f"Type not supported for transpose: {type(array)}.")  # 抛出类型不支持的异常


# 框架无关的数组重塑函数,支持numpy、torch、tensorflow和jax的数组
def reshape(array, newshape):
    if is_numpy_array(array):  # 如果是numpy数组
        return np.reshape(array, newshape)  # 使用numpy的重塑函数
    elif is_torch_tensor(array):  # 如果是torch张量
        return array.reshape(*newshape)  # 使用torch的重塑方法
    elif is_tf_tensor(array):  # 如果是tensorflow张量
        import tensorflow as tf
        return tf.reshape(array, newshape)  # 使用tensorflow的重塑函数
    elif is_jax_tensor(array):  # 如果是jax张量
        return jnp.reshape(array, newshape)  # 使用jax的重塑函数
    else:
        raise ValueError(f"Type not supported for reshape: {type(array)}.")  # 抛出类型不支持的异常


# 框架无关的数组挤压函数,支持numpy、torch、tensorflow和jax的数组
def squeeze(array, axis=None):
    if is_numpy_array(array):  # 如果是numpy数组
        return np.squeeze(array, axis=axis)  # 使用numpy的挤压函数
    # 如果输入的数组是 PyTorch 张量,则进行挤压操作,去除维度为1的轴
    elif is_torch_tensor(array):
        return array.squeeze() if axis is None else array.squeeze(dim=axis)
    # 如果输入的数组是 TensorFlow 张量,则导入 TensorFlow 库并进行挤压操作,去除指定的轴
    elif is_tf_tensor(array):
        import tensorflow as tf

        return tf.squeeze(array, axis=axis)
    # 如果输入的数组是 JAX 张量,则进行挤压操作,去除指定的轴
    elif is_jax_tensor(array):
        return jnp.squeeze(array, axis=axis)
    # 如果输入的数组类型不被支持,则抛出异常并显示错误信息
    else:
        raise ValueError(f"Type not supported for squeeze: {type(array)}.")
# 定义一个函数,用于在不同深度学习框架下扩展张量的维度
def expand_dims(array, axis):
    """
    Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
    arrays.
    """
    # 如果输入数组是 NumPy 数组,则使用 NumPy 的 `expand_dims` 函数
    if is_numpy_array(array):
        return np.expand_dims(array, axis)
    # 如果输入数组是 PyTorch 张量,则使用 PyTorch 的 `unsqueeze` 函数
    elif is_torch_tensor(array):
        return array.unsqueeze(dim=axis)
    # 如果输入数组是 TensorFlow 张量,则使用 TensorFlow 的 `expand_dims` 函数
    elif is_tf_tensor(array):
        import tensorflow as tf
        
        return tf.expand_dims(array, axis=axis)
    # 如果输入数组是 Jax 张量,则使用 Jax 的 `expand_dims` 函数
    elif is_jax_tensor(array):
        return jnp.expand_dims(array, axis=axis)
    else:
        # 如果输入数组类型不被支持,则抛出 ValueError 异常
        raise ValueError(f"Type not supported for expand_dims: {type(array)}.")


# 定义一个函数,用于计算不同深度学习框架下张量的大小
def tensor_size(array):
    """
    Framework-agnostic version of `numpy.size` that will work on torch/TensorFlow/Jax tensors as well as NumPy arrays.
    """
    # 如果输入数组是 NumPy 数组,则返回数组的大小
    if is_numpy_array(array):
        return np.size(array)
    # 如果输入数组是 PyTorch 张量,则返回张量的元素个数
    elif is_torch_tensor(array):
        return array.numel()
    # 如果输入数组是 TensorFlow 张量,则返回张量的大小
    elif is_tf_tensor(array):
        import tensorflow as tf
        
        return tf.size(array)
    # 如果输入数组是 Jax 张量,则返回张量的大小
    elif is_jax_tensor(array):
        return array.size
    else:
        # 如果输入数组类型不被支持,则抛出 ValueError 异常
        raise ValueError(f"Type not supported for tensor_size: {type(array)}.")


# 定义一个函数,将 repo_id 的信息添加到给定的自动映射 auto_map 中
def add_model_info_to_auto_map(auto_map, repo_id):
    """
    Adds the information of the repo_id to a given auto map.
    """
    # 遍历 auto_map 的键值对
    for key, value in auto_map.items():
        # 如果值是列表或元组,则将每个元素前添加 repo_id,避免重复添加
        if isinstance(value, (tuple, list)):
            auto_map[key] = [f"{repo_id}--{v}" if (v is not None and "--" not in v) else v for v in value]
        # 如果值不是 None 且不包含 "--",则在值前添加 repo_id
        elif value is not None and "--" not in value:
            auto_map[key] = f"{repo_id}--{value}"

    # 返回更新后的 auto_map
    return auto_map


# 定义一个函数,推断给定模型类的深度学习框架
def infer_framework(model_class):
    """
    Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
    classes are imported or available.
    """
    # 遍历模型类的方法解析顺序(Method Resolution Order)
    for base_class in inspect.getmro(model_class):
        module = base_class.__module__
        name = base_class.__name__
        # 如果基类模块名以 "tensorflow" 或 "keras" 开头,或者基类名为 "TFPreTrainedModel",则推断为 TensorFlow 框架
        if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
            return "tf"
        # 如果基类模块名以 "torch" 开头,或者基类名为 "PreTrainedModel",则推断为 PyTorch 框架
        elif module.startswith("torch") or name == "PreTrainedModel":
            return "pt"
        # 如果基类模块名以 "flax" 或 "jax" 开头,或者基类名为 "FlaxPreTrainedModel",则推断为 Jax/Flax 框架
        elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
            return "flax"
    else:
        # 如果无法推断出框架,则抛出 TypeError 异常
        raise TypeError(f"Could not infer framework from class {model_class}.")

.\utils\hp_naming.py

    # 复制标准库 copy 和 re
    import copy
    import re

# 试验短命名器类
class TrialShortNamer:
    # 类变量 PREFIX 初始化为 "hp"
    PREFIX = "hp"
    # 类变量 DEFAULTS 初始化为空字典
    DEFAULTS = {}
    # 类变量 NAMING_INFO 初始化为 None
    NAMING_INFO = None

    # 类方法,设置类变量 PREFIX 和 DEFAULTS,并调用 build_naming_info 方法
    @classmethod
    def set_defaults(cls, prefix, defaults):
        cls.PREFIX = prefix
        cls.DEFAULTS = defaults
        cls.build_naming_info()

    # 静态方法,为单词生成短名称
    @staticmethod
    def shortname_for_word(info, word):
        # 如果单词长度为0,返回空字符串
        if len(word) == 0:
            return ""
        # 初始化 short_word 为 None
        short_word = None
        # 如果单词中包含数字,抛出异常
        if any(char.isdigit() for char in word):
            raise Exception(f"Parameters should not contain numbers: '{word}' contains a number")
        # 如果单词已经在 info 的 "short_word" 中,直接返回其短名称
        if word in info["short_word"]:
            return info["short_word"][word]
        # 尝试生成单词的短前缀,避免与已有的短名称冲突
        for prefix_len in range(1, len(word) + 1):
            prefix = word[:prefix_len]
            if prefix in info["reverse_short_word"]:
                continue
            else:
                short_word = prefix
                break

        # 如果未能生成短前缀,则采用备用方法生成唯一的短名称
        if short_word is None:
            # 备用方法:将数字转换为字母
            def int_to_alphabetic(integer):
                s = ""
                while integer != 0:
                    s = chr(ord("A") + integer % 10) + s
                    integer //= 10
                return s

            i = 0
            while True:
                sword = word + "#" + int_to_alphabetic(i)
                if sword in info["reverse_short_word"]:
                    continue
                else:
                    short_word = sword
                    break

        # 将生成的短名称存储在 info 中,并更新反向映射
        info["short_word"][word] = short_word
        info["reverse_short_word"][short_word] = word
        return short_word

    # 静态方法,为参数名生成短名称
    @staticmethod
    def shortname_for_key(info, param_name):
        # 将参数名分割成单词列表
        words = param_name.split("_")

        # 为每个单词生成短名称部分
        shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words]

        # 尝试创建无分隔符的短名称,若存在冲突则使用分隔符分隔单词
        separators = ["", "_"]

        for separator in separators:
            shortname = separator.join(shortname_parts)
            if shortname not in info["reverse_short_param"]:
                info["short_param"][param_name] = shortname
                info["reverse_short_param"][shortname] = param_name
                return shortname

        # 如果无法避免冲突,则返回原参数名
        return param_name

    @staticmethod
    def add_new_param_name(info, param_name):
        # 使用 TrialShortNamer 提供的方法生成 param_name 的短名称,并添加到 info 字典中
        short_name = TrialShortNamer.shortname_for_key(info, param_name)
        # 将 param_name 和其对应的 short_name 存储在 info 字典的 "short_param" 和 "reverse_short_param" 中
        info["short_param"][param_name] = short_name
        info["reverse_short_param"][short_name] = param_name

    @classmethod
    def build_naming_info(cls):
        # 如果 NAMING_INFO 已经存在,则直接返回,避免重复构建
        if cls.NAMING_INFO is not None:
            return

        # 初始化一个空的命名信息字典
        info = {
            "short_word": {},
            "reverse_short_word": {},
            "short_param": {},
            "reverse_short_param": {},
        }

        # 获取类的默认参数列表
        field_keys = list(cls.DEFAULTS.keys())

        # 为每个参数调用 add_new_param_name 方法,构建参数名和短名称的映射关系
        for k in field_keys:
            cls.add_new_param_name(info, k)

        # 将构建好的命名信息保存到类的 NAMING_INFO 属性中
        cls.NAMING_INFO = info

    @classmethod
    def shortname(cls, params):
        # 确保命名信息已经构建
        cls.build_naming_info()
        # 断言类的 PREFIX 属性不为空
        assert cls.PREFIX is not None
        # 创建一个名称列表,起始部分为类的 PREFIX 属性的拷贝
        name = [copy.copy(cls.PREFIX)]

        # 遍历传入的参数字典
        for k, v in params.items():
            # 如果参数 k 不在默认参数列表中,则抛出异常
            if k not in cls.DEFAULTS:
                raise Exception(f"You should provide a default value for the param name {k} with value {v}")
            # 如果参数 v 等于默认值,则不将其添加到名称中
            if v == cls.DEFAULTS[k]:
                continue

            # 根据参数名 k 获取其短名称
            key = cls.NAMING_INFO["short_param"][k]

            # 如果参数值是布尔类型,则转换为整数形式
            if isinstance(v, bool):
                v = 1 if v else 0

            # 确定连接参数名和参数值的分隔符
            sep = "" if isinstance(v, (int, float)) else "-"
            # 构建参数名和参数值的字符串表示,并添加到名称列表中
            e = f"{key}{sep}{v}"
            name.append(e)

        # 返回连接后的名称字符串,使用下划线连接各部分
        return "_".join(name)

    @classmethod
    def parse_repr(cls, repr):
        # 截取 repr 字符串,去除前缀部分,得到实际的参数表示部分
        repr = repr[len(cls.PREFIX) + 1 :]
        # 如果 repr 为空字符串,则初始化值列表为空
        if repr == "":
            values = []
        else:
            # 否则,按下划线分割 repr 字符串,得到值列表
            values = repr.split("_")

        # 初始化参数字典
        parameters = {}

        # 遍历值列表中的每个值
        for value in values:
            # 如果值中包含 "-" 符号,则按照该符号分割键和值
            if "-" in value:
                p_k, p_v = value.split("-")
            else:
                # 否则,提取键部分,去除数字和小数点,并转换为字符串
                p_k = re.sub("[0-9.]", "", value)
                # 提取值部分,去除非数字和小数点字符,并转换为浮点数
                p_v = float(re.sub("[^0-9.]", "", value))

            # 根据短参数名 p_k 获取原始参数名,并将其与值 p_v 存储到参数字典中
            key = cls.NAMING_INFO["reverse_short_param"][p_k]
            parameters[key] = p_v

        # 对于每个默认参数,如果其在参数字典中不存在,则添加默认值
        for k in cls.DEFAULTS:
            if k not in parameters:
                parameters[k] = cls.DEFAULTS[k]

        # 返回解析得到的参数字典
        return parameters

.\utils\hub.py

# 标准版权声明,声明此代码版权归 HuggingFace 团队所有
#
# 根据 Apache License, Version 2.0 许可证进行许可,除非符合许可证要求,否则不得使用此文件
#
# 导入必要的库和模块
import json  # 导入处理 JSON 的模块
import os  # 导入操作系统相关功能的模块
import re  # 导入正则表达式模块
import shutil  # 导入文件操作相关模块
import sys  # 导入系统相关的模块
import tempfile  # 导入临时文件目录相关模块
import traceback  # 导入追踪异常的模块
import warnings  # 导入警告处理模块
from concurrent import futures  # 导入并发处理模块
from pathlib import Path  # 导入处理路径相关功能的模块
from typing import Dict, List, Optional, Tuple, Union  # 导入类型提示相关模块
from urllib.parse import urlparse  # 导入处理 URL 解析的模块
from uuid import uuid4  # 导入生成 UUID 的模块

import huggingface_hub  # 导入 HuggingFace Hub 库
import requests  # 导入处理 HTTP 请求的模块
from huggingface_hub import (
    _CACHED_NO_EXIST,
    CommitOperationAdd,
    ModelCard,
    ModelCardData,
    constants,
    create_branch,
    create_commit,
    create_repo,
    get_hf_file_metadata,
    hf_hub_download,
    hf_hub_url,
    try_to_load_from_cache,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get  # 导入文件下载相关模块
from huggingface_hub.utils import (
    EntryNotFoundError,
    GatedRepoError,
    HFValidationError,
    LocalEntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
    build_hf_headers,
    hf_raise_for_status,
    send_telemetry,
)
from huggingface_hub.utils._deprecation import _deprecate_method  # 导入废弃方法相关模块
from requests.exceptions import HTTPError  # 导入处理 HTTP 错误的模块

from . import __version__, logging  # 导入当前模块的版本和日志模块
from .generic import working_or_temp_dir  # 导入通用功能的临时工作目录处理模块
from .import_utils import (
    ENV_VARS_TRUE_VALUES,
    _tf_version,
    _torch_version,
    is_tf_available,
    is_torch_available,
    is_training_run_on_sagemaker,
)  # 导入导入相关的实用工具模块

from .logging import tqdm  # 导入显示进度条的日志模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,命名为 logger,禁止 pylint 提示

_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False


def is_offline_mode():
    # 返回当前是否处于离线模式的布尔值
    return _is_offline_mode


torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
default_cache_path = constants.default_cache_path  # 获取默认缓存路径
old_default_cache_path = os.path.join(torch_cache_home, "transformers")

# 确定默认缓存目录。为了向后兼容性,考虑了大量遗留环境变量。
# 最佳设置缓存路径的方式是使用环境变量 HF_HOME。详情请查看文档页:https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
#
# 在代码中,使用 `HF_HUB_CACHE` 作为默认缓存路径。这个变量由库设置,保证设置为正确的值。
#
# TODO: 为 v5 版本进行清理?
# 设置用于缓存预训练 BERT 模型的环境变量,如果未设置则使用默认值 constants.HF_HUB_CACHE
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
# 设置用于缓存 PyTorch Transformers 的环境变量,如果未设置则使用上面设置的 PYTORCH_PRETRAINED_BERT_CACHE
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
# 设置用于缓存 Transformers 的环境变量,如果未设置则使用上面设置的 PYTORCH_TRANSFORMERS_CACHE
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)

# 如果旧的默认缓存路径存在,并且新的缓存路径 constants.HF_HUB_CACHE 不存在,并且没有设置相关环境变量,则执行一次性迁移
if (
    os.path.isdir(old_default_cache_path)
    and not os.path.isdir(constants.HF_HUB_CACHE)
    and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
    and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
    and "TRANSFORMERS_CACHE" not in os.environ
):
    # 发出警告说明在 Transformers v4.22.0 中,默认的模型下载缓存路径从 '~/.cache/torch/transformers' 变更为 '~/.cache/huggingface/hub'
    # 由于当前 '~/.cache/torch/transformers' 存在且未被覆盖,所以执行移动操作到 '~/.cache/huggingface/hub',避免重复下载已缓存的模型
    logger.warning(
        "In Transformers v4.22.0, the default path to cache downloaded models changed from"
        " '~/.cache/torch/transformers' to '~/.cache/huggingface/hub'. Since you don't seem to have"
        " overridden and '~/.cache/torch/transformers' is a directory that exists, we're moving it to"
        " '~/.cache/huggingface/hub' to avoid redownloading models you have already in the cache. You should"
        " only see this message once."
    )
    # 执行实际的文件移动操作
    shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)

# 设置用于缓存 HF 模块的环境变量,默认路径为 constants.HF_HOME 下的 modules 文件夹
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
# 定义 Transformers 动态模块的名称
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
# 生成一个新的会话 ID,用于标识当前会话的唯一性
SESSION_ID = uuid4().hex

# 对旧的环境变量进行弃用警告
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
    if os.getenv(key) is not None:
        # 发出警告,提示使用这些环境变量已经被弃用,建议在 Transformers v5 中使用 `HF_HOME` 替代
        warnings.warn(
            f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
            FutureWarning,
        )

# 定义 Hugging Face 模型存储在 S3 桶中的前缀
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
# 定义 Hugging Face CDN 的分发前缀
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"

# 检查是否处于 Hugging Face CO 的测试阶段,根据环境变量 HUGGINGFACE_CO_STAGING 来判断
_staging_mode = os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
# 根据是否处于测试阶段选择默认的 API 终端地址
_default_endpoint = "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co"

# 设置用于解析模型的终端地址,默认为 _default_endpoint
HUGGINGFACE_CO_RESOLVE_ENDPOINT = _default_endpoint
# 如果设置了环境变量 HUGGINGFACE_CO_RESOLVE_ENDPOINT,则发出警告提示使用 HF_ENDPOINT 替代
if os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None) is not None:
    warnings.warn(
        "Using the environment variable `HUGGINGFACE_CO_RESOLVE_ENDPOINT` is deprecated and will be removed in "
        "Transformers v5. Use `HF_ENDPOINT` instead.",
        FutureWarning,
    )
    # 使用环境变量 HF_ENDPOINT 的值来更新 HUGGINGFACE_CO_RESOLVE_ENDPOINT
    HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOINT", None)
# 使用环境变量 HF_ENDPOINT 的值来更新 HUGGINGFACE_CO_RESOLVE_ENDPOINT
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_RESOLVE_ENDPOINT)
# 构建 Hugging Face CO 的模型解析路径模板
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
# 构建用于上报示例数据的 Telemetry API 地址
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"


def _get_cache_file_to_return(
    path_or_repo_id: str, full_filename: str, cache_dir: Union[str, Path, None] = None, revision: Optional[str] = None
):
    # 尝试查找缓存文件,如果存在且未过期,则返回该文件(未完成的部分)
    # 定义函数:尝试从缓存中加载文件并返回路径,若找不到或不存在,则返回None
    def try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=None, revision=None):
        # 调用尝试从缓存中直接加载文件路径的内部函数(未提供,需根据实际情况补充实现)
        resolved_file = try_to_load_from_cache_inner(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
        # 若返回的路径不为None且不是缓存不存在的标志(_CACHED_NO_EXIST),说明找到了正确路径
        if resolved_file is not None and resolved_file != _CACHED_NO_EXIST:
            # 返回找到的文件路径
            return resolved_file
        # 如果上述条件均不满足,说明未找到文件或存在异常情况,返回None
        return None
    
    
    # 注意事项中未提供try_to_load_from_cache_inner的具体实现,以下是假设实现供参考
    
    
    def try_to_load_from_cache_inner(path_or_repo_id, full_filename, cache_dir=None, revision=None):
        # 数据处理逻辑:尝试使用给定的repo_id和full_filename从缓存目录cache_dir中加载文件路径
        cache_file_path = None
        if cache_dir is not None:
            # 构建缓存文件路径
            cache_file_path = os.path.join(cache_dir, full_filename)
            # 检查缓存文件是否存在
            if os.path.exists(cache_file_path):
                # 省略读取缓存、验证等具体逻辑,假设通过验证后返回有效文件路径
                # 这里返回一个伪有效路径,代表实际应从缓存中加载文件路径的具体处理
                # 在实际应用中应增强逻辑以正确处理缓存文件的读取和验证
                return os.fspath(cache_file_path)
            else:
                # 文件不存在于缓存目录,不返回任何值表示未找到缓存文件
                return None
        else:
            # 缓存目录未指定,返回None表示不尝试从缓存加载文件
            return None
# 检查给定的 URL 或文件名是否是远程 URL
def is_remote_url(url_or_filename):
    # 解析 URL 或文件名,获取其组成部分
    parsed = urlparse(url_or_filename)
    # 判断解析结果中的 scheme 是否为 http 或 https
    return parsed.scheme in ("http", "https")


# TODO: 在完全弃用后删除此函数
# TODO? 同时也要从 './examples/research_projects/lxmert/utils.py' 中移除
# TODO? 同时也要从 './examples/research_projects/visual_bert/utils.py' 中移除
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
    """
    返回一个列表,表示本地缓存的模型二进制文件。每个元组的形式为 `(model_url, etag, size_MB)`。
    只有以 *.bin* 结尾的 URL 文件名会被添加到列表中。

    Args:
        cache_dir (`Union[str, Path]`, *optional*):
            要在其中搜索模型的缓存目录。如果未设置,将默认使用 transformers 的缓存目录。

    Returns:
        List[Tuple]: 包含 `(model_url, etag, size_MB)` 形式的元组列表
    """
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    elif isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)
    # 如果缓存目录不存在,返回空列表
    if not os.path.isdir(cache_dir):
        return []

    cached_models = []
    # 遍历缓存目录中的文件
    for file in os.listdir(cache_dir):
        if file.endswith(".json"):
            meta_path = os.path.join(cache_dir, file)
            # 打开元数据文件并加载 JSON 数据
            with open(meta_path, encoding="utf-8") as meta_file:
                metadata = json.load(meta_file)
                url = metadata["url"]
                etag = metadata["etag"]
                # 如果 URL 以 .bin 结尾,计算文件大小并添加到列表中
                if url.endswith(".bin"):
                    size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
                    cached_models.append((url, etag, size_MB))

    return cached_models


def define_sagemaker_information():
    try:
        # 获取当前实例的容器元数据
        instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
        dlc_container_used = instance_data["Image"]
        dlc_tag = instance_data["Image"].split(":")[1]
    except Exception:
        # 如果获取失败,设置为 None
        dlc_container_used = None
        dlc_tag = None

    # 解析 SageMaker 框架的参数
    sagemaker_params = json.loads(os.getenv("SM_FRAMEWORK_PARAMS", "{}"))
    # 检查是否启用了 SageMaker 的分布式训练
    runs_distributed_training = True if "sagemaker_distributed_dataparallel_enabled" in sagemaker_params else False
    # 从环境变量中提取账户 ID
    account_id = os.getenv("TRAINING_JOB_ARN").split(":")[4] if "TRAINING_JOB_ARN" in os.environ else None

    # 构建包含 SageMaker 相关信息的字典对象
    sagemaker_object = {
        "sm_framework": os.getenv("SM_FRAMEWORK_MODULE", None),
        "sm_region": os.getenv("AWS_REGION", None),
        "sm_number_gpu": os.getenv("SM_NUM_GPUS", 0),
        "sm_number_cpu": os.getenv("SM_NUM_CPUS", 0),
        "sm_distributed_training": runs_distributed_training,
        "sm_deep_learning_container": dlc_container_used,
        "sm_deep_learning_container_tag": dlc_tag,
        "sm_account_id": account_id,
    }
    return sagemaker_object


def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
    """
    # 格式化用户代理字符串,包含请求的基本信息
    """
    ua = f"transformers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
    if is_torch_available():
        ua += f"; torch/{_torch_version}"
    if is_tf_available():
        ua += f"; tensorflow/{_tf_version}"
    if constants.HF_HUB_DISABLE_TELEMETRY:
        return ua + "; telemetry/off"
    if is_training_run_on_sagemaker():
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in define_sagemaker_information().items())
    # CI will set this value to True
    if os.environ.get("TRANSFORMERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
        ua += "; is_ci/true"
    if isinstance(user_agent, dict):
        ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
    elif isinstance(user_agent, str):
        ua += "; " + user_agent
    返回格式化后的用户代理字符串
    """
# 从已解析的文件名中提取提交哈希值,并用于缓存文件。
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
    # 如果 resolved_file 为 None 或者 commit_hash 不为 None,则直接返回 commit_hash
    if resolved_file is None or commit_hash is not None:
        return commit_hash
    
    # 将 resolved_file 转换为标准的 POSIX 路径字符串
    resolved_file = str(Path(resolved_file).as_posix())
    
    # 使用正则表达式在 resolved_file 中搜索匹配 'snapshots/([^/]+)/' 的内容
    search = re.search(r"snapshots/([^/]+)/", resolved_file)
    
    # 如果未找到匹配项,则返回 None
    if search is None:
        return None
    
    # 从搜索结果中获取第一个捕获组,即提取的 commit_hash
    commit_hash = search.groups()[0]
    
    # 如果提取的 commit_hash 符合预期的格式(由 REGEX_COMMIT_HASH 定义),则返回 commit_hash,否则返回 None
    return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None


# 尝试在本地文件夹或存储库中定位文件,如果必要则下载并缓存它。
def cached_file(
    path_or_repo_id: Union[str, os.PathLike],
    filename: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    subfolder: str = "",
    repo_type: Optional[str] = None,
    user_agent: Optional[Union[str, Dict[str, str]]] = None,
    _raise_exceptions_for_gated_repo: bool = True,
    _raise_exceptions_for_missing_entries: bool = True,
    _raise_exceptions_for_connection_errors: bool = True,
    _commit_hash: Optional[str] = None,
    **deprecated_kwargs,
) -> Optional[str]:
    """
    Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
    """
    # 获取 deprecated_kwargs 字典中的 use_auth_token 键对应的值,并将其从字典中移除
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    # 如果 use_auth_token 参数不为 None,则发出警告信息,说明该参数已弃用,并将在 Transformers 版本 v5 中移除。建议使用 `token` 参数替代。
    # 引发 FutureWarning 警告。
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果同时指定了 token 参数和 use_auth_token 参数,则抛出 ValueError 异常。
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 token 参数设置为 use_auth_token 参数的值。
        token = use_auth_token

    # Private arguments
    #     _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
    #         None.
    #     _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
    #         None.
    #     _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
    #         None.
    #     _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
    #         a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.

    # 如果处于离线模式且 local_files_only 参数为 False,则设置 local_files_only 参数为 True,并输出相应的日志信息。
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

    # 如果 subfolder 参数为 None,则将其设置为空字符串。
    if subfolder is None:
        subfolder = ""

    # 将 path_or_repo_id 参数转换为字符串。
    path_or_repo_id = str(path_or_repo_id)
    # 将 subfolder 和 filename 参数拼接成完整的文件路径。
    full_filename = os.path.join(subfolder, filename)

    # 如果 path_or_repo_id 参数指定的路径是一个目录,则解析文件路径并检查文件是否存在。
    if os.path.isdir(path_or_repo_id):
        resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
        # 如果解析后的文件路径不是一个文件,并且 _raise_exceptions_for_missing_entries 参数为 True,则抛出 EnvironmentError 异常。
        if not os.path.isfile(resolved_file):
            if _raise_exceptions_for_missing_entries:
                raise EnvironmentError(
                    f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
                    f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
                )
            # 如果 _raise_exceptions_for_missing_entries 参数为 False,则返回 None。
            else:
                return None
        # 返回解析后的文件路径。
        return resolved_file

    # 如果 cache_dir 参数为 None,则将其设置为 TRANSFORMERS_CACHE 变量的值。
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE

    # 如果 cache_dir 参数是 Path 对象,则将其转换为字符串。
    if isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    # 如果 _commit_hash 参数不为 None 并且 force_download 参数为 False,则尝试从缓存中加载文件。
    if _commit_hash is not None and not force_download:
        # 如果文件在指定的 _commit_hash 下被缓存,则直接返回该文件。
        resolved_file = try_to_load_from_cache(
            path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
        )
        # 如果成功加载文件,则根据情况返回解析后的文件路径、None 或抛出异常。
        if resolved_file is not None:
            if resolved_file is not _CACHED_NO_EXIST:
                return resolved_file
            elif not _raise_exceptions_for_missing_entries:
                return None
            else:
                raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")

    # 调用 http_user_agent 函数来处理 user_agent 参数。
    user_agent = http_user_agent(user_agent)
    try:
        # 尝试从 URL 或缓存加载文件
        resolved_file = hf_hub_download(
            path_or_repo_id,
            filename,
            subfolder=None if len(subfolder) == 0 else subfolder,
            repo_type=repo_type,
            revision=revision,
            cache_dir=cache_dir,
            user_agent=user_agent,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            token=token,
            local_files_only=local_files_only,
        )
    except GatedRepoError as e:
        # 如果遇到受限制的仓库错误,则尝试从缓存中获取文件以返回
        resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
        # 如果已获取文件或不应为受限制的仓库错误引发异常,则返回解析的文件
        if resolved_file is not None or not _raise_exceptions_for_gated_repo:
            return resolved_file
        # 否则,引发环境错误并显示详细信息
        raise EnvironmentError(
            "You are trying to access a gated repo.\nMake sure to have access to it at "
            f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
        ) from e
    except RepositoryNotFoundError as e:
        # 如果仓库未找到,则引发环境错误并显示详细信息
        raise EnvironmentError(
            f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
            "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
            "having permission to this repo either by logging in with `huggingface-cli login` or by passing "
            "`token=<your_token>`"
        ) from e
    except RevisionNotFoundError as e:
        # 如果找不到指定的版本号,则引发环境错误并显示详细信息
        raise EnvironmentError(
            f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
            "for this model name. Check the model page at "
            f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
        ) from e
    except LocalEntryNotFoundError as e:
        # 如果本地条目未找到,则尝试从缓存获取文件以返回
        resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
        # 如果已获取文件或不应为丢失条目或连接错误引发异常,则返回解析的文件
        if (
            resolved_file is not None
            or not _raise_exceptions_for_missing_entries
            or not _raise_exceptions_for_connection_errors
        ):
            return resolved_file
        # 否则,引发环境错误并显示详细信息
        raise EnvironmentError(
            f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
            f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
            f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
            " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
        ) from e
    # 处理 EntryNotFoundError 异常,如果设置了不抛出缺失条目异常,则返回 None
    except EntryNotFoundError as e:
        if not _raise_exceptions_for_missing_entries:
            return None
        # 如果未指定修订版本,则默认为 "main"
        if revision is None:
            revision = "main"
        # 抛出环境错误,指示指定的路径或 repo_id 中不存在指定的完整文件名
        raise EnvironmentError(
            f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
            f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
        ) from e
    # 处理 HTTPError 异常
    except HTTPError as err:
        # 尝试获取缓存中已解决的文件,如果存在或设置了不抛出连接错误异常,则返回该文件
        resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
        if resolved_file is not None or not _raise_exceptions_for_connection_errors:
            return resolved_file
        # 抛出环境错误,指示加载指定路径或 repo_id 时发生特定的连接错误
        raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
    # 处理 HFValidationError 异常
    except HFValidationError as e:
        # 抛出环境错误,指示路径或模型 ID 的提供方式不正确,应提供本地文件夹的路径或 Hub 上模型的 repo_id
        raise EnvironmentError(
            f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
        ) from e
    # 返回已解决的文件(如果有)
    return resolved_file
# TODO: deprecate `get_file_from_repo` or document it differently?
#       Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
#       there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
#       IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
# 定义了一个函数 `get_file_from_repo`,用于从本地文件夹或仓库中获取文件,并在需要时下载和缓存它。
def get_file_from_repo(
    path_or_repo: Union[str, os.PathLike],  # 参数1: 文件路径或仓库位置,可以是字符串或PathLike对象
    filename: str,                          # 参数2: 文件名,表示需要获取的文件名
    cache_dir: Optional[Union[str, os.PathLike]] = None,  # 参数3: 缓存目录的路径,可选,默认为None
    force_download: bool = False,           # 参数4: 是否强制下载文件,默认为False
    resume_download: bool = False,          # 参数5: 是否继续下载(即断点续传),默认为False
    proxies: Optional[Dict[str, str]] = None,  # 参数6: 代理设置,可选,默认为None
    token: Optional[Union[bool, str]] = None,   # 参数7: 访问令牌,可选,默认为None
    revision: Optional[str] = None,         # 参数8: 仓库的版本或分支,可选,默认为None
    local_files_only: bool = False,          # 参数9: 是否只使用本地文件,不从仓库下载,默认为False
    subfolder: str = "",                    # 参数10: 仓库中的子文件夹路径,默认为空字符串
    **deprecated_kwargs,                    # 其他已废弃的关键字参数将被收集到deprecated_kwargs中
):
    """
    Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
    """
    Args:
        path_or_repo (`str` or `os.PathLike`):
            This can be either:

            - a string, the *model id* of a model repo on huggingface.co.
            - a path to a *directory* potentially containing the file.
        filename (`str`):
            The name of the file to locate in `path_or_repo`.
        cache_dir (`str` or `os.PathLike`, *optional*):
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
        resume_download (`bool`, *optional*, defaults to `False`):
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
        token (`str` or *bool*, *optional*):
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
            when running `huggingface-cli login` (stored in `~/.huggingface`).
        revision (`str`, *optional*, defaults to `"main"`):
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
            identifier allowed by git.
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
        subfolder (`str`, *optional*, defaults to `""`):
            In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
            specify the folder name here.

    <Tip>

    Passing `token=True` is required when you want to use a private model.

    </Tip>

    Returns:
        `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
        file does not exist.

    Examples:

    ```
    # Download a tokenizer configuration from huggingface.co and cache.
    tokenizer_config = get_file_from_repo("google-bert/bert-base-uncased", "tokenizer_config.json")
    # This model does not have a tokenizer config so the result will be None.
    tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
    ```
    """
    # Check for deprecated argument and assign its value to use_auth_token
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    # 如果 use_auth_token 参数不为 None,则发出警告,指出该参数将在 Transformers 的 v5 版本中被移除,建议使用 token 参数。
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果 token 参数也不为 None,则抛出 ValueError,因为不能同时指定 `token` 和 `use_auth_token`。
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 use_auth_token 的值赋给 token 参数
        token = use_auth_token

    # 调用 cached_file 函数,传入各种参数,并返回其结果
    return cached_file(
        path_or_repo_id=path_or_repo,
        filename=filename,
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        token=token,
        revision=revision,
        local_files_only=local_files_only,
        subfolder=subfolder,
        _raise_exceptions_for_gated_repo=False,
        _raise_exceptions_for_missing_entries=False,
        _raise_exceptions_for_connection_errors=False,
    )
# 下载指定的 URL 对应的文件到临时文件中
def download_url(url, proxies=None):
    # 发出警告,提醒使用者此函数即将在 Transformers v5 版本中被移除
    warnings.warn(
        f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
        " v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
        " that this is not compatible with the caching system (your file will be downloaded at each execution) or"
        " multiple processes (each process will download the file in a different temporary file).",
        FutureWarning,
    )
    # 创建临时文件并返回其文件名
    tmp_fd, tmp_file = tempfile.mkstemp()
    with os.fdopen(tmp_fd, "wb") as f:
        # 使用 HTTP GET 方法下载 URL 对应的文件到临时文件中
        http_get(url, f, proxies=proxies)
    # 返回临时文件的路径
    return tmp_file


# 检查指定的仓库或路径是否包含指定的文件,支持远程仓库和本地文件夹
def has_file(
    path_or_repo: Union[str, os.PathLike],
    filename: str,
    revision: Optional[str] = None,
    proxies: Optional[Dict[str, str]] = None,
    token: Optional[Union[bool, str]] = None,
    **deprecated_kwargs,
):
    # 如果使用的是已弃用的参数 `use_auth_token`,则发出警告并将其转换到 `token` 参数
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        token = use_auth_token

    # 如果路径为一个目录,则直接检查目录下是否有指定的文件
    if os.path.isdir(path_or_repo):
        return os.path.isfile(os.path.join(path_or_repo, filename))

    # 构建 Hub 的 URL,并获取相应的 headers
    url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
    headers = build_hf_headers(token=token, user_agent=http_user_agent())

    # 发送 HEAD 请求到指定的 URL,检查文件是否存在
    r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
    try:
        # 检查请求的状态,如果正常则返回 True,否则引发异常
        hf_raise_for_status(r)
        return True
    # 如果捕获到 GatedRepoError 异常,则记录错误信息并抛出 EnvironmentError 异常
    except GatedRepoError as e:
        logger.error(e)
        raise EnvironmentError(
            # 指定路径或资源 {path_or_repo} 是一个受保护的仓库。请确保在 'https://huggingface.co/{path_or_repo}' 请求访问权限,
            # 并通过 `huggingface-cli login` 登录或通过 `token=<your_token>` 传递具有访问权限的令牌。
            f"{path_or_repo} is a gated repository. Make sure to request access at "
            f"https://huggingface.co/{path_or_repo} and pass a token having permission to this repo either by "
            "logging in with `huggingface-cli login` or by passing `token=<your_token>`."
        ) from e
    # 如果捕获到 RepositoryNotFoundError 异常,则记录错误信息并抛出 EnvironmentError 异常
    except RepositoryNotFoundError as e:
        logger.error(e)
        raise EnvironmentError(
            # 指定路径或资源 {path_or_repo} 不是本地文件夹或 'https://hf.co' 上的有效仓库名称。
            f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'."
        )
    # 如果捕获到 RevisionNotFoundError 异常,则记录错误信息并抛出 EnvironmentError 异常
    except RevisionNotFoundError as e:
        logger.error(e)
        raise EnvironmentError(
            # 指定的 {revision} 不是此模型名称存在的有效 git 标识符(分支名称、标签名称或提交 ID)。
            # 查看 'https://huggingface.co/{path_or_repo}' 上模型页面获取可用的修订版本。
            f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
            f"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
        )
    # 如果捕获到 requests.HTTPError 异常,则返回 False,处理 EntryNotFoundError 和所有连接错误
    except requests.HTTPError:
        return False
    def _create_repo(
        self,
        repo_id: str,
        private: Optional[bool] = None,
        token: Optional[Union[bool, str]] = None,
        repo_url: Optional[str] = None,
        organization: Optional[str] = None,
    ) -> str:
        """
        创建仓库(如果需要),清理使用了已弃用参数 `repo_url` 和 `organization` 的 `repo_id`,并获取 token。
        """
        # 如果指定了 repo_url 参数,则发出警告并处理 repo_id
        if repo_url is not None:
            warnings.warn(
                "The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
                "instead."
            )
            if repo_id is not None:
                raise ValueError(
                    "`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
                )
            # 根据约定的终结点,修改 repo_id
            repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
        
        # 如果指定了 organization 参数,则发出警告并调整 repo_id
        if organization is not None:
            warnings.warn(
                "The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
                "organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
            )
            # 如果 repo_id 不以 organization 开头,则进行调整
            if not repo_id.startswith(organization):
                if "/" in repo_id:
                    repo_id = repo_id.split("/")[-1]
                repo_id = f"{organization}/{repo_id}"

        # 调用 create_repo 函数创建仓库,并返回 repo_id
        url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
        return url.repo_id

    def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
        """
        返回工作目录下文件及其最后修改时间戳的字典。
        """
        # 遍历工作目录中的文件,获取它们的最后修改时间戳
        return {f: os.path.getmtime(os.path.join(working_dir, f)) for f in os.listdir(working_dir)}

    def _upload_modified_files(
        self,
        working_dir: Union[str, os.PathLike],
        repo_id: str,
        files_timestamps: Dict[str, float],
        commit_message: Optional[str] = None,
        token: Optional[Union[bool, str]] = None,
        create_pr: bool = False,
        revision: str = None,
        commit_description: str = None,
    ):
        """
        上传修改过的文件到指定的仓库,并支持创建 Pull Request 功能。
        """
    ):
        """
        Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
        """

        # Determine the commit message if not provided explicitly
        if commit_message is None:
            # Set default commit messages based on the class name
            if "Model" in self.__class__.__name__:
                commit_message = "Upload model"
            elif "Config" in self.__class__.__name__:
                commit_message = "Upload config"
            elif "Tokenizer" in self.__class__.__name__:
                commit_message = "Upload tokenizer"
            elif "FeatureExtractor" in self.__class__.__name__:
                commit_message = "Upload feature extractor"
            elif "Processor" in self.__class__.__name__:
                commit_message = "Upload processor"
            else:
                commit_message = f"Upload {self.__class__.__name__}"

        # Identify modified files based on timestamps and existence in `working_dir`
        modified_files = [
            f
            for f in os.listdir(working_dir)
            if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
        ]

        # Filter for actual files and folders at the root level of `working_dir`
        modified_files = [
            f
            for f in modified_files
            if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
        ]

        operations = []

        # Upload individual files or files within directories
        for file in modified_files:
            if os.path.isdir(os.path.join(working_dir, file)):
                # Upload files within the directory individually
                for f in os.listdir(os.path.join(working_dir, file)):
                    operations.append(
                        CommitOperationAdd(
                            path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
                        )
                    )
            else:
                # Upload standalone file
                operations.append(
                    CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
                )

        # Optionally create a new branch if `revision` is specified
        if revision is not None:
            create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)

        # Log the files being uploaded
        logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")

        # Create and return a commit with specified parameters
        return create_commit(
            repo_id=repo_id,
            operations=operations,
            commit_message=commit_message,
            commit_description=commit_description,
            token=token,
            create_pr=create_pr,
            revision=revision,
        )

    def push_to_hub(
        self,
        repo_id: str,
        use_temp_dir: Optional[bool] = None,
        commit_message: Optional[str] = None,
        private: Optional[bool] = None,
        token: Optional[Union[bool, str]] = None,
        max_shard_size: Optional[Union[int, str]] = "5GB",
        create_pr: bool = False,
        safe_serialization: bool = True,
        revision: str = None,
        commit_description: str = None,
        tags: Optional[List[str]] = None,
        **deprecated_kwargs,
# 发送示例的遥测数据,用于跟踪示例的使用情况
def send_example_telemetry(example_name, *example_args, framework="pytorch"):
    """
    Sends telemetry that helps tracking the examples use.

    Args:
        example_name (`str`): The name of the example.
        *example_args (dataclasses or `argparse.ArgumentParser`): The arguments to the script. This function will only
            try to extract the model and dataset name from those. Nothing else is tracked.
        framework (`str`, *optional*, defaults to `"pytorch"`): The framework for the example.
    """
    # 如果处于离线模式,则直接返回
    if is_offline_mode():
        return

    # 准备遥测数据的基本信息
    data = {"example": example_name, "framework": framework}

    # 遍历传入的每组参数
    for args in example_args:
        # 将参数对象转换为字典,过滤掉以"_"开头的私有属性,并且值不为None的项
        args_as_dict = {k: v for k, v in args.__dict__.items() if not k.startswith("_") and v is not None}
        
        # 如果参数字典中包含模型名或路径
        if "model_name_or_path" in args_as_dict:
            model_name = args_as_dict["model_name_or_path"]
            # 如果模型名不是一个目录,则记录模型名
            if not os.path.isdir(model_name):
                data["model_name"] = args_as_dict["model_name_or_path"]
        
        # 如果参数字典中包含数据集名
        if "dataset_name" in args_as_dict:
            data["dataset_name"] = args_as_dict["dataset_name"]
        elif "task_name" in args_as_dict:
            # 从示例名中提取脚本名
            script_name = example_name.replace("tf_", "").replace("flax_", "").replace("run_", "")
            script_name = script_name.replace("_no_trainer", "")
            # 构建数据集名,由脚本名和任务名组成
            data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"

    # 在后台发送遥测数据
    send_telemetry(
        topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
    )


# 将文件大小转换为整数表示的字节数
def convert_file_size_to_int(size: Union[int, str]):
    """
    Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).

    Args:
        size (`int` or `str`): The size to convert. Will be directly returned if an `int`.

    Example:
    ```
    >>> convert_file_size_to_int("1MiB")
    1048576
    ```
    """
    # 如果大小已经是整数类型,则直接返回
    if isinstance(size, int):
        return size
    
    # 根据大小字符串的单位后缀进行转换
    if size.upper().endswith("GIB"):
        return int(size[:-3]) * (2**30)
    if size.upper().endswith("MIB"):
        return int(size[:-3]) * (2**20)
    if size.upper().endswith("KIB"):
        return int(size[:-3]) * (2**10)
    if size.upper().endswith("GB"):
        int_size = int(size[:-2]) * (10**9)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("MB"):
        int_size = int(size[:-2]) * (10**6)
        return int_size // 8 if size.endswith("b") else int_size
    if size.upper().endswith("KB"):
        int_size = int(size[:-2]) * (10**3)
        return int_size // 8 if size.endswith("b") else int_size
    
    # 若无法识别大小的格式,则抛出错误
    raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")


# 获取检查点分片文件列表
def get_checkpoint_shard_files(
    pretrained_model_name_or_path,
    index_filename,
    cache_dir=None,
    force_download=False,
    proxies=None,
    resume_download=False,  # 是否继续下载,默认为False,表示不继续下载
    local_files_only=False,  # 是否仅使用本地文件,默认为False,表示不仅使用本地文件
    token=None,  # 访问令牌,通常为None,表示未指定特定的访问令牌
    user_agent=None,  # 用户代理信息,通常为None,表示未指定特定的用户代理
    revision=None,  # 版本号,通常为None,表示未指定特定的版本号
    subfolder="",  # 子文件夹路径,默认为空字符串,表示没有特定的子文件夹
    _commit_hash=None,  # 提交哈希值,通常为None,表示未指定特定的提交哈希值
    **deprecated_kwargs,  # 其他过时的关键字参数,通过**kwargs收集
    """
    For a given model:

    - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
      Hub
    - returns the list of paths to all the shards, as well as some metadata.

    For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
    index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
    """
    # 导入 json 模块
    import json

    # 处理已弃用的参数 `use_auth_token`
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    if use_auth_token is not None:
        # 引发未来警告,提醒用户 `use_auth_token` 参数将在 Transformers v5 版本中删除
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果同时指定了 `use_auth_token` 和 `token`,则引发值错误
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 `use_auth_token` 赋值给 `token`
        token = use_auth_token

    # 如果指定的 `index_filename` 不是文件,则引发值错误
    if not os.path.isfile(index_filename):
        raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")

    # 读取 `index_filename` 文件内容并解析为 JSON 格式,赋值给 `index` 变量
    with open(index_filename, "r") as f:
        index = json.loads(f.read())

    # 获取所有分片文件名并排序,去重
    shard_filenames = sorted(set(index["weight_map"].values()))
    # 获取分片元数据
    sharded_metadata = index["metadata"]
    # 添加额外的检查点键到 `sharded_metadata` 中
    sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
    # 复制 `weight_map` 到 `sharded_metadata` 中的 `weight_map` 键
    sharded_metadata["weight_map"] = index["weight_map"].copy()

    # 首先处理本地文件夹的情况
    if os.path.isdir(pretrained_model_name_or_path):
        # 构建所有分片文件的完整路径
        shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
        # 返回分片文件名列表和分片元数据
        return shard_filenames, sharded_metadata

    # 如果代码执行到这里,说明 `pretrained_model_name_or_path` 是 Hub 上的模型标识符
    cached_filenames = []
    # 检查模型是否已经缓存。我们只尝试最后一个检查点,这应该涵盖大多数下载情况(如果下载被中断)。
    last_shard = try_to_load_from_cache(
        pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
    )
    # 如果 `last_shard` 为 None 或者强制下载被设置,则显示进度条
    show_progress_bar = last_shard is None or force_download
    # 遍历每个分片文件名列表,显示下载进度条,如果禁用了进度条则不显示
    for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
        try:
            # 尝试从URL加载文件到缓存,并返回缓存后的文件名
            cached_filename = cached_file(
                pretrained_model_name_or_path,  # 预训练模型名称或路径
                shard_filename,                 # 当前分片文件名
                cache_dir=cache_dir,            # 缓存目录路径
                force_download=force_download,  # 是否强制重新下载
                proxies=proxies,                # 代理设置
                resume_download=resume_download,  # 是否从上次中断处继续下载
                local_files_only=local_files_only,  # 仅使用本地文件
                token=token,                    # 访问令牌
                user_agent=user_agent,          # 用户代理
                revision=revision,              # 版本号
                subfolder=subfolder,            # 子文件夹
                _commit_hash=_commit_hash,      # 提交哈希值
            )
        # 已经在获取索引时处理了 RepositoryNotFoundError 和 RevisionNotFoundError,因此这里不需要捕获它们
        except EntryNotFoundError:
            # 如果缓存中找不到指定的文件,则抛出环境错误异常
            raise EnvironmentError(
                f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
                "required according to the checkpoint index."
            )
        except HTTPError:
            # 如果无法连接到指定的下载端点,抛出环境错误异常,提示检查网络连接后重试
            raise EnvironmentError(
                f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
                " again after checking your internet connection."
            )

        # 将缓存后的文件名添加到缓存文件名列表中
        cached_filenames.append(cached_filename)

    # 返回缓存文件名列表和分片元数据
    return cached_filenames, sharded_metadata
# 所有以下代码都是用于旧缓存格式和新缓存格式之间的转换。

# 返回所有已缓存文件的列表,包括相应的元数据
def get_all_cached_files(cache_dir=None):
    # 如果未指定缓存目录,则使用默认的 TRANSFORMERS_CACHE
    if cache_dir is None:
        cache_dir = TRANSFORMERS_CACHE
    else:
        # 将缓存目录转换为字符串形式
        cache_dir = str(cache_dir)
    # 如果缓存目录不存在,则返回空列表
    if not os.path.isdir(cache_dir):
        return []

    # 初始化一个空列表,用于存储已缓存的文件及其元数据
    cached_files = []
    # 遍历缓存目录下的所有文件
    for file in os.listdir(cache_dir):
        # 构建元数据文件的路径
        meta_path = os.path.join(cache_dir, f"{file}.json")
        # 如果元数据文件不存在,则跳过当前文件
        if not os.path.isfile(meta_path):
            continue

        # 打开元数据文件并加载其中的 JSON 数据
        with open(meta_path, encoding="utf-8") as meta_file:
            metadata = json.load(meta_file)
            # 提取 URL 和 ETag,并移除双引号
            url = metadata["url"]
            etag = metadata["etag"].replace('"', "")
            # 将文件名、URL 和 ETag 组成字典,加入到 cached_files 列表中
            cached_files.append({"file": file, "url": url, "etag": etag})

    # 返回所有已缓存文件及其元数据的列表
    return cached_files


# 从 URL 中提取仓库名、版本和文件名
def extract_info_from_url(url):
    # 使用正则表达式从 URL 中提取仓库名、版本和文件名
    search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
    # 如果未能匹配到结果,则返回 None
    if search is None:
        return None
    # 提取仓库名、版本和文件名,并拼接成字典返回
    repo, revision, filename = search.groups()
    cache_repo = "--".join(["models"] + repo.split("/"))
    return {"repo": cache_repo, "revision": revision, "filename": filename}


# 创建或加载现有的模型卡,并添加标签
def create_and_tag_model_card(
    repo_id: str,
    tags: Optional[List[str]] = None,
    token: Optional[str] = None,
    ignore_metadata_errors: bool = False,
):
    try:
        # 尝试从远程仓库加载模型卡
        model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
    except EntryNotFoundError:
        # 如果未找到模型卡,则从模板创建一个简单的模型卡
        model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
        card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
        model_card = ModelCard.from_template(card_data, model_description=model_description)

    # 如果提供了标签列表,则逐个添加到模型卡的标签中
    if tags is not None:
        for model_tag in tags:
            if model_tag not in model_card.data.tags:
                model_card.data.tags.append(model_tag)

    # 返回创建或加载的模型卡对象
    return model_card


# 删除指定文件及其相关的文件(如文件的元数据文件和锁文件),如果存在的话
def clean_files_for(file):
    pass  # 这个函数定义了一个清理文件的方法,但在这里没有实现具体功能,只是占位用的
    # 对文件和其关联的 .json 和 .lock 文件进行循环操作
    for f in [file, f"{file}.json", f"{file}.lock"]:
        # 检查当前路径下是否存在指定的文件
        if os.path.isfile(f):
            # 如果存在,删除该文件
            os.remove(f)
# 创建一个函数,将文件移动到新的缓存组织中,按照新的 huggingface hub 缓存组织规则操作
def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
    # 确保目标 repo 目录存在,如果不存在则创建
    os.makedirs(repo, exist_ok=True)

    # refs 目录
    os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
    # 如果 revision 和 commit_hash 不相同,将 commit_hash 写入到相应的 ref 文件中
    if revision != commit_hash:
        ref_path = os.path.join(repo, "refs", revision)
        with open(ref_path, "w") as f:
            f.write(commit_hash)

    # blobs 目录
    os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
    # 将文件移动到 blobs 目录下的以 etag 命名的文件中
    blob_path = os.path.join(repo, "blobs", etag)
    shutil.move(file, blob_path)

    # snapshots 目录
    os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
    os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
    # 在 snapshots 目录下的 commit_hash 子目录中创建文件名为 filename 的指针文件
    pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
    # 使用 huggingface_hub.file_download._create_relative_symlink 创建相对符号链接
    huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
    # 清理原始文件
    clean_files_for(file)


# 创建一个函数,用于迁移缓存
def move_cache(cache_dir=None, new_cache_dir=None, token=None):
    # 如果未提供 new_cache_dir,则使用默认 TRANSFORMERS_CACHE
    if new_cache_dir is None:
        new_cache_dir = TRANSFORMERS_CACHE
    # 如果未提供 cache_dir,则尝试从旧的缓存目录 .cache/huggingface/transformers 中获取
    if cache_dir is None:
        old_cache = Path(TRANSFORMERS_CACHE).parent / "transformers"
        if os.path.isdir(str(old_cache)):
            cache_dir = str(old_cache)
        else:
            cache_dir = new_cache_dir
    # 获取所有缓存文件列表
    cached_files = get_all_cached_files(cache_dir=cache_dir)
    # 记录迁移过程中的日志信息
    logger.info(f"Moving {len(cached_files)} files to the new cache system")

    # 存储 hub 元数据的字典
    hub_metadata = {}
    # 遍历所有缓存文件
    for file_info in tqdm(cached_files):
        # 获取文件的 URL,并移除文件信息中的 URL 键
        url = file_info.pop("url")
        # 如果 URL 不在 hub_metadata 中,则尝试获取文件的元数据信息并存储
        if url not in hub_metadata:
            try:
                hub_metadata[url] = get_hf_file_metadata(url, token=token)
            except requests.HTTPError:
                continue

        # 获取文件的 etag 和 commit_hash
        etag, commit_hash = hub_metadata[url].etag, hub_metadata[url].commit_hash
        # 如果 etag 或 commit_hash 为空,则跳过当前文件
        if etag is None or commit_hash is None:
            continue

        # 如果文件信息中的 etag 与当前文件的 etag 不同,清理当前文件并跳过
        if file_info["etag"] != etag:
            # 缓存文件不是最新版本,清理该文件,因为将会下载新版本
            clean_files_for(os.path.join(cache_dir, file_info["file"]))
            continue

        # 从 URL 中提取信息
        url_info = extract_info_from_url(url)
        # 如果无法从 URL 中提取信息,则跳过当前文件
        if url_info is None:
            # 不是来自 huggingface.co 的文件
            continue

        # 构建目标 repo 的路径
        repo = os.path.join(new_cache_dir, url_info["repo"])
        # 调用 move_to_new_cache 函数,将文件移动到新缓存中
        move_to_new_cache(
            file=os.path.join(cache_dir, file_info["file"]),
            repo=repo,
            filename=url_info["filename"],
            revision=url_info["revision"],
            etag=etag,
            commit_hash=commit_hash,
        )


# PushInProgress 类,用于跟踪进行中的推送(可能包含多个 Future 任务)
class PushInProgress:
    """
    Internal class to keep track of a push in progress (which might contain multiple `Future` jobs).
    """

    def __init__(self, jobs: Optional[futures.Future] = None) -> None:
        # 初始化 jobs 列表,默认为空列表
        self.jobs = [] if jobs is None else jobs

    # 检查所有任务是否完成
    def is_done(self):
        # 返回所有任务是否都已完成的布尔值
        return all(job.done() for job in self.jobs)
    # 等待所有任务完成
    def wait_until_done(self):
        # 使用 futures 模块等待所有任务完成
        futures.wait(self.jobs)

    # 取消所有未开始或已取消/已完成的任务
    def cancel(self) -> None:
        self.jobs = [
            job
            for job in self.jobs
            # 如果任务还未开始,则取消该任务;同时移除已取消或已完成的任务
            if not (job.cancel() or job.done())
        ]
# 拼接得到缓存版本文件的路径
cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")

# 检查缓存版本文件是否存在,如果不存在则默认缓存版本为0
if not os.path.isfile(cache_version_file):
    cache_version = 0
else:
    # 如果文件存在,则读取其中的内容并尝试转换为整数,如果无法转换则默认缓存版本为0
    with open(cache_version_file) as f:
        try:
            cache_version = int(f.read())
        except ValueError:
            cache_version = 0

# 检查 TRANSFORMERS_CACHE 目录是否存在且不为空
cache_is_not_empty = os.path.isdir(TRANSFORMERS_CACHE) and len(os.listdir(TRANSFORMERS_CACHE)) > 0

# 如果缓存版本小于1且 TRANSFORMERS_CACHE 目录不为空
if cache_version < 1 and cache_is_not_empty:
    # 如果处于离线模式,则记录警告信息
    if is_offline_mode():
        logger.warning(
            "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
            "cache seems to be the one of a previous version. It is very likely that all your calls to any "
            "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
            "your cache be updated automatically, then you can go back to offline mode."
        )
    else:
        # 否则记录警告信息,说明模型文件的缓存已更新
        logger.warning(
            "The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a "
            "one-time only operation. You can interrupt this and resume the migration later on by calling "
            "`transformers.utils.move_cache()`."
        )

    try:
        # 尝试迁移缓存,根据是否设置了自定义的缓存路径
        if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
            # 用户设置了某些环境变量以自定义缓存存储
            move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
        else:
            # 否则执行默认的缓存迁移
            move_cache()
    except Exception as e:
        # 如果迁移过程中发生异常,记录错误信息和堆栈追踪
        trace = "\n".join(traceback.format_tb(e.__traceback__))
        logger.error(
            f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
            "file an issue at https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole "
            "message and we will do our best to help."
        )

# 如果缓存版本小于1,则尝试创建 TRANSFORMERS_CACHE 目录,并在其中写入缓存版本号为"1"
if cache_version < 1:
    try:
        os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
        with open(cache_version_file, "w") as f:
            f.write("1")
    except Exception:
        # 如果创建过程中发生异常,则记录警告信息
        logger.warning(
            f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
            "the environment variable TRANSFORMERS_CACHE to a writable directory."
        )

.\utils\import_utils.py

# 导入模块:与导入和懒初始化相关的实用工具
import importlib.metadata  # 导入标准库中的 importlib.metadata 模块
import importlib.util  # 导入标准库中的 importlib.util 模块
import json  # 导入标准库中的 json 模块
import os  # 导入标准库中的 os 模块
import shutil  # 导入标准库中的 shutil 模块
import subprocess  # 导入标准库中的 subprocess 模块
import sys  # 导入标准库中的 sys 模块
import warnings  # 导入标准库中的 warnings 模块
from collections import OrderedDict  # 从标准库的 collections 模块中导入 OrderedDict 类
from functools import lru_cache  # 从标准库的 functools 模块中导入 lru_cache 装饰器
from itertools import chain  # 从标准库的 itertools 模块中导入 chain 函数
from types import ModuleType  # 从标准库的 types 模块中导入 ModuleType 类
from typing import Any, Tuple, Union  # 导入 typing 模块中的 Any、Tuple、Union 类型

from packaging import version  # 从 packaging 库中导入 version 模块

from . import logging  # 从当前包中导入 logging 模块


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
# 获取当前模块的 logger 实例,用于记录日志,名称为当前模块的名称
# pylint: disable=invalid-name 是禁止 pylint 检查器发出的无效名称警告


# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
    """
    检查指定的包是否可用,并返回其版本信息(如果指定)。

    Args:
        pkg_name (str): 要检查的包的名称。
        return_version (bool, optional): 是否返回包的版本信息。默认为 False。

    Returns:
        Union[Tuple[bool, str], bool]: 如果 return_version 为 True,则返回包的存在状态和版本信息的元组;
        否则,仅返回包的存在状态(布尔值)。

    Notes:
        如果包存在,则尝试获取其版本信息,如果无法获取则使用特定的后备方法。
        使用 logging 模块记录调试信息,包括检测到的包的版本信息。
    """
    # 检查包是否存在,并获取其版本信息以避免导入本地目录
    package_exists = importlib.util.find_spec(pkg_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            # 主要方法获取包的版本信息
            package_version = importlib.metadata.version(pkg_name)
        except importlib.metadata.PackageNotFoundError:
            # 备用方法:仅针对 "torch" 和包含 "dev" 的版本
            if pkg_name == "torch":
                try:
                    package = importlib.import_module(pkg_name)
                    temp_version = getattr(package, "__version__", "N/A")
                    # 检查版本信息中是否包含 "dev"
                    if "dev" in temp_version:
                        package_version = temp_version
                        package_exists = True
                    else:
                        package_exists = False
                except ImportError:
                    # 如果无法导入包,则表示不可用
                    package_exists = False
            else:
                # 对于除了 "torch" 外的包,不尝试后备方法,直接设置为不可用
                package_exists = False
        logger.debug(f"Detected {pkg_name} version: {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists


ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
# 获取环境变量 USE_TF 的值,并转换为大写形式,如果未设置则默认为 "AUTO"
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()

# 尝试通过设置该值为0,在安装了TorchXLA的环境中运行原生的PyTorch作业。
USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper()

FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()

# `transformers`需要`torch>=1.11`,但此变量对外公开,因此不能简单地删除它。
# 这是运行torch.fx特性和torch.onnx与字典输入所需的torch版本。
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")

ACCELERATE_MIN_VERSION = "0.21.0"
FSDP_MIN_VERSION = "1.12.0"

# 检查是否安装了accelerate包,并返回其是否可用及其版本号。
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
# 检查是否安装了apex包。
_apex_available = _is_package_available("apex")
# 检查是否安装了aqlm包。
_aqlm_available = _is_package_available("aqlm")
# 检查是否安装了bitsandbytes包。
_bitsandbytes_available = _is_package_available("bitsandbytes")
# 检查是否安装了galore_torch包。
_galore_torch_available = _is_package_available("galore_torch")
# 检查是否安装了beautifulsoup4包(注意,使用的是find_spec函数,因为导入的名称与包名称不同)。
_bs4_available = importlib.util.find_spec("bs4") is not None
# 检查是否安装了coloredlogs包。
_coloredlogs_available = _is_package_available("coloredlogs")
# 检查是否安装了cv2(opencv-python-headless)包。
_cv2_available = importlib.util.find_spec("cv2") is not None
# 检查是否安装了datasets包。
_datasets_available = _is_package_available("datasets")
# 检查是否安装了decord包。
_decord_available = importlib.util.find_spec("decord") is not None
# 检查是否安装了detectron2包。
_detectron2_available = _is_package_available("detectron2")
# 检查是否安装了faiss或faiss-cpu包。
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
    # 尝试获取faiss包的版本信息。
    _faiss_version = importlib.metadata.version("faiss")
    logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib.metadata.PackageNotFoundError:
    try:
        # 如果faiss包未找到,则尝试获取faiss-cpu包的版本信息。
        _faiss_version = importlib.metadata.version("faiss-cpu")
        logger.debug(f"Successfully imported faiss version {_faiss_version}")
    except importlib.metadata.PackageNotFoundError:
        # 如果faiss和faiss-cpu包都未找到,则标记_faiss_available为False。
        _faiss_available = False
# 检查是否安装了ftfy包。
_ftfy_available = _is_package_available("ftfy")
# 检查是否安装了g2p_en包。
_g2p_en_available = _is_package_available("g2p_en")
# 检查是否安装了intel_extension_for_pytorch包,并返回其是否可用及其版本号。
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
# 检查是否安装了jieba包。
_jieba_available = _is_package_available("jieba")
# 检查是否安装了jinja2包。
_jinja_available = _is_package_available("jinja2")
# 检查是否安装了kenlm包。
_kenlm_available = _is_package_available("kenlm")
# 检查是否安装了keras_nlp包。
_keras_nlp_available = _is_package_available("keras_nlp")
# 检查是否安装了Levenshtein包。
_levenshtein_available = _is_package_available("Levenshtein")
# 检查是否安装了librosa包。
_librosa_available = _is_package_available("librosa")
# 检查是否安装了natten包。
_natten_available = _is_package_available("natten")
# 检查是否安装了nltk包。
_nltk_available = _is_package_available("nltk")
# 检查是否安装了onnx包。
_onnx_available = _is_package_available("onnx")
# 检查是否安装了openai包。
_openai_available = _is_package_available("openai")
# 检查是否安装了optimum包。
_optimum_available = _is_package_available("optimum")
# 检查是否安装了auto_gptq包。
_auto_gptq_available = _is_package_available("auto_gptq")
# 检查是否安装了awq包。
# (注意,此处应有代码,但由于未找到正确的导入方式,省略了相关部分)
# 检查是否可以导入名为 "awq" 的模块
_auto_awq_available = importlib.util.find_spec("awq") is not None

# 检查名为 "quanto" 的包是否可用
_quanto_available = _is_package_available("quanto")

# 检查名为 "pandas" 的包是否可用
_pandas_available = _is_package_available("pandas")

# 检查名为 "peft" 的包是否可用
_peft_available = _is_package_available("peft")

# 检查名为 "phonemizer" 的包是否可用
_phonemizer_available = _is_package_available("phonemizer")

# 检查名为 "psutil" 的包是否可用
_psutil_available = _is_package_available("psutil")

# 检查名为 "py3nvml" 的包是否可用
_py3nvml_available = _is_package_available("py3nvml")

# 检查名为 "pyctcdecode" 的包是否可用
_pyctcdecode_available = _is_package_available("pyctcdecode")

# 检查名为 "pytesseract" 的包是否可用
_pytesseract_available = _is_package_available("pytesseract")

# 检查名为 "pytest" 的包是否可用
_pytest_available = _is_package_available("pytest")

# 检查名为 "pytorch_quantization" 的包是否可用
_pytorch_quantization_available = _is_package_available("pytorch_quantization")

# 检查名为 "rjieba" 的包是否可用
_rjieba_available = _is_package_available("rjieba")

# 检查名为 "sacremoses" 的包是否可用
_sacremoses_available = _is_package_available("sacremoses")

# 检查名为 "safetensors" 的包是否可用
_safetensors_available = _is_package_available("safetensors")

# 检查名为 "scipy" 的包是否可用
_scipy_available = _is_package_available("scipy")

# 检查名为 "sentencepiece" 的包是否可用
_sentencepiece_available = _is_package_available("sentencepiece")

# 检查名为 "seqio" 的包是否可用
_is_seqio_available = _is_package_available("seqio")

# 检查是否可以导入名为 "sklearn" 的模块
_sklearn_available = importlib.util.find_spec("sklearn") is not None
if _sklearn_available:
    try:
        # 尝试获取 "scikit-learn" 的版本信息
        importlib.metadata.version("scikit-learn")
    except importlib.metadata.PackageNotFoundError:
        # 如果找不到 "scikit-learn" 包,将 _sklearn_available 设为 False
        _sklearn_available = False

# 检查是否可以导入名为 "smdistributed" 的模块
_smdistributed_available = importlib.util.find_spec("smdistributed") is not None

# 检查名为 "soundfile" 的包是否可用
_soundfile_available = _is_package_available("soundfile")

# 检查名为 "spacy" 的包是否可用
_spacy_available = _is_package_available("spacy")

# 检查名为 "sudachipy" 的包是否可用,并获取其版本信息
_sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True)

# 检查名为 "tensorflow_probability" 的包是否可用
_tensorflow_probability_available = _is_package_available("tensorflow_probability")

# 检查名为 "tensorflow_text" 的包是否可用
_tensorflow_text_available = _is_package_available("tensorflow_text")

# 检查名为 "tf2onnx" 的包是否可用
_tf2onnx_available = _is_package_available("tf2onnx")

# 检查名为 "timm" 的包是否可用
_timm_available = _is_package_available("timm")

# 检查名为 "tokenizers" 的包是否可用
_tokenizers_available = _is_package_available("tokenizers")

# 检查名为 "torchaudio" 的包是否可用
_torchaudio_available = _is_package_available("torchaudio")

# 检查名为 "torchdistx" 的包是否可用
_torchdistx_available = _is_package_available("torchdistx")

# 检查名为 "torchvision" 的包是否可用
_torchvision_available = _is_package_available("torchvision")

# 检查名为 "mlx" 的包是否可用
_mlx_available = _is_package_available("mlx")

# 初始化 _torch_version 变量为 "N/A",_torch_available 变量为 False
_torch_version = "N/A"
_torch_available = False

# 如果 USE_TORCH 在 ENV_VARS_TRUE_AND_AUTO_VALUES 中且 USE_TF 不在 ENV_VARS_TRUE_VALUES 中
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
    # 尝试获取 "torch" 包的版本信息,并设置 _torch_available 为 True
    _torch_available, _torch_version = _is_package_available("torch", return_version=True)
else:
    # 记录信息表明禁用 PyTorch 因为 USE_TF 已设置
    logger.info("Disabling PyTorch because USE_TF is set")
    # 设置 _torch_available 为 False
    _torch_available = False

# 初始化 _tf_version 变量为 "N/A",_tf_available 变量为 False
_tf_version = "N/A"
_tf_available = False

# 如果 FORCE_TF_AVAILABLE 在 ENV_VARS_TRUE_VALUES 中
if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES:
    # 设置 _tf_available 为 True
    _tf_available = True
else:
    # 检查环境变量中是否启用了 TensorFlow,并且未启用 Torch
    if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
        # 注意:_is_package_available("tensorflow") 对 tensorflow-cpu 会失败,请测试下面的代码行
        # 在使用 tensorflow-cpu 时确保它仍然有效!

        # 检查是否可以导入 tensorflow 库
        _tf_available = importlib.util.find_spec("tensorflow") is not None
        if _tf_available:
            # 可选的 TensorFlow 包列表
            candidates = (
                "tensorflow",
                "tensorflow-cpu",
                "tensorflow-gpu",
                "tf-nightly",
                "tf-nightly-cpu",
                "tf-nightly-gpu",
                "tf-nightly-rocm",
                "intel-tensorflow",
                "intel-tensorflow-avx512",
                "tensorflow-rocm",
                "tensorflow-macos",
                "tensorflow-aarch64",
            )
            _tf_version = None
            # 在候选包列表中查找 TensorFlow 的版本信息
            for pkg in candidates:
                try:
                    _tf_version = importlib.metadata.version(pkg)
                    break
                except importlib.metadata.PackageNotFoundError:
                    pass
            # 更新 _tf_available 状态为找到的 TensorFlow 版本是否非空
            _tf_available = _tf_version is not None

        if _tf_available:
            # 如果找到 TensorFlow 并且版本小于 2,则记录警告信息并将 _tf_available 置为 False
            if version.parse(_tf_version) < version.parse("2"):
                logger.info(
                    f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum."
                )
                _tf_available = False
    else:
        # 如果 USE_TORCH 已设置,则记录禁用 TensorFlow 的信息
        logger.info("Disabling Tensorflow because USE_TORCH is set")
# 检查是否安装了 Essentia 库
_essentia_available = importlib.util.find_spec("essentia") is not None
try:
    # 获取 Essentia 库的版本信息
    _essentia_version = importlib.metadata.version("essentia")
    logger.debug(f"Successfully imported essentia version {_essentia_version}")
except importlib.metadata.PackageNotFoundError:
    # 如果 Essentia 库未找到,则标记为不可用
    _essentia_version = False


# 检查是否安装了 Pretty MIDI 库
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try:
    # 获取 Pretty MIDI 库的版本信息
    _pretty_midi_version = importlib.metadata.version("pretty_midi")
    logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
except importlib.metadata.PackageNotFoundError:
    # 如果 Pretty MIDI 库未找到,则标记为不可用
    _pretty_midi_available = False


# 初始化 CCL 版本信息,默认为 "N/A",检查是否安装了 CCL 相关库
ccl_version = "N/A"
_is_ccl_available = (
    importlib.util.find_spec("torch_ccl") is not None
    or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None
)
try:
    # 获取 oneccl_bind_pt 库的版本信息
    ccl_version = importlib.metadata.version("oneccl_bind_pt")
    logger.debug(f"Detected oneccl_bind_pt version {ccl_version}")
except importlib.metadata.PackageNotFoundError:
    # 如果 oneccl_bind_pt 库未找到,则标记 CCL 不可用
    _is_ccl_available = False


# 初始化 Flax 是否可用,默认为 False
_flax_available = False
# 如果使用 JAX 环境变量指定为 True
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
    # 检查 Flax 包是否可用,并获取其版本信息
    _flax_available, _flax_version = _is_package_available("flax", return_version=True)
    if _flax_available:
        # 如果 Flax 可用,则检查 JAX 包是否也可用,并获取其版本信息
        _jax_available, _jax_version = _is_package_available("jax", return_version=True)
        if _jax_available:
            # 如果 JAX 可用,则记录日志显示 JAX 和 Flax 的版本信息
            logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
        else:
            # 如果 JAX 不可用,则将 Flax 和 JAX 的可用性都设为 False,并将版本信息置为 "N/A"
            _flax_available = _jax_available = False
            _jax_version = _flax_version = "N/A"


# 初始化 Torch FX 是否可用,默认为 False
_torch_fx_available = False
# 如果 Torch 可用
if _torch_available:
    # 解析 Torch 版本信息
    torch_version = version.parse(_torch_version)
    # 检查 Torch FX 是否可用,需满足指定的最低版本要求
    _torch_fx_available = (torch_version.major, torch_version.minor) >= (
        TORCH_FX_REQUIRED_VERSION.major,
        TORCH_FX_REQUIRED_VERSION.minor,
    )


# 初始化 Torch XLA 是否可用,默认为 False
_torch_xla_available = False
# 如果使用 Torch XLA 环境变量指定为 True
if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES:
    # 检查 Torch XLA 包是否可用,并获取其版本信息
    _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True)
    if _torch_xla_available:
        # 如果 Torch XLA 可用,则记录日志显示 Torch XLA 的版本信息
        logger.info(f"Torch XLA version {_torch_xla_version} available.")


# 返回 KenLM 库是否可用的函数
def is_kenlm_available():
    return _kenlm_available


# 返回 OpenCV 库是否可用的函数
def is_cv2_available():
    return _cv2_available


# 返回 Torch 库是否可用的函数
def is_torch_available():
    return _torch_available


# 返回当前使用的 Torch 版本信息的函数
def get_torch_version():
    return _torch_version


# 检查是否安装了 Torch SDPA 库
def is_torch_sdpa_available():
    # 如果 Torch 不可用,则 SDPA 也不可用
    if not is_torch_available():
        return False
    # 如果 Torch 版本信息为 "N/A",则 SDPA 也不可用
    elif _torch_version == "N/A":
        return False

    # 笔记: 我们要求 torch>=2.1(而不是torch>=2.0)以在 Transformers 中使用 SDPA 有两个原因:
    # - 允许全局使用在 https://github.com/pytorch/pytorch/pull/95259 中引入的 `scale` 参数
    # - 内存高效的注意力支持任意的 attention_mask: https://github.com/pytorch/pytorch/pull/104310
    # 笔记: 我们要求 torch>=2.1.1 以避免 SDPA 在非连续输入中出现的数值问题:https://github.com/pytorch/pytorch/issues/112577
    return version.parse(_torch_version) >= version.parse("2.1.1")


# 返回 Torch Vision 库是否可用的函数
def is_torchvision_available():
    return _torchvision_available


# 返回变量 _torchvision_available 的值作为函数的返回结果
# 检查是否 galore_torch 可用,返回对应的状态
def is_galore_torch_available():
    return _galore_torch_available


# 检查是否 pyctcdecode 可用,返回对应的状态
def is_pyctcdecode_available():
    return _pyctcdecode_available


# 检查是否 librosa 可用,返回对应的状态
def is_librosa_available():
    return _librosa_available


# 检查是否 essentia 可用,返回对应的状态
def is_essentia_available():
    return _essentia_available


# 检查是否 pretty_midi 可用,返回对应的状态
def is_pretty_midi_available():
    return _pretty_midi_available


# 检查是否 torch 可用,并且 CUDA 是否可用
def is_torch_cuda_available():
    if is_torch_available():
        import torch

        return torch.cuda.is_available()
    else:
        return False


# 检查是否 torch 可用,并且检查是否 mamba_ssm 包可用
def is_mamba_ssm_available():
    if is_torch_available():
        import torch

        if not torch.cuda.is_available():
            return False
        else:
            return _is_package_available("mamba_ssm")
    return False


# 检查是否 torch 可用,并且检查是否 causal_conv1d 包可用
def is_causal_conv1d_available():
    if is_torch_available():
        import torch

        if not torch.cuda.is_available():
            return False
        return _is_package_available("causal_conv1d")
    return False


# 检查是否 torch 可用,并且检查是否 torch.backends.mps 可用
def is_torch_mps_available():
    if is_torch_available():
        import torch

        if hasattr(torch.backends, "mps"):
            return torch.backends.mps.is_available()
    return False


# 检查是否 torch 可用,并且检查是否 CUDA BF16 支持
def is_torch_bf16_gpu_available():
    if not is_torch_available():
        return False

    import torch

    return torch.cuda.is_available() and torch.cuda.is_bf16_supported()


# 检查是否 torch 可用,并且检查是否 CPU BF16 支持
def is_torch_bf16_cpu_available():
    if not is_torch_available():
        return False

    import torch

    try:
        _ = torch.cpu.amp.autocast
    except AttributeError:
        return False

    return True


# 检查是否 torch 可用,并且检查是否 GPU 或 CPU 上 BF16 支持
def is_torch_bf16_available():
    warnings.warn(
        "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available "
        "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu",
        FutureWarning,
    )
    return is_torch_bf16_gpu_available()


# 使用 lru_cache 修饰器,检查在指定设备上是否 torch 的 FP16 可用
def is_torch_fp16_available_on_device(device):
    if not is_torch_available():
        return False

    import torch

    try:
        # 创建一个小张量,并执行矩阵乘法操作以检查 FP16 支持
        x = torch.zeros(2, 2, dtype=torch.float16).to(device)
        _ = x @ x

        # 检查在设备上是否支持 LayerNorm 操作,因为许多模型使用此操作
        batch, sentence_length, embedding_dim = 3, 4, 5
        embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device)
        layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device)
        _ = layer_norm(embedding)

    except:  # 捕获所有异常,返回 False
        return False

    return True
# 使用 LRU 缓存装饰器缓存函数结果,避免重复计算
@lru_cache()
# 检查指定设备上是否可用 Torch 的 BF16 支持
def is_torch_bf16_available_on_device(device):
    # 如果 Torch 不可用,则返回 False
    if not is_torch_available():
        return False

    # 导入 Torch 库
    import torch

    # 如果设备是 "cuda",则检查 GPU 上是否可用 BF16 支持
    if device == "cuda":
        return is_torch_bf16_gpu_available()

    # 尝试在指定设备上创建一个 bfloat16 类型的张量并执行矩阵乘法操作
    try:
        x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
        _ = x @ x
    except:  # noqa: E722
        # 捕获所有异常,通常返回 RuntimeError,但不保证
        # TODO: 如果可能的话,进行更精确的异常匹配
        return False

    # 如果以上尝试成功,则返回 True,表示 BF16 在该设备上可用
    return True


# 检查当前环境是否支持 Torch 的 TF32 支持
def is_torch_tf32_available():
    # 如果 Torch 不可用,则返回 False
    if not is_torch_available():
        return False

    # 导入 Torch 库
    import torch

    # 如果 CUDA 不可用或者 CUDA 版本为 None,则返回 False
    if not torch.cuda.is_available() or torch.version.cuda is None:
        return False
    # 如果 CUDA 设备的主版本号小于 8,则返回 False
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
        return False
    # 如果 CUDA 版本的主版本号小于 11,则返回 False
    if int(torch.version.cuda.split(".")[0]) < 11:
        return False
    # 如果 Torch 版本小于 1.7,则返回 False
    if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
        return False

    # 如果以上条件都满足,则返回 True,表示 TF32 在当前环境中可用
    return True


# 返回 Torch FX 是否可用的标志
def is_torch_fx_available():
    return _torch_fx_available


# 返回 PEFT 是否可用的标志
def is_peft_available():
    return _peft_available


# 返回 Beautiful Soup (bs4) 是否可用的标志
def is_bs4_available():
    return _bs4_available


# 返回 TensorFlow 是否可用的标志
def is_tf_available():
    return _tf_available


# 返回 coloredlogs 是否可用的标志
def is_coloredlogs_available():
    return _coloredlogs_available


# 返回 TF2ONNX 是否可用的标志
def is_tf2onnx_available():
    return _tf2onnx_available


# 返回 ONNX 是否可用的标志
def is_onnx_available():
    return _onnx_available


# 返回 OpenAI 的库是否可用的标志
def is_openai_available():
    return _openai_available


# 返回 Flax 是否可用的标志
def is_flax_available():
    return _flax_available


# 返回 ftfy 是否可用的标志
def is_ftfy_available():
    return _ftfy_available


# 返回 g2p_en 是否可用的标志
def is_g2p_en_available():
    return _g2p_en_available


# 使用 LRU 缓存装饰器缓存函数结果,避免重复计算
@lru_cache()
# 检查是否 Torch TPU 可用(即是否安装了 torch_xla 并且环境中存在 TPU)
def is_torch_tpu_available(check_device=True):
    # 发出警告,提示函数即将被弃用
    warnings.warn(
        "`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. "
        "Please use the `is_torch_xla_available` instead.",
        FutureWarning,
    )

    # 如果 Torch 不可用,则返回 False
    if not _torch_available:
        return False
    # 如果安装了 torch_xla,则进一步检查是否存在 TPU 设备
    if importlib.util.find_spec("torch_xla") is not None:
        if check_device:
            # 需要检查是否可以找到 `xla_device`,如果找不到将引发 RuntimeError
            try:
                import torch_xla.core.xla_model as xm

                _ = xm.xla_device()
                return True
            except RuntimeError:
                return False
        return True
    return False


# 使用 LRU 缓存装饰器缓存函数结果,避免重复计算
@lru_cache
# 检查 Torch XLA 是否可用
def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False):
    """
    Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set
    the USE_TORCH_XLA to false.
    """
    # 断言 `check_is_tpu` 和 `check_is_gpu` 不能同时为 True
    assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true."

    # 如果 Torch XLA 不可用,则返回 False
    if not _torch_xla_available:
        return False

    # 导入 torch_xla 库
    import torch_xla

    # 如果需要检查 GPU,则返回当前设备类型是否为 GPU 或 CUDA
    if check_is_gpu:
        return torch_xla.runtime.device_type() in ["GPU", "CUDA"]
    # 如果检测到是TPU设备,则返回是否为TPU
    elif check_is_tpu:
        return torch_xla.runtime.device_type() == "TPU"
    # 否则返回True
    return True
# 使用 lru_cache 装饰器,缓存函数调用结果,提升性能
@lru_cache()
# 检查是否存在 torch_neuronx 模块,若存在则调用 is_torch_xla_available 函数
def is_torch_neuroncore_available(check_device=True):
    if importlib.util.find_spec("torch_neuronx") is not None:
        return is_torch_xla_available()
    return False


# 使用 lru_cache 装饰器,缓存函数调用结果,提升性能
@lru_cache()
# 检查是否安装了 torch_npu 模块,并可选地检查环境中是否存在 NPU 设备
def is_torch_npu_available(check_device=False):
    # 如果 _torch_available 为 False 或者找不到 torch_npu 模块,则返回 False
    if not _torch_available or importlib.util.find_spec("torch_npu") is None:
        return False

    import torch
    import torch_npu  # noqa: F401

    if check_device:
        try:
            # 如果没有找到 NPU 设备会抛出 RuntimeError
            _ = torch.npu.device_count()
            return torch.npu.is_available()
        except RuntimeError:
            return False
    # 检查 torch 是否有 npu 属性并且 NPU 可用
    return hasattr(torch, "npu") and torch.npu.is_available()


# 检查是否存在 torch _dynamo 模块以判断是否可用
def is_torchdynamo_available():
    if not is_torch_available():  # 如果 torch 不可用,则返回 False
        return False
    try:
        import torch._dynamo as dynamo  # noqa: F401

        return True
    except Exception:
        return False


# 检查是否存在 torch.compile 属性来判断是否可用
def is_torch_compile_available():
    if not is_torch_available():  # 如果 torch 不可用,则返回 False
        return False

    import torch

    # 不进行版本检查以支持夜间版本标记为 1.14。最终需要与 2.0 版本进行检查,但暂时不处理
    return hasattr(torch, "compile")


# 检查是否在编译 torch _dynamo 模块
def is_torchdynamo_compiling():
    if not is_torch_available():  # 如果 torch 不可用,则返回 False
        return False
    try:
        import torch._dynamo as dynamo  # noqa: F401

        return dynamo.is_compiling()
    except Exception:
        return False


# 检查是否安装了 torch_tensorrt 模块,并且是否存在 torch_tensorrt.fx 子模块
def is_torch_tensorrt_fx_available():
    if importlib.util.find_spec("torch_tensorrt") is None:  # 如果找不到 torch_tensorrt 模块,则返回 False
        return False
    return importlib.util.find_spec("torch_tensorrt.fx") is not None  # 检查是否存在 torch_tensorrt.fx 子模块


# 返回 _datasets_available 变量的值
def is_datasets_available():
    return _datasets_available


# 返回 _detectron2_available 变量的值
def is_detectron2_available():
    return _detectron2_available


# 返回 _rjieba_available 变量的值
def is_rjieba_available():
    return _rjieba_available


# 返回 _psutil_available 变量的值
def is_psutil_available():
    return _psutil_available


# 返回 _py3nvml_available 变量的值
def is_py3nvml_available():
    return _py3nvml_available


# 返回 _sacremoses_available 变量的值
def is_sacremoses_available():
    return _sacremoses_available


# 返回 _apex_available 变量的值
def is_apex_available():
    return _apex_available


# 返回 _aqlm_available 变量的值
def is_aqlm_available():
    return _aqlm_available


# 检查系统是否安装了 ninja 构建系统
def is_ninja_available():
    r"""
    Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
    [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise.
    """
    try:
        subprocess.check_output("ninja --version".split())  # 执行命令检查 ninja 版本
    except Exception:
        return False  # 捕获异常则返回 False
    else:
        return True  # 执行成功则返回 True


# 检查是否安装了 ipex 模块以及 torch 可用性和 _ipex_available 变量
def is_ipex_available():
    def get_major_and_minor_from_version(full_version):
        return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

    if not is_torch_available() or not _ipex_available:
        return False  # 如果 torch 不可用或者 _ipex_available 为 False,则返回 False

    torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
    ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
    # 检查当前安装的 PyTorch 主版本和次版本是否与 Intel Extension for PyTorch 所需版本匹配
    if torch_major_and_minor != ipex_major_and_minor:
        # 如果不匹配,记录警告信息,提示用户切换到匹配的 PyTorch 版本后重新运行
        logger.warning(
            f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
            f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
        )
        # 返回 False 表示版本不匹配
        return False
    # 如果版本匹配,返回 True
    return True
# 使用 lru_cache 装饰器来缓存函数的结果,提升函数性能
@lru_cache
# 检查是否安装了 intel_extension_for_pytorch 并且可能存在 XPU 设备
def is_torch_xpu_available(check_device=False):
    if not is_ipex_available():  # 如果没有安装 intel_extension_for_pytorch,则返回 False
        return False

    import intel_extension_for_pytorch  # 引入 intel_extension_for_pytorch 模块,用于检查是否安装
    import torch  # 引入 torch 模块

    if check_device:
        try:
            # 尝试获取 XPU 设备的数量,如果没有 XPU 设备会抛出 RuntimeError
            _ = torch.xpu.device_count()
            # 返回当前是否有可用的 XPU 设备
            return torch.xpu.is_available()
        except RuntimeError:
            return False
    # 检查是否存在 torch.xpu 模块,并且该模块当前是否可用
    return hasattr(torch, "xpu") and torch.xpu.is_available()


def is_bitsandbytes_available():
    if not is_torch_available():  # 如果 torch 不可用,则返回 False
        return False

    # bitsandbytes 在没有 cuda 可用时会抛出错误,这里添加简单检查避免异常
    import torch  # 引入 torch 模块

    return _bitsandbytes_available and torch.cuda.is_available()  # 返回 bitsandbytes 是否可用以及当前是否有 cuda 可用


def is_flash_attn_2_available():
    if not is_torch_available():  # 如果 torch 不可用,则返回 False
        return False

    if not _is_package_available("flash_attn"):  # 如果 flash_attn 包不可用,则返回 False
        return False

    import torch  # 引入 torch 模块

    if not torch.cuda.is_available():  # 如果没有 cuda 可用,则返回 False
        return False

    if torch.version.cuda:  # 如果是 CUDA 版本
        # 检查 flash_attn 包的版本是否大于等于 2.1.0
        return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
    elif torch.version.hip:  # 如果是 HIP 版本
        # TODO: 一旦在 https://github.com/ROCmSoftwarePlatform/flash-attention 发布,将要求将要求版本提升至 2.1.0
        return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
    else:
        return False


def is_flash_attn_greater_or_equal_2_10():
    if not _is_package_available("flash_attn"):  # 如果 flash_attn 包不可用,则返回 False
        return False

    # 检查 flash_attn 包的版本是否大于等于 2.1.0
    return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")


def is_torchdistx_available():
    return _torchdistx_available  # 返回 _torchdistx_available 变量的值


def is_faiss_available():
    return _faiss_available  # 返回 _faiss_available 变量的值


def is_scipy_available():
    return _scipy_available  # 返回 _scipy_available 变量的值


def is_sklearn_available():
    return _sklearn_available  # 返回 _sklearn_available 变量的值


def is_sentencepiece_available():
    return _sentencepiece_available  # 返回 _sentencepiece_available 变量的值


def is_seqio_available():
    return _is_seqio_available  # 返回 _is_seqio_available 变量的值


def is_protobuf_available():
    if importlib.util.find_spec("google") is None:  # 如果找不到 google 模块,则返回 False
        return False
    # 检查是否找到 google.protobuf 模块,并返回结果
    return importlib.util.find_spec("google.protobuf") is not None


def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
    if min_version is not None:
        # 检查 _accelerate_available 变量的值,并且检查其版本是否大于等于 min_version
        return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
    return _accelerate_available  # 返回 _accelerate_available 变量的值


def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
    if is_torch_available():  # 如果 torch 可用
        # 检查 _torch_version 的版本是否大于等于 min_version
        return version.parse(_torch_version) >= version.parse(min_version)
    return False  # 如果 torch 不可用,则返回 False


def is_optimum_available():
    return _optimum_available  # 返回 _optimum_available 变量的值


def is_auto_awq_available():
    return _auto_awq_available  # 返回 _auto_awq_available 变量的值


def is_quanto_available():
    return _quanto_available  # 返回 _quanto_available 变量的值


def is_auto_gptq_available():
    return _auto_gptq_available  # 返回 _auto_gptq_available 变量的值


def is_levenshtein_available():
    # 此函数未实现,没有返回值
    return _levenshtein_available


    # 返回变量 _levenshtein_available 的值作为函数的返回结果
# 检查是否已经安装了 optimum.neuron 包并且 _optimum_available 变量为真
def is_optimum_neuron_available():
    return _optimum_available and _is_package_available("optimum.neuron")


# 返回 _safetensors_available 变量的值
def is_safetensors_available():
    return _safetensors_available


# 返回 _tokenizers_available 变量的值
def is_tokenizers_available():
    return _tokenizers_available


# 使用 lru_cache 装饰器缓存函数结果,检查 PIL 库是否可用
@lru_cache
def is_vision_available():
    _pil_available = importlib.util.find_spec("PIL") is not None
    if _pil_available:
        try:
            package_version = importlib.metadata.version("Pillow")
        except importlib.metadata.PackageNotFoundError:
            try:
                package_version = importlib.metadata.version("Pillow-SIMD")
            except importlib.metadata.PackageNotFoundError:
                return False
        logger.debug(f"Detected PIL version {package_version}")
    return _pil_available


# 返回 _pytesseract_available 变量的值
def is_pytesseract_available():
    return _pytesseract_available


# 返回 _pytest_available 变量的值
def is_pytest_available():
    return _pytest_available


# 返回 _spacy_available 变量的值
def is_spacy_available():
    return _spacy_available


# 返回 is_tf_available() 和 _tensorflow_text_available 变量的逻辑与结果
def is_tensorflow_text_available():
    return is_tf_available() and _tensorflow_text_available


# 返回 is_tensorflow_text_available() 和 _keras_nlp_available 变量的逻辑与结果
def is_keras_nlp_available():
    return is_tensorflow_text_available() and _keras_nlp_available


# 在 Notebook 环境中检查 IPython 模块的存在
def is_in_notebook():
    try:
        get_ipython = sys.modules["IPython"].get_ipython
        if "IPKernelApp" not in get_ipython().config:
            raise ImportError("console")
        if "VSCODE_PID" in os.environ:
            raise ImportError("vscode")
        if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0":
            raise ImportError("databricks")
        return importlib.util.find_spec("IPython") is not None
    except (AttributeError, ImportError, KeyError):
        return False


# 返回 _pytorch_quantization_available 变量的值
def is_pytorch_quantization_available():
    return _pytorch_quantization_available


# 返回 _tensorflow_probability_available 变量的值
def is_tensorflow_probability_available():
    return _tensorflow_probability_available


# 返回 _pandas_available 变量的值
def is_pandas_available():
    return _pandas_available


# 检查 SageMaker 是否启用了分布式数据并行 (Distributed Data Parallel, DDP)
# 通过解析环境变量 SM_FRAMEWORK_PARAMS 检查 sagemaker_distributed_dataparallel_enabled 字段
def is_sagemaker_dp_enabled():
    sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        sagemaker_params = json.loads(sagemaker_params)
        if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
            return False
    except json.JSONDecodeError:
        return False
    return _smdistributed_available


# 获取 SageMaker 的 MP 参数变量 SM_HP_MP_PARAMETERS
def is_sagemaker_mp_enabled():
    smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
    try:
        # 尝试解析 smp_options 变量,并检查是否包含 "partitions" 字段,这是模型并行所需的。
        smp_options = json.loads(smp_options)
        if "partitions" not in smp_options:
            return False
    except json.JSONDecodeError:
        # 解析失败或格式错误,返回 False
        return False

    # 从 mpi_options 变量中获取 SageMaker 特定的框架参数。
    mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # 尝试解析 mpi_options 变量,并检查是否包含 "sagemaker_mpi_enabled" 字段。
        mpi_options = json.loads(mpi_options)
        if not mpi_options.get("sagemaker_mpi_enabled", False):
            return False
    except json.JSONDecodeError:
        # 解析失败或格式错误,返回 False
        return False
    
    # 最后,检查是否存在 `smdistributed` 模块。
    return _smdistributed_available
# 检查当前运行环境是否为 SageMaker 环境,通过检查环境变量中是否存在 "SAGEMAKER_JOB_NAME"
def is_training_run_on_sagemaker():
    return "SAGEMAKER_JOB_NAME" in os.environ


# 返回一个布尔值,指示是否安装了 soundfile 库
def is_soundfile_availble():
    return _soundfile_available


# 返回一个布尔值,指示是否安装了 timm 库
def is_timm_available():
    return _timm_available


# 返回一个布尔值,指示是否安装了 natten 库
def is_natten_available():
    return _natten_available


# 返回一个布尔值,指示是否安装了 nltk 库
def is_nltk_available():
    return _nltk_available


# 返回一个布尔值,指示是否安装了 torchaudio 库
def is_torchaudio_available():
    return _torchaudio_available


# 返回一个布尔值,指示是否安装了与语音处理相关的库,目前依赖于 torchaudio
def is_speech_available():
    return _torchaudio_available


# 返回一个布尔值,指示是否安装了 phonemizer 库
def is_phonemizer_available():
    return _phonemizer_available


# 返回一个装饰器函数,用于检查是否安装了 torch 库,如果未安装则抛出 ImportError
def torch_only_method(fn):
    def wrapper(*args, **kwargs):
        if not _torch_available:
            raise ImportError(
                "You need to install pytorch to use this method or class, "
                "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
            )
        else:
            return fn(*args, **kwargs)

    return wrapper


# 返回一个布尔值,指示是否安装了 ccl 库
def is_ccl_available():
    return _is_ccl_available


# 返回一个布尔值,指示是否安装了 decord 库
def is_decord_available():
    return _decord_available


# 返回一个布尔值,指示是否安装了 sudachipy 库
def is_sudachi_available():
    return _sudachipy_available


# 返回当前 sudachipy 库的版本信息
def get_sudachi_version():
    return _sudachipy_version


# 返回一个布尔值,指示是否安装了 sudachipy 并且支持 projection 选项
def is_sudachi_projection_available():
    if not is_sudachi_available():
        return False

    # 检查 sudachipy 版本是否大于等于 0.6.8,以确定是否支持 projection 选项
    return version.parse(_sudachipy_version) >= version.parse("0.6.8")


# 返回一个布尔值,指示是否安装了 jumanpp 库
def is_jumanpp_available():
    # 使用 importlib.util.find_spec 检查 rhoknp 模块和 shutil.which 检查 jumanpp 是否存在
    return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None)


# 返回一个布尔值,指示是否安装了 cython 库
def is_cython_available():
    return importlib.util.find_spec("pyximport") is not None


# 返回一个布尔值,指示是否安装了 jieba 库
def is_jieba_available():
    return _jieba_available


# 返回一个布尔值,指示是否安装了 jinja 库
def is_jinja_available():
    return _jinja_available


# 返回一个布尔值,指示是否安装了 mlx 库
def is_mlx_available():
    return _mlx_available


# CV2_IMPORT_ERROR 的文本内容,提醒用户需要安装 OpenCV 库才能继续执行相关操作
CV2_IMPORT_ERROR = """
{0} requires the OpenCV library but it was not found in your environment. You can install it with:

pip install opencv-python

Please note that you may need to restart your runtime after installation.
"""


# DATASETS_IMPORT_ERROR 的文本内容,提醒用户需要安装 🤗 Datasets 库才能继续执行相关操作
DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:

pip install datasets

In a notebook or a colab, you can install it by executing a cell with

!pip install datasets

then restarting your kernel.

Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
that python file if that's the case. Please note that you may need to restart your runtime after installation.
"""


# TOKENIZERS_IMPORT_ERROR 是空字符串,没有具体的内容或注释
TOKENIZERS_IMPORT_ERROR = """
# 格式化字符串,用于给定模块名的导入错误提示信息
SENTENCEPIECE_IMPORT_ERROR = """
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""


# 格式化字符串,用于给定模块名的导入错误提示信息
PROTOBUF_IMPORT_ERROR = """
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""


# 格式化字符串,用于给定模块名的导入错误提示信息
FAISS_IMPORT_ERROR = """
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""


# 格式化字符串,用于给定模块名的导入错误提示信息
PYTORCH_IMPORT_ERROR = """
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""


# 格式化字符串,用于给定模块名的导入错误提示信息
TORCHVISION_IMPORT_ERROR = """
{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""

# 格式化字符串,用于给定模块名的导入错误提示信息,同时提供了关于 TensorFlow 和 PyTorch 的信息
PYTORCH_IMPORT_ERROR_WITH_TF = """
{0} requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to our PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TF{0}".
If you want to use TensorFlow, please use TF classes instead!

If you really do want to use PyTorch please go to
https://pytorch.org/get-started/locally/ and follow the instructions that
match your environment.
"""

# 格式化字符串,用于给定模块名的导入错误提示信息,同时提供了关于 TensorFlow 和 PyTorch 的信息
TF_IMPORT_ERROR_WITH_PYTORCH = """
{0} requires the TensorFlow library but it was not found in your environment.
However, we were able to find a PyTorch installation. PyTorch classes do not begin
# 定义错误消息模板,用于缺少 Beautiful Soup 库时显示
BS4_IMPORT_ERROR = """
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 scikit-learn 库时显示
SKLEARN_IMPORT_ERROR = """
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:

pip install -U scikit-learn

In a notebook or a colab, you can install it by executing a cell with

!pip install -U scikit-learn

Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 TensorFlow 库时显示
TENSORFLOW_IMPORT_ERROR = """
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 detectron2 库时显示
DETECTRON2_IMPORT_ERROR = """
{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 FLAX 库时显示
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 ftfy 库时显示
FTFY_IMPORT_ERROR = """
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 python-Levenshtein 库时显示
LEVENSHTEIN_IMPORT_ERROR = """
{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip
install python-Levenshtein`. Please note that you may need to restart your runtime after installation.
"""

# 定义错误消息模板,用于缺少 g2p-en 库时显示
G2P_EN_IMPORT_ERROR = """
{0} requires the g2p-en library but it was not found in your environment. You can install it with pip:
`pip install g2p-en`. Please note that you may need to restart your runtime after installation.
"""

# 空白的错误消息模板,用于缺少 PyTorch Quantization 库时显示
PYTORCH_QUANTIZATION_IMPORT_ERROR = """
"""
# 定义当缺少 pytorch-quantization 库时所需的错误消息模板
TENSORFLOW_PROBABILITY_IMPORT_ERROR = """
{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 tensorflow_text 库时所需的错误消息模板
TENSORFLOW_TEXT_IMPORT_ERROR = """
{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as
explained here: https://www.tensorflow.org/text/guide/tf_text_intro.
Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 pandas 库时所需的错误消息模板
PANDAS_IMPORT_ERROR = """
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 phonemizer 库时所需的错误消息模板
PHONEMIZER_IMPORT_ERROR = """
{0} requires the phonemizer library but it was not found in your environment. You can install it with pip:
`pip install phonemizer`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 sacremoses 库时所需的错误消息模板
SACREMOSES_IMPORT_ERROR = """
{0} requires the sacremoses library but it was not found in your environment. You can install it with pip:
`pip install sacremoses`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 scipy 库时所需的错误消息模板
SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip:
`pip install scipy`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 torchaudio 库时所需的错误消息模板
SPEECH_IMPORT_ERROR = """
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
`pip install torchaudio`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 timm 库时所需的错误消息模板
TIMM_IMPORT_ERROR = """
{0} requires the timm library but it was not found in your environment. You can install it with pip:
`pip install timm`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 natten 库时所需的错误消息模板
NATTEN_IMPORT_ERROR = """
{0} requires the natten library but it was not found in your environment. You can install it by referring to:
shi-labs.com/natten . You can also install it with pip (may take longer to build):
`pip install natten`. Please note that you may need to restart your runtime after installation.
"""

# 定义当缺少 NLTK 库时所需的错误消息模板
NLTK_IMPORT_ERROR = """
{0} requires the NLTK library but it was not found in your environment. You can install it by referring to:
# 引入 docstyle-ignore,以下注释内容是一些导入错误消息的字符串模板
# 引入 Vision 模块时的导入错误消息模板
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
`pip install pillow`. Please note that you may need to restart your runtime after installation.
"""

# 引入 PyTesseract 模块时的导入错误消息模板
PYTESSERACT_IMPORT_ERROR = """
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
`pip install pytesseract`. Please note that you may need to restart your runtime after installation.
"""

# 引入 pyctcdecode 模块时的导入错误消息模板
PYCTCDECODE_IMPORT_ERROR = """
{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip:
`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation.
"""

# 引入 accelerate 模块时的导入错误消息模板
ACCELERATE_IMPORT_ERROR = """
{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment.
You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your
runtime after installation.
"""

# 引入 torch ccl 模块时的导入错误消息模板
CCL_IMPORT_ERROR = """
{0} requires the torch ccl library but it was not found in your environment. You can install it with pip:
`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable`
Please note that you may need to restart your runtime after installation.
"""

# 引入 essentia 模块时的导入错误消息模板
ESSENTIA_IMPORT_ERROR = """
{0} requires essentia library. But that was not found in your environment. You can install them with pip:
`pip install essentia==2.1b6.dev1034`
Please note that you may need to restart your runtime after installation.
"""

# 引入 librosa 模块时的导入错误消息模板
LIBROSA_IMPORT_ERROR = """
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
`pip install librosa`
Please note that you may need to restart your runtime after installation.
"""

# 引入 pretty_midi 模块时的导入错误消息模板
PRETTY_MIDI_IMPORT_ERROR = """
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
`pip install pretty_midi`
Please note that you may need to restart your runtime after installation.
"""

# 引入 decord 模块时的导入错误消息模板
DECORD_IMPORT_ERROR = """
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
decord`. Please note that you may need to restart your runtime after installation.
"""

# 引入 Cython 模块时的导入错误消息模板
CYTHON_IMPORT_ERROR = """
{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install
Cython`. Please note that you may need to restart your runtime after installation.
"""

# 引入 jieba 模块时的导入错误消息模板
JIEBA_IMPORT_ERROR = """
{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install
jieba`. Please note that you may need to restart your runtime after installation.
"""

# 引入 PEFT 模块时的注释内容为空,因此无需添加任何注释
PEFT_IMPORT_ERROR = """
# 引入 OrderedDict 类型,用于定义一个有序的映射关系
BACKENDS_MAPPING = OrderedDict(
    # 列表包含了各个库及其可用性检查函数和导入错误常量的元组
    [
        # BeautifulSoup4 库的可用性检查和导入错误常量
        ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
        # OpenCV 库的可用性检查和导入错误常量
        ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)),
        # Datasets 库的可用性检查和导入错误常量
        ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
        # Detectron2 库的可用性检查和导入错误常量
        ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
        # Essentia 库的可用性检查和导入错误常量
        ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
        # Faiss 库的可用性检查和导入错误常量
        ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
        # Flax 库的可用性检查和导入错误常量
        ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
        # FTFY 库的可用性检查和导入错误常量
        ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
        # g2p_en 库的可用性检查和导入错误常量
        ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)),
        # Pandas 库的可用性检查和导入错误常量
        ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
        # Phonemizer 库的可用性检查和导入错误常量
        ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
        # Pretty MIDI 库的可用性检查和导入错误常量
        ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
        # Levenshtein 库的可用性检查和导入错误常量
        ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)),
        # Librosa 库的可用性检查和导入错误常量
        ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
        # Protobuf 库的可用性检查和导入错误常量
        ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
        # PyCTCDecode 库的可用性检查和导入错误常量
        ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
        # Pytesseract 库的可用性检查和导入错误常量
        ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
        # Sacremoses 库的可用性检查和导入错误常量
        ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)),
        # PyTorch Quantization 库的可用性检查和导入错误常量
        ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)),
        # SentencePiece 库的可用性检查和导入错误常量
        ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
        # Scikit-learn 库的可用性检查和导入错误常量
        ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)),
        # Speech 库的可用性检查和导入错误常量
        ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)),
        # TensorFlow Probability 库的可用性检查和导入错误常量
        ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)),
        # TensorFlow 库的可用性检查和导入错误常量
        ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
        # TensorFlow Text 库的可用性检查和导入错误常量
        ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)),
        # Timm 库的可用性检查和导入错误常量
        ("timm", (is_timm_available, TIMM_IMPORT_ERROR)),
        # Natten 库的可用性检查和导入错误常量
        ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)),
        # NLTK 库的可用性检查和导入错误常量
        ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
        # Tokenizers 库的可用性检查和导入错误常量
        ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)),
        # PyTorch 库的可用性检查和导入错误常量
        ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
        # Torchvision 库的可用性检查和导入错误常量
        ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)),
        # Vision 库的可用性检查和导入错误常量
        ("vision", (is_vision_available, VISION_IMPORT_ERROR)),
        # SciPy 库的可用性检查和导入错误常量
        ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
        # Accelerate 库的可用性检查和导入错误常量
        ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
        # OneCCL 绑定库的可用性检查和导入错误常量
        ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)),
        # Decord 库的可用性检查和导入错误常量
        ("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
        # Cython 库的可用性检查和导入错误常量
        ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
        # 结巴分词 库的可用性检查和导入错误常量
        ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
        # PEFT 库的可用性检查和导入错误常量
        ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
        # Jinja 库的可用性检查和导入错误常量
        ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
    ]
    # 定义一个名为 `DummyObject` 的元类,用于创建虚拟对象类
    class DummyObject(type):
        """
        Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
        `requires_backend` each time a user tries to access any method of that class.
        """

        # 拦截对类属性和方法的访问,检查所需的后端是否可用
        def __getattribute__(cls, key):
            if key.startswith("_") and key != "_from_config":
                return super().__getattribute__(key)
            # 调用 `requires_backends` 函数,检查类 `cls` 所需的后端是否可用
            requires_backends(cls, cls._backends)


    # 判断对象 `x` 是否为 Torch FX 的代理对象
    def is_torch_fx_proxy(x):
        if is_torch_fx_available():
            import torch.fx

            return isinstance(x, torch.fx.Proxy)
        return False


    # 定义一个 `_LazyModule` 类,用于惰性加载模块
    class _LazyModule(ModuleType):
        """
        Module class that surfaces all objects but only performs associated imports when the objects are requested.
        """

        # 构造函数,初始化惰性加载模块
        def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
            super().__init__(name)
            # 设置模块的导入结构和相关属性
            self._modules = set(import_structure.keys())
            self._class_to_module = {}
            # 为类和其模块之间的映射建立字典
            for key, values in import_structure.items():
                for value in values:
                    self._class_to_module[value] = key
            # 设置模块的 `__all__` 属性,用于 IDE 的自动补全
            self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
            self.__file__ = module_file
            self.__spec__ = module_spec
            self.__path__ = [os.path.dirname(module_file)]
            # 设置模块的额外对象属性
            self._objects = {} if extra_objects is None else extra_objects
            self._name = name
            self._import_structure = import_structure

        # 为了在 IDE 中进行自动补全而需要的特殊方法
    # 继承父类的 __dir__() 方法,获取默认的属性列表
    def __dir__(self):
        result = super().__dir__()
        # 检查 self.__all__ 中的元素是否是子模块,有些可能已经在属性列表中,取决于是否已被访问
        # 只添加那些尚未在属性列表中的 self.__all__ 元素
        for attr in self.__all__:
            if attr not in result:
                result.append(attr)
        # 返回更新后的属性列表
        return result

    # 获取属性值的方法,支持动态获取 self._objects 中的对象或者通过模块名称获取模块中的属性
    def __getattr__(self, name: str) -> Any:
        if name in self._objects:
            return self._objects[name]  # 如果属性在 self._objects 中,直接返回其值
        if name in self._modules:
            value = self._get_module(name)  # 如果属性在 self._modules 中,调用 _get_module 获取模块对象
        elif name in self._class_to_module.keys():
            # 如果属性在 self._class_to_module 中,获取相应的模块对象,并从中获取属性值
            module = self._get_module(self._class_to_module[name])
            value = getattr(module, name)
        else:
            # 如果属性不存在于以上三种情况,则引发 AttributeError
            raise AttributeError(f"module {self.__name__} has no attribute {name}")

        setattr(self, name, value)  # 将获取到的属性值设置为实例的属性,以便下次直接访问
        return value

    # 根据模块名称导入模块的方法
    def _get_module(self, module_name: str):
        try:
            return importlib.import_module("." + module_name, self.__name__)
        except Exception as e:
            # 如果导入失败,抛出 RuntimeError 异常
            raise RuntimeError(
                f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
                f" traceback):\n{e}"
            ) from e

    # 序列化对象时调用的方法,返回对象的类、名称、导入结构元组
    def __reduce__(self):
        return (self.__class__, (self._name, self.__file__, self._import_structure))
class OptionalDependencyNotAvailable(BaseException):
    """用于表示未找到可选依赖项的内部错误类。"""


def direct_transformers_import(path: str, file="__init__.py") -> ModuleType:
    """直接导入 transformers 模块

    Args:
        path (`str`): 源文件的路径
        file (`str`, optional): 要与路径拼接的文件名。默认为 "__init__.py".

    Returns:
        `ModuleType`: 导入的结果模块对象
    """
    # 设置模块名为 "transformers"
    name = "transformers"
    # 构建文件的完整路径
    location = os.path.join(path, file)
    # 创建模块的规范对象
    spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path])
    # 根据规范对象创建模块
    module = importlib.util.module_from_spec(spec)
    # 执行模块的代码,加载模块
    spec.loader.exec_module(module)
    # 获取已加载的模块对象
    module = sys.modules[name]
    # 返回导入的模块对象
    return module
posted @ 2024-07-01 10:54  绝不原创的飞龙  阅读(45)  评论(0编辑  收藏  举报