Transformers-源码解析-一百三十三-

Transformers 源码解析(一百三十三)

.\processing_utils.py

# 设置文件编码为 UTF-8
# 版权声明,声明代码的版权归 The HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证版本 2.0 使用此文件,除非遵守许可证,否则不得使用此文件
# 可以在以下网址获取许可证副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件根据"原样"提供,不附带任何形式的明示或暗示的担保或条件
# 有关详细信息,请参阅许可证
"""
通用处理器的保存/加载类。
"""

import copy  # 导入复制模块
import inspect  # 导入检查模块
import json  # 导入 JSON 模块
import os  # 导入操作系统模块
import warnings  # 导入警告模块
from pathlib import Path  # 导入 Path 类
from typing import Any, Dict, Optional, Tuple, Union  # 导入类型提示

from .dynamic_module_utils import custom_object_save  # 从动态模块工具导入自定义对象保存函数
from .tokenization_utils_base import PreTrainedTokenizerBase  # 从基础标记化工具导入预训练分词器基类
from .utils import (
    PROCESSOR_NAME,  # 从工具模块导入处理器名称常量
    PushToHubMixin,  # 从工具模块导入推送至 Hub 的 Mixin 类
    add_model_info_to_auto_map,  # 从工具模块导入将模型信息添加到自动映射的函数
    cached_file,  # 从工具模块导入缓存文件函数
    copy_func,  # 从工具模块导入复制函数函数
    direct_transformers_import,  # 从工具模块导入直接导入 Transformers 模块的函数
    download_url,  # 从工具模块导入下载 URL 函数
    is_offline_mode,  # 从工具模块导入检查是否为离线模式的函数
    is_remote_url,  # 从工具模块导入检查是否为远程 URL 的函数
    logging,  # 从工具模块导入日志记录对象
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象

# 动态导入 Transformers 模块,以获取处理器类的属性类
transformers_module = direct_transformers_import(Path(__file__).parent)

# 自动映射到基类的映射表,用于自动模型加载时的类关联
AUTO_TO_BASE_CLASS_MAPPING = {
    "AutoTokenizer": "PreTrainedTokenizerBase",  # 自动分词器映射到基础分词器基类
    "AutoFeatureExtractor": "FeatureExtractionMixin",  # 自动特征提取器映射到特征提取混合类
    "AutoImageProcessor": "ImageProcessingMixin",  # 自动图像处理器映射到图像处理混合类
}


class ProcessorMixin(PushToHubMixin):
    """
    这是一个 Mixin 类,用于为所有处理器类提供保存/加载功能。
    """

    attributes = ["feature_extractor", "tokenizer"]  # 处理器类中需要保存的属性列表
    # 对应属性列表中的类属性定义
    feature_extractor_class = None  # 特征提取器类属性初始化为空
    tokenizer_class = None  # 分词器类属性初始化为空
    _auto_class = None  # 自动加载的类属性初始化为空

    # args have to match the attributes class attribute
    def __init__(self, *args, **kwargs):
        # 对传入的参数和关键字参数进行清理和验证
        for key in kwargs:
            # 检查关键字参数是否在对象的属性列表中,否则引发异常
            if key not in self.attributes:
                raise TypeError(f"Unexpected keyword argument {key}.")
        
        for arg, attribute_name in zip(args, self.attributes):
            # 检查位置参数是否与属性名匹配的关键字参数冲突,如果有冲突则引发异常
            if attribute_name in kwargs:
                raise TypeError(f"Got multiple values for argument {attribute_name}.")
            else:
                kwargs[attribute_name] = arg

        if len(kwargs) != len(self.attributes):
            # 检查最终的关键字参数数量是否与对象属性数量匹配,不匹配则引发数值错误异常
            raise ValueError(
                f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got "
                f"{len(args)} arguments instead."
            )

        # 检查每个参数是否属于其对应的预期类别,这也会捕获用户错误顺序初始化的情况
        for attribute_name, arg in kwargs.items():
            class_name = getattr(self, f"{attribute_name}_class")
            # 如果类名为"AutoXxx",则检查其对应的基类
            class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
            if isinstance(class_name, tuple):
                # 如果类名是元组,则获取模块中对应的类列表
                proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None)
            else:
                # 否则直接获取模块中的类
                proper_class = getattr(transformers_module, class_name)

            # 检查参数是否属于预期的类别,不属于则引发数值错误异常
            if not isinstance(arg, proper_class):
                raise ValueError(
                    f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
                )

            # 将参数设置为对象的属性
            setattr(self, attribute_name, arg)
    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this processor instance.
        """
        # Create a deep copy of the instance's __dict__ to prevent unintended modifications
        output = copy.deepcopy(self.__dict__)

        # Retrieve the signature of the __init__ method to get its parameters
        sig = inspect.signature(self.__init__)
        
        # Filter out attributes that are not listed in the __init__ parameters
        attrs_to_save = sig.parameters
        attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes]
        
        # Add "auto_map" to the list of attributes to be saved
        attrs_to_save += ["auto_map"]

        # Filter the output dictionary to include only the attributes to be saved
        output = {k: v for k, v in output.items() if k in attrs_to_save}

        # Add the class name of the processor instance to the output dictionary
        output["processor_class"] = self.__class__.__name__

        # Remove specific attributes that should not be included in the output
        if "tokenizer" in output:
            del output["tokenizer"]
        if "image_processor" in output:
            del output["image_processor"]
        if "feature_extractor" in output:
            del output["feature_extractor"]

        # Filter out attributes with names indicating objects not suitable for serialization
        output = {
            k: v
            for k, v in output.items()
            if not (isinstance(v, PushToHubMixin) or v.__class__.__name__ == "BeamSearchDecoderCTC")
        }

        return output

    def to_json_string(self) -> str:
        """
        Serializes this instance to a JSON string.

        Returns:
            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
        """
        # Convert the instance to a dictionary
        dictionary = self.to_dict()

        # Serialize the dictionary to a JSON string with formatting
        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this processor instance's parameters will be saved.
        """
        # Open the JSON file for writing
        with open(json_file_path, "w", encoding="utf-8") as writer:
            # Write the instance's JSON representation to the file
            writer.write(self.to_json_string())

    def __repr__(self):
        """
        Returns a string representation of the processor instance.

        Returns:
            `str`: String representation of the processor instance, including key attributes and JSON serialization.
        """
        # Generate representations of all attributes specified in self.attributes
        attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes]
        
        # Concatenate attribute representations into a single string
        attributes_repr = "\n".join(attributes_repr)
        
        # Return a formatted string including class name, attributes, and JSON serialization
        return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}"

    @classmethod
    def get_processor_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ):
        """
        Placeholder method for defining how to get processor dictionary.

        This method is not implemented in the provided code snippet.
        """
        pass
    def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
        """
        从参数字典和额外关键字参数实例化一个 [`~processing_utils.ProcessingMixin`] 类型的对象。

        Args:
            processor_dict (`Dict[str, Any]`):
                用于实例化处理器对象的参数字典。可以利用预训练检查点的
                [`~processing_utils.ProcessingMixin.to_dict`] 方法来获取这样一个字典。
            kwargs (`Dict[str, Any]`):
                初始化处理器对象的额外参数。

        Returns:
            [`~processing_utils.ProcessingMixin`]: 从这些参数实例化的处理器对象。
        """
        processor_dict = processor_dict.copy()
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        # 不像图像处理器或特征提取器那样,处理器的 `__init__` 方法不接受 `kwargs`。
        # 我们必须弹出一些未使用的(但是特定的)参数才能使其正常工作。
        if "processor_class" in processor_dict:
            del processor_dict["processor_class"]

        if "auto_map" in processor_dict:
            del processor_dict["auto_map"]

        # 使用给定的 `args` 和 `processor_dict` 实例化处理器对象
        processor = cls(*args, **processor_dict)

        # 如果需要,使用 `kwargs` 更新处理器对象
        for key in set(kwargs.keys()):
            if hasattr(processor, key):
                setattr(processor, key, kwargs.pop(key))

        # 记录处理器对象的信息
        logger.info(f"Processor {processor}")
        if return_unused_kwargs:
            return processor, kwargs
        else:
            return processor

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
        ):
        r"""
        Instantiate a processor associated with a pretrained model.

        <Tip>

        This class method is simply calling the feature extractor
        [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor
        [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer
        [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the
        methods above for more information.

        </Tip>

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                This can be either:

                - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
                  huggingface.co.
                - a path to a *directory* containing a feature extractor file saved using the
                  [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
                - a path or url to a saved feature extractor JSON *file*, e.g.,
                  `./my_model_directory/preprocessor_config.json`.
            **kwargs
                Additional keyword arguments passed along to both
                [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
                [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
        """
        kwargs["cache_dir"] = cache_dir
        kwargs["force_download"] = force_download
        kwargs["local_files_only"] = local_files_only
        kwargs["revision"] = revision

        # Check and handle deprecated use_auth_token argument
        use_auth_token = kwargs.pop("use_auth_token", None)
        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        # If token is provided, set it in kwargs
        if token is not None:
            kwargs["token"] = token

        # Get arguments from pretrained model and process kwargs
        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
        # Obtain processor dictionary and update kwargs
        processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)

        # Instantiate the class using obtained arguments and processor dictionary
        return cls.from_args_and_dict(args, processor_dict, **kwargs)

    @classmethod
    # 注册一个自动类别名,用于自定义特征提取器,这应仅用于自定义的特征提取器,因为库中的提取器已经与 `AutoProcessor` 映射好了。
    def register_for_auto_class(cls, auto_class="AutoProcessor"):
        """
        Register this class with a given auto class. This should only be used for custom feature extractors as the ones
        in the library are already mapped with `AutoProcessor`.

        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
                The auto class to register this new feature extractor with.
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        # 导入 transformers.models.auto 模块,用于检查 auto_class 是否存在
        import transformers.models.auto as auto_module

        # 如果 auto_module 中没有找到指定的 auto_class,则抛出 ValueError
        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} is not a valid auto class.")

        # 将 auto_class 赋值给当前类的 _auto_class 属性
        cls._auto_class = auto_class

    @classmethod
    def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 初始化一个空列表,用于存储从预训练模型中获取的参数
        args = []
        # 遍历类的 attributes 列表
        for attribute_name in cls.attributes:
            # 获取当前属性对应的类名
            class_name = getattr(cls, f"{attribute_name}_class")

            # 如果 class_name 是一个元组
            if isinstance(class_name, tuple):
                # 从 transformers_module 中获取类,如果为 None 则跳过
                classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
                # 获取 kwargs 中的 use_fast 参数,默认为 True
                use_fast = kwargs.get("use_fast", True)
                # 如果 use_fast 为 True 并且 classes[1] 不为 None,则使用 classes[1],否则使用 classes[0]
                if use_fast and classes[1] is not None:
                    attribute_class = classes[1]
                else:
                    attribute_class = classes[0]
            else:
                # 如果 class_name 不是元组,则直接从 transformers_module 中获取对应的类
                attribute_class = getattr(transformers_module, class_name)

            # 使用 from_pretrained 方法从预训练模型加载参数,并添加到 args 列表中
            args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
        return args

    @property
    def model_input_names(self):
        # 获取当前对象的第一个属性,并尝试获取其 model_input_names 属性,如果不存在则返回 None
        first_attribute = getattr(self, self.attributes[0])
        return getattr(first_attribute, "model_input_names", None)
# 将 ProcessorMixin 类的 push_to_hub 方法复制一份,赋值给原方法
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
# 检查 push_to_hub 方法的文档字符串是否不为空
if ProcessorMixin.push_to_hub.__doc__ is not None:
    # 如果文档字符串不为空,使用格式化字符串将文档字符串中的占位符替换为指定的内容
    ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
        object="processor", object_class="AutoProcessor", object_files="processor files"
    )

.\pytorch_utils.py

# 导入inspect模块,用于获取对象信息
import inspect
# 导入类型提示模块
from typing import Callable, List, Optional, Set, Tuple, Union

# 导入PyTorch库
import torch
# 导入版本管理模块
from packaging import version
# 导入safetensors库中的相关函数
from safetensors.torch import storage_ptr, storage_size
# 导入PyTorch的神经网络模块
from torch import nn

# 导入本地的is_torch_xla_available和logging函数
from .utils import is_torch_xla_available, logging

# 定义一个包含nn.LayerNorm的列表
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 解析当前使用的PyTorch版本的基础版本号
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

# 检查当前PyTorch版本是否大于等于2.2
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
# 检查当前PyTorch版本是否大于等于2.1
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
# 检查当前PyTorch版本是否大于等于2.0
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
# 检查当前PyTorch版本是否大于等于1.13
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
# 检查当前PyTorch版本是否大于等于1.12
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")


def softmax_backward_data(parent, grad_output, output, dim, self):
    """
    调用内部的`_softmax_backward_data` PyTorch方法,并根据检测到的torch版本调整参数。
    
    Args:
        parent: 父对象
        grad_output: 梯度输出
        output: 输出
        dim: 维度
        self: 当前对象

    Returns:
        返回内部方法`_softmax_backward_data`的调用结果
    """
    from torch import _softmax_backward_data

    return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)


def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
    """
    对线性层进行修剪,仅保留index中的条目。

    用于移除神经网络中的头部。

    Args:
        layer (`torch.nn.Linear`): 需要修剪的线性层。
        index (`torch.LongTensor`): 要在层中保留的索引。
        dim (`int`, *可选*, 默认为0): 在哪个维度上保留索引。

    Returns:
        `torch.nn.Linear`: 作为新层的修剪后的层,具有`requires_grad=True`。
    """
    # 将索引移到与权重张量相同的设备上
    index = index.to(layer.weight.device)
    # 从权重张量中选择指定维度上的索引,并进行克隆和分离
    W = layer.weight.index_select(dim, index).clone().detach()
    # 如果存在偏置,则根据维度选择偏置,并进行克隆和分离
    if layer.bias is not None:
        if dim == 1:
            b = layer.bias.clone().detach()
        else:
            b = layer.bias[index].clone().detach()
    # 创建新层,其尺寸与权重相同,但在指定维度上为索引长度
    new_size = list(layer.weight.size())
    new_size[dim] = len(index)
    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
    # 设置新层的权重不需要梯度,并复制修剪后的权重
    new_layer.weight.requires_grad = False
    new_layer.weight.copy_(W.contiguous())
    new_layer.weight.requires_grad = True

    return new_layer
    # 检查原始层是否具有偏置项
    if layer.bias is not None:
        # 将新层的偏置项设置为不需要梯度计算
        new_layer.bias.requires_grad = False
        # 将新层的偏置项赋值为现有偏置项的连续拷贝
        new_layer.bias.copy_(b.contiguous())
        # 设置新层的偏置项为需要梯度计算
        new_layer.bias.requires_grad = True
    # 返回已经设置好的新层对象
    return new_layer
    def apply_chunking_to_forward(
        forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors: Any
    ) -> torch.Tensor:
        """
        This function applies chunking to a forward function to allow for large tensor inputs that
        exceed memory capacity.

        Args:
            forward_fn (`Callable[..., torch.Tensor]`): The forward function of the model.
            chunk_size (`int`): The size of each chunk in the specified dimension.
            chunk_dim (`int`): The dimension along which to chunk the input tensors.
            *input_tensors (`Any`): Input tensors to the forward function.

        Returns:
            `torch.Tensor`: The result tensor from the forward function after applying chunking.
        """
        assert isinstance(chunk_size, int) and chunk_size > 0
        assert isinstance(chunk_dim, int) and chunk_dim < len(input_tensors[0].size())

        chunked_input_tensors = list(zip(*map(lambda x: x.chunk(chunk_size, dim=chunk_dim), input_tensors)))
        outputs = []

        for chunked_inputs in chunked_input_tensors:
            outputs.append(forward_fn(*chunked_inputs))

        if len(outputs) == 1:
            return outputs[0]
        else:
            return torch.cat(outputs, dim=chunk_dim)
    forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors


# forward_fn: Callable[..., torch.Tensor] 定义了一个名为 forward_fn 的参数,它是一个可调用对象,接受任意数量和类型的参数并返回 torch.Tensor 类型的对象。
# chunk_size: int 是一个整数类型的参数,用于指定数据块的大小。
# chunk_dim: int 是一个整数类型的参数,表示数据块在输入张量中的维度。
# *input_tensors 指定了一个可变数量的输入张量参数,它们将被传递给 forward_fn 函数。
    """
    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
    `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.

    If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
    applying `forward_fn` to `input_tensors`.

    Args:
        forward_fn (`Callable[..., torch.Tensor]`):
            The forward function of the model.
        chunk_size (`int`):
            The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
        chunk_dim (`int`):
            The dimension over which the `input_tensors` should be chunked.
        input_tensors (`Tuple[torch.Tensor]`):
            The input tensors of `forward_fn` which will be chunked

    Returns:
        `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied.

    Examples:

    ```
    # rename the usual forward() fn to forward_chunk()
    def forward_chunk(self, hidden_states):
        hidden_states = self.decoder(hidden_states)
        return hidden_states


    # implement a chunked forward function
    def forward(self, hidden_states):
        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
    ```"""

    # Check if there are input tensors provided
    assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"

    # Determine the number of arguments expected by the forward function
    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)

    # Validate that the number of input tensors matches the number of expected arguments in forward_fn
    if num_args_in_forward_chunk_fn != len(input_tensors):
        raise ValueError(
            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
            "tensors are given"
        )
    # 如果指定了有效的块大小
    if chunk_size > 0:
        # 获取输入张量列表中第一个张量在指定维度上的形状
        tensor_shape = input_tensors[0].shape[chunk_dim]
        
        # 遍历输入张量列表,检查它们在指定维度上的形状是否与第一个张量相同
        for input_tensor in input_tensors:
            if input_tensor.shape[chunk_dim] != tensor_shape:
                # 如果形状不同,则抛出数值错误异常
                raise ValueError(
                    f"All input tenors have to be of the same shape: {tensor_shape}, "
                    f"found shape {input_tensor.shape[chunk_dim]}"
                )

        # 检查第一个张量在指定维度上的大小是否能被块大小整除
        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
            # 如果不能整除,则抛出数值错误异常
            raise ValueError(
                f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
                f"size {chunk_size}"
            )

        # 计算需要分块的数量
        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size

        # 将每个输入张量按指定维度分块成元组
        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
        
        # 对每个元组应用前向函数,并生成输出块的元组
        output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
        
        # 在指定维度上连接输出块,生成最终的输出张量
        return torch.cat(output_chunks, dim=chunk_dim)

    # 如果未指定有效的块大小,则直接将前向函数应用于输入张量并返回结果
    return forward_fn(*input_tensors)
# 定义一个函数,用于查找可以裁剪的头部索引及其位置
def find_pruneable_heads_and_indices(
    heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
    """
    找到需要裁剪的头部及其索引,考虑到已经裁剪的头部。

    Args:
        heads (`List[int]`): 需要裁剪的头部索引列表。
        n_heads (`int`): 模型中头部的数量。
        head_size (`int`): 每个头部的大小。
        already_pruned_heads (`Set[int]`): 已经裁剪的头部集合。

    Returns:
        `Tuple[Set[int], torch.LongTensor]`: 返回一个元组,包含考虑到 `already_pruned_heads` 后需要裁剪的头部索引,
        以及层权重中需要保留的行/列索引。
    """
    mask = torch.ones(n_heads, head_size)  # 创建一个全为1的掩码
    heads = set(heads) - already_pruned_heads  # 转换为集合并移除已经裁剪的头部
    for head in heads:
        # 计算在当前头部之前有多少已经裁剪的头部,并相应地调整索引
        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
        mask[head] = 0  # 将对应头部的掩码置为0,表示需要裁剪
    mask = mask.view(-1).contiguous().eq(1)  # 将掩码展平为一维,并保留值为1的位置
    index: torch.LongTensor = torch.arange(len(mask))[mask].long()  # 获取需要保留的索引
    return heads, index  # 返回需要裁剪的头部索引和需要保留的索引


# 定义一个函数,用于创建网格
def meshgrid(
    *tensors: Union[torch.Tensor, List[torch.Tensor]], indexing: Optional[str] = None
) -> Tuple[torch.Tensor, ...]:
    """
    对 torch.meshgrid 的包装,以避免关于引入的 `indexing` 参数的警告信息。

    Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
    """
    return torch.meshgrid(*tensors, indexing=indexing)


# 定义一个函数,用于获取张量的存储信息
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
    """
    唯一标识符,用于标识张量的存储。多个不同的张量可以共享相同的底层存储。
    例如,“meta”张量共享相同的存储,因此它们的标识符将相等。
    此标识符在张量的生命周期内保证是唯一且常量的。两个存储生命周期不重叠的张量可能具有相同的id。
    """
    if tensor.device.type == "xla" and is_torch_xla_available():
        # 注意:xla 张量没有存储
        # 使用其他唯一标识符来区分。
        # 这是一个 XLA 张量,必须使用 torch_xla 的设备创建。
        # 所以以下导入是安全的:
        import torch_xla

        unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
    else:
        unique_id = storage_ptr(tensor)

    return tensor.device, unique_id, storage_size(tensor)

.\quantizers\auto.py

# 导入警告模块,用于处理警告信息
import warnings
# 导入类型提示相关模块
from typing import Dict, Optional, Union

# 导入自动配置模块
from ..models.auto.configuration_auto import AutoConfig
# 导入量化配置相关类和方法
from ..utils.quantization_config import (
    AqlmConfig,
    AwqConfig,
    BitsAndBytesConfig,
    GPTQConfig,
    QuantizationConfigMixin,
    QuantizationMethod,
    QuantoConfig,
)
# 导入各种量化器类
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer

# 自动量化器与量化器类的映射关系
AUTO_QUANTIZER_MAPPING = {
    "awq": AwqQuantizer,
    "bitsandbytes_4bit": Bnb4BitHfQuantizer,
    "bitsandbytes_8bit": Bnb8BitHfQuantizer,
    "gptq": GptqHfQuantizer,
    "aqlm": AqlmHfQuantizer,
    "quanto": QuantoHfQuantizer,
}

# 自动量化配置与量化配置类的映射关系
AUTO_QUANTIZATION_CONFIG_MAPPING = {
    "awq": AwqConfig,
    "bitsandbytes_4bit": BitsAndBytesConfig,
    "bitsandbytes_8bit": BitsAndBytesConfig,
    "gptq": GPTQConfig,
    "aqlm": AqlmConfig,
    "quanto": QuantoConfig,
}

# 自动量化配置类,用于根据给定的量化配置自动分发到正确的量化配置
class AutoQuantizationConfig:
    """
    The Auto-HF quantization config class that takes care of automatically dispatching to the correct
    quantization config given a quantization config stored in a dictionary.
    """

    @classmethod
    # 从给定的字典构建一个类方法,用于反序列化量化配置
    def from_dict(cls, quantization_config_dict: Dict):
        # 从量化配置字典中获取量化方法,如果不存在则设为 None
        quant_method = quantization_config_dict.get("quant_method", None)
        # 对于 bnb 模型,需要特别处理以确保兼容性
        if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
            # 如果配置中指定了 load_in_4bit,则使用 4 位量化方法后缀;否则使用 8 位后缀
            suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
            # 构建量化方法字符串,结合 BITS_AND_BYTES 常量
            quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
        # 如果未指定量化方法,则抛出 ValueError 异常
        elif quant_method is None:
            raise ValueError(
                "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
            )

        # 如果量化方法不在自动量化配置映射的键中,则抛出 ValueError 异常
        if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
            raise ValueError(
                f"Unknown quantization type, got {quant_method} - supported types are:"
                f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
            )

        # 根据量化方法从自动量化配置映射中获取目标类
        target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
        # 使用目标类的类方法将量化配置字典反序列化为对象
        return target_cls.from_dict(quantization_config_dict)

    # 类方法:从预训练模型中加载模型配置,并返回相应的量化配置
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 从预训练模型名称或路径加载模型配置
        model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        # 如果模型配置中没有量化配置,则抛出 ValueError 异常
        if getattr(model_config, "quantization_config", None) is None:
            raise ValueError(
                f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
            )
        # 获取模型配置中的量化配置字典
        quantization_config_dict = model_config.quantization_config
        # 使用 from_dict 方法将量化配置字典反序列化为量化配置对象
        quantization_config = cls.from_dict(quantization_config_dict)
        # 将传递自 from_pretrained 的额外参数更新到量化配置对象中
        quantization_config.update(kwargs)
        # 返回构建好的量化配置对象
        return quantization_config
    """
     The Auto-HF quantizer class that takes care of automatically instantiating to the correct
    `HfQuantizer` given the `QuantizationConfig`.
    """

    @classmethod
    # 类方法:从给定的量化配置创建一个实例
    def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
        # 如果 quantization_config 是字典类型,则转换为 QuantizationConfig
        if isinstance(quantization_config, dict):
            quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

        # 获取量化方法
        quant_method = quantization_config.quant_method

        # 对 BITS_AND_BYTES 方法进行特殊处理,因为我们有一个单独的量化配置类用于 4-bit 和 8-bit 量化
        if quant_method == QuantizationMethod.BITS_AND_BYTES:
            if quantization_config.load_in_8bit:
                quant_method += "_8bit"
            else:
                quant_method += "_4bit"

        # 检查量化方法是否在自动量化映射中
        if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
            raise ValueError(
                f"Unknown quantization type, got {quant_method} - supported types are:"
                f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
            )

        # 根据量化方法选择对应的类
        target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
        # 使用选定的类创建实例并返回
        return target_cls(quantization_config, **kwargs)

    @classmethod
    # 类方法:从预训练模型名称或路径创建一个实例
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        # 从预训练模型获取量化配置
        quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        # 使用量化配置创建实例
        return cls.from_config(quantization_config)

    @classmethod
    # 类方法:合并两个量化配置
    def merge_quantization_configs(
        cls,
        quantization_config: Union[dict, QuantizationConfigMixin],
        quantization_config_from_args: Optional[QuantizationConfigMixin],
        """
        处理同时存在来自参数和模型配置中的量化配置的情况。
        """
        # 如果参数中有 quantization_config,则生成警告消息
        if quantization_config_from_args is not None:
            warning_msg = (
                "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
                " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
            )
        else:
            warning_msg = ""

        # 如果 quantization_config 是字典类型,则转换为 AutoQuantizationConfig 对象
        if isinstance(quantization_config, dict):
            quantization_config = AutoQuantizationConfig.from_dict(quantization_config)

        # 对于 GPTQConfig 或 AwqConfig 类型的 quantization_config,并且 quantization_config_from_args 不为空的特殊情况处理
        if isinstance(quantization_config, (GPTQConfig, AwqConfig)) and quantization_config_from_args is not None:
            # 获取 quantization_config_from_args 中的加载属性,并设置到 quantization_config 对象中
            loading_attr_dict = quantization_config_from_args.get_loading_attributes()
            for attr, val in loading_attr_dict.items():
                setattr(quantization_config, attr, val)
            # 更新警告消息,说明加载属性将被传入的参数覆盖,其余属性将被忽略
            warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."

        # 如果有警告消息,发出警告
        if warning_msg != "":
            warnings.warn(warning_msg)

        # 返回处理后的 quantization_config 对象
        return quantization_config

.\quantizers\base.py

# 引入 ABC 类和类型检查相关的模块
from abc import ABC, abstractmethod
# 引入类型检查模块,用于检查是否支持特定类型的操作
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

# 从相对路径的模块中导入 is_torch_available 函数
from ..utils import is_torch_available
# 从相对路径的模块中导入 QuantizationConfigMixin 类
from ..utils.quantization_config import QuantizationConfigMixin

# 如果在类型检查模式下,导入 PreTrainedModel 类
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

# 如果 Torch 可用,则导入 Torch 模块
if is_torch_available():
    import torch

# 定义 HfQuantizer 抽象类
class HfQuantizer(ABC):
    """
    HuggingFace 量化器的抽象类。目前支持对 HF transformers 模型进行推断和/或量化。
    这个类仅用于 transformers.PreTrainedModel.from_pretrained 方法的范围内,无法在该方法范围外轻松使用。

    Attributes:
        quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
            定义要量化的模型的量化参数的配置。
        modules_to_not_convert (`List[str]`, *optional*):
            在量化模型时不希望转换的模块名称列表。
        required_packages (`List[str]`, *optional*):
            使用量化器之前需要安装的必需 pip 包列表。
        requires_calibration (`bool`):
            使用量化方法是否需要在使用模型之前进行校准。
        requires_parameters_quantization (`bool`):
            使用量化方法是否需要创建新的参数。例如,对于 bitsandbytes,需要创建新的 xxxParameter 来正确量化模型。
    """

    # 标识量化方法是否需要校准
    requires_calibration = False
    # 用于存储需要安装的必需 pip 包的列表
    required_packages = None
    # 标识量化方法是否需要对参数进行量化
    requires_parameters_quantization = False
    # 初始化函数,接受量化配置和其他关键字参数
    def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
        # 将量化配置保存到实例变量中
        self.quantization_config = quantization_config

        # 处理额外的关键字参数
        self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
        self.pre_quantized = kwargs.pop("pre_quantized", True)

        # 如果未预量化但需要校准,引发值错误异常
        if not self.pre_quantized and self.requires_calibration:
            raise ValueError(
                f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
                f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
                f"pass `pre_quantized=True` while knowing what you are doing."
            )

    # 更新 Torch 数据类型的方法,通常由子类重写以确保行为一致性
    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        """
        Some quantization methods require to explicitly set the dtype of the model to a
        target dtype. You need to override this method in case you want to make sure that behavior is
        preserved

        Args:
            torch_dtype (`torch.dtype`):
                The input dtype that is passed in `from_pretrained`
        """
        return torch_dtype

    # 更新设备映射的方法,通常由子类重写以允许传递新的设备映射
    def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
        """
        Override this method if you want to pass a override the existing device map with a new
        one. E.g. for bitsandbytes, since `accelerate` is a hard requirement, if no device_map is
        passed, the device_map is set to `"auto"``

        Args:
            device_map (`Union[dict, str]`, *optional*):
                The device_map that is passed through the `from_pretrained` method.
        """
        return device_map

    # 调整目标 Torch 数据类型的方法,通常由子类重写以适应特定的量化需求
    def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        """
        Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained`
        to compute the device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype`
        to `torch.int8` and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.

        Args:
            torch_dtype (`torch.dtype`, *optional*):
                The torch_dtype that is used to compute the device_map.
        """
        return torch_dtype

    # 更新缺失键列表的方法,通常由子类重写以适应特定的模型加载需求
    def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
        """
        Override this method if you want to adjust the `missing_keys`.

        Args:
            missing_keys (`List[str]`, *optional*):
                The list of missing keys in the checkpoint compared to the state dict of the model
        """
        return missing_keys
    def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
        """
        返回未量化模块的数据类型字典 - 用于在传递字符串作为 device_map 时计算 device_map。
        该方法将使用 `_process_model_before_weight_loading` 中修改的 `modules_to_not_convert`。
        
        Args:
            model (`~transformers.PreTrainedModel`):
                要量化的模型
            torch_dtype (`torch.dtype`):
                在 `from_pretrained` 方法中传递的数据类型
        """
        return {
            name: torch_dtype
            for name, _ in model.named_parameters()
            if any(m in name for m in self.modules_to_not_convert)
        }

    def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
        """根据需要为量化调整 max_memory 参数,用于 infer_auto_device_map()。"""
        return max_memory

    def check_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ) -> bool:
        """
        检查加载的 state_dict 组件是否是量化参数的一部分,同时进行一些验证;
        只有在 requires_parameters_quantization == True 时才会定义,用于需要为量化方法创建新参数的情况。
        """
        return False

    def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
        """
        从 state_dict 中获取必要的组件并创建量化参数;
        只有在 requires_parameters_quantization == True 时适用。
        如果不支持 requires_parameters_quantization,则会引发 AttributeError。
        """
        if not self.requires_parameters_quantization:
            raise AttributeError(
                f"`.create_quantized_param()` 方法不受量化器类 {self.__class__.__name__} 支持。"
            )

    def validate_environment(self, *args, **kwargs):
        """
        该方法用于潜在地检查在 `from_pretrained` 中传递的参数是否存在冲突。
        对于所有未来与 transformers 集成的量化器,都需要定义它。
        如果不需要显式检查,则简单返回即可。
        """
        return
    # 定义一个方法,用于预处理模型,在加载权重之前设置模型属性或转换模型。
    def preprocess_model(self, model: "PreTrainedModel", **kwargs):
        """
        设置模型属性和/或在加载权重之前对模型进行转换。此时模型应在元设备上初始化,
        因此可以自由地操纵模型的骨架以替换模块。确保覆盖抽象方法 `_process_model_before_weight_loading`。

        Args:
            model (`~transformers.PreTrainedModel`):
                要量化的模型
            kwargs (`dict`, *optional*):
                被传递到 `_process_model_before_weight_loading` 的关键字参数。
        """
        # 设置模型的量化标志为True
        model.is_quantized = True
        # 设置模型的量化方法为配置文件中指定的量化方法
        model.quantization_method = self.quantization_config.quant_method
        # 调用 `_process_model_before_weight_loading` 方法进行进一步处理
        return self._process_model_before_weight_loading(model, **kwargs)

    # 定义一个方法,用于在加载权重后对模型进行后处理。
    def postprocess_model(self, model: "PreTrainedModel", **kwargs):
        """
        在加载权重后对模型进行后处理。确保覆盖抽象方法 `_process_model_after_weight_loading`。

        Args:
            model (`~transformers.PreTrainedModel`):
                要量化的模型
            kwargs (`dict`, *optional*):
                被传递到 `_process_model_after_weight_loading` 的关键字参数。
        """
        # 调用 `_process_model_after_weight_loading` 方法进行后处理
        return self._process_model_after_weight_loading(model, **kwargs)

    # 抽象方法,用于在加载权重之前处理模型,需要在子类中实现
    @abstractmethod
    def _process_model_before_weight_loading(self, model, **kwargs):
        ...

    # 抽象方法,用于在加载权重之后处理模型,需要在子类中实现
    @abstractmethod
    def _process_model_after_weight_loading(self, model, **kwargs):
        ...

    # 抽象属性,用于指示对象是否可序列化,需要在子类中实现
    @property
    @abstractmethod
    def is_serializable(self):
        ...

    # 抽象属性,用于指示对象是否可训练,需要在子类中实现
    @property
    @abstractmethod
    def is_trainable(self):
        ...

.\quantizers\quantizers_utils.py

# 导入必要的类型
from typing import Any, Tuple

# 定义一个函数,接受一个模块和一个字符串作为参数,返回一个元组
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
    # 如果字符串中包含".",则按"."分割字符串
    if "." in tensor_name:
        splits = tensor_name.split(".")
        # 遍历分割后的字符串列表,除了最后一个元素
        for split in splits[:-1]:
            # 获取模块中的属性
            new_module = getattr(module, split)
            # 如果获取的属性为None,则抛出异常
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            # 更新模块为获取到的属性
            module = new_module
        # 更新张量名称为分割后列表的最后一个元素
        tensor_name = splits[-1]
    # 返回更新后的模块和张量名称组成的元组
    return module, tensor_name

.\quantizers\quantizer_aqlm.py

# 版权声明及版权许可信息
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的模块和库
import importlib
from typing import TYPE_CHECKING, Optional

# 导入版本管理相关的模块
from packaging import version

# 导入基类 HfQuantizer
from .base import HfQuantizer

# 检查类型,仅在类型检查时导入相关模块
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

# 导入 AQLM 相关函数和类
from ..integrations import replace_with_aqlm_linear
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin

# 如果 Torch 可用,则导入 Torch 模块
if is_torch_available():
    import torch

# 获取日志记录器
logger = logging.get_logger(__name__)


class AqlmHfQuantizer(HfQuantizer):
    """
    AQLM 方法的量化器。支持加载预量化模型。
    """

    # 需要校准
    requires_calibration = True
    # 需要的包
    required_packages = ["aqlm"]
    # 最佳量化器,默认为 None
    optimum_quantizer = None

    def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
        super().__init__(quantization_config, **kwargs)
        self.quantization_config = quantization_config

    def validate_environment(self, *args, **kwargs):
        # 检查是否安装了 Accelerate 加速库
        if not is_accelerate_available():
            raise ImportError("Using `aqlm` quantization requires Accelerate: `pip install accelerate`")

        # 检查是否安装了 AQLM 库
        if not is_aqlm_available():
            raise ImportError("Using `aqlm` quantization requires AQLM: `pip install aqlm[gpu,cpu]`")

    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        # 如果未指定 Torch 的数据类型
        if torch_dtype is None:
            # 如果 CUDA 可用,则默认为 torch.float16
            if torch.cuda.is_available():
                torch_dtype = torch.float16
                logger.info(
                    "CUDA available. Assuming AQLM inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
                )
            # 如果 CUDA 不可用,默认为 torch.float32
            else:
                torch_dtype = torch.float32
                logger.info(
                    "CUDA is unavailable. Assuming AQLM inference on CPU and loading the model in `torch.float32`. To overwrite it, set `torch_dtype` manually."
                )
        return torch_dtype

    def _process_model_before_weight_loading(
        self,
        model: "PreTrainedModel",
        **kwargs,
    ):
        # 替换模型中的线性层为 AQLM 线性层
        replace_with_aqlm_linear(
            model,
            quantization_config=self.quantization_config,
            linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
        )
        # 将量化配置信息保存到模型配置中
        model.config.quantization_config = self.quantization_config
    # 在模型加载权重后处理模型的方法,返回未修改的模型对象
    def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
        return model

    # 属性方法,用于检查模型是否可训练
    # 如果当前安装的 `aqlm` 版本支持训练,返回 True
    # 否则记录警告信息并返回 False
    @property
    def is_trainable(self, model: Optional["PreTrainedModel"] = None):
        # 检查当前安装的 `aqlm` 版本是否大于等于 1.0.2
        aqlm_supports_training = version.parse(importlib.metadata.version("aqlm")) >= version.parse("1.0.2")
        if aqlm_supports_training:
            return True
        else:
            # 记录警告信息,提示用户更新 `aqlm` 版本以支持训练
            logger.warn(
                f"Currently installed `aqlm` version ({importlib.metadata.version('aqlm')}) doesn't support training. If you wish to train a quantized model, please update `aqlm` with `pip install aqlm>=1.0.2`"
            )
            return False

    # 属性方法,用于检查对象是否可序列化
    @property
    def is_serializable(self):
        return True

.\quantizers\quantizer_awq.py

# 版权声明及许可信息,指明HuggingFace Inc.团队拥有版权
#
# 根据Apache许可证2.0版("许可证")授权,除非符合许可证要求,
# 否则不得使用本文件。您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,本软件按"原样"分发,不提供任何明示或
# 含示的担保或条件。详细信息请参阅许可证。
import importlib.metadata
from typing import TYPE_CHECKING

from packaging import version

# 导入基础的HfQuantizer类
from .base import HfQuantizer

# 如果类型检查开启,则导入PreTrainedModel类
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

# 导入一些工具和依赖的模块
from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import AWQLinearVersion

# 如果torch可用,则导入torch库
if is_torch_available():
    import torch

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# AwqQuantizer类继承自HfQuantizer类,提供Activation-aware Weight Quantization(AWQ)的4位量化支持
class AwqQuantizer(HfQuantizer):
    """
    4-bit quantization for Activation-aware Weight Quantization(AWQ) (https://arxiv.org/abs/2306.00978)
    """

    # AWQ需要数据校准 - 我们只支持推断(inference)
    requires_calibration = True

    # 必需的包名称列表,包括"awq"和"accelerate"
    required_packages = ["awq", "accelerate"]

    # 初始化方法,接受quantization_config和其他关键字参数
    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)

    # 验证运行环境的方法,检查GPU是否可用以及必需的库是否已安装
    def validate_environment(self, device_map, **kwargs):
        # 如果没有CUDA设备可用,则抛出运行时错误
        if not torch.cuda.is_available():
            raise RuntimeError("GPU is required to run AWQ quantized model.")

        # 如果未安装auto-awq库,则抛出导入错误
        if not is_auto_awq_available():
            raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")

        # 如果未安装accelerate库,则抛出导入错误
        if not is_accelerate_available():
            raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")

        # 如果device_map为None,则发出警告,建议在GPU设备上运行模型
        if device_map is None:
            logger.warning_once(
                "You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
                "your model on a GPU device in order to run your model."
            )
        # 如果device_map不为None,则检查是否包含CPU或磁盘设备,如果是则抛出数值错误
        elif device_map is not None:
            if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
                raise ValueError(
                    "You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
                    " This is not supported. Please remove the CPU or disk device from the device_map."
                )

    # 更新torch数据类型的方法,如果未提供torch_dtype,则使用torch.float16
    def update_torch_dtype(self, torch_dtype):
        if torch_dtype is None:
            torch_dtype = torch.float16
        elif torch_dtype != torch.float16:
            logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
        return torch_dtype
    # 在加载权重前处理模型的方法,用于量化感知训练模型
    def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
        # 导入必要的集成模块:获取不需转换的模块键和替换 AWQ 线性模块
        from ..integrations import get_keys_to_not_convert, replace_with_awq_linear
        
        # 获取不需要转换的模块列表
        self.modules_to_not_convert = get_keys_to_not_convert(model)
        
        # 如果配置中有指定不需要转换的模块,则扩展已有的列表
        if self.quantization_config.modules_to_not_convert is not None:
            self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
        
        # 替换模型中的线性层为 AWQ 线性层,并检查是否有替换发生
        model, has_been_replaced = replace_with_awq_linear(
            model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
        )
        
        # 如果没有进行替换,则发出警告信息
        if not has_been_replaced:
            logger.warning(
                "You are loading an AWQ model but no linear modules were found in your model."
                " Please double check your model architecture, or submit an issue on github if you think this is a bug."
            )
    
    # 在加载权重后处理模型的方法
    def _process_model_after_weight_loading(self, model):
        # 如果配置要求进行模块融合
        if self.quantization_config.do_fuse:
            # 导入模块:融合 AWQ 模块
            from ..integrations import fuse_awq_modules
            
            # 融合模型中的 AWQ 模块
            model = fuse_awq_modules(model, self.quantization_config)
            # 设置 AWQ 被融合的标志为真,考虑将此标志存储在 model.config 中
            model._awq_is_fused = True  # TODO: consider storing this flag in model.config instead
        
        # 如果使用的 AWQ 版本为 EXLLAMA
        if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
            # 导入模块:初始化 AWQ EXLLAMA 后端模块
            from ..integrations import post_init_awq_exllama_modules
            
            # 对模型进行 AWQ EXLLAMA 后端模块的初始化
            model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
    
    # 判断模型是否可序列化的属性
    @property
    def is_serializable(self):
        # 如果配置要求进行模块融合,则不可保存
        if self.quantization_config.do_fuse:
            logger.warning("You cannot save an AWQ model that uses fused modules!")
            return False
        
        # 如果使用的 AWQ 版本为 EXLLAMA,则不可保存
        if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
            logger.warning("You cannot save an AWQ model that uses Exllama backend!")
            return False
        
        # 否则可保存
        return True
    
    # 判断模型是否可训练的属性
    @property
    def is_trainable(self):
        # 定义 PEFT 微调所需的最小 AWQ 版本
        MIN_AWQ_VERSION_FOR_PEFT = "0.2.0"
        
        # 检查当前 autoawq 模块的版本是否支持 PEFT 微调
        return version.parse(importlib.metadata.version("autoawq")) >= version.parse(MIN_AWQ_VERSION_FOR_PEFT)

.\quantizers\quantizer_bnb_4bit.py

# 导入必要的模块和库

# 版权声明和许可证信息
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入 importlib 模块,用于动态加载模块
import importlib

# 导入类型检查相关模块和类型
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

# 导入版本控制相关模块
from packaging import version

# 导入基础量化类 HfQuantizer
from .base import HfQuantizer

# 导入工具函数 get_module_from_name
from .quantizers_utils import get_module_from_name

# 如果是类型检查模式,则导入 PreTrainedModel 类
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

# 导入加速库是否可用检查函数
from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging

# 如果 torch 可用,则导入 torch 相关模块和类
if is_torch_available():
    import torch
    from ..pytorch_utils import Conv1D

# 获取 logger 对象
logger = logging.get_logger(__name__)

# 定义 Bnb4BitHfQuantizer 类,继承自 HfQuantizer
class Bnb4BitHfQuantizer(HfQuantizer):
    """
    从 bitsandbytes.py 量化方法中实现的 4 位量化:
        在加载之前: 将 transformer 层转换为 Linear4bit
        在加载期间: 加载 16 位权重并传递给层对象
        在量化后: 在第一次 .cuda() 调用时将 Linear4bit 中的单个权重量化为 4 位
        保存:
            从状态字典中,像往常一样; 保存权重和 `quant_state` 组件
        加载:
            需要定位 `quant_state` 组件并传递给 Param4bit 构造函数
    """

    # 使用 keep_in_fp32 模块
    use_keep_in_fp32_modules = True

    # 需要参数量化
    requires_parameters_quantization = True

    # 不需要校准
    requires_calibration = False

    # 必需的软件包
    required_packages = ["bitsandbytes", "accelerate"]

    # 初始化方法,接受量化配置和其他关键字参数
    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)

        # 如果量化配置中定义了 llm_int8_skip_modules,则使用它来指定不转换的模块
        if self.quantization_config.llm_int8_skip_modules is not None:
            self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

# 结束 Bnb4BitHfQuantizer 类定义
    # 检查是否安装了加速库和 bitsandbytes 库,如果没有则引发 ImportError 异常
    if not (is_accelerate_available() and is_bitsandbytes_available()):
        raise ImportError(
            "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
            "and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
        )

    # 检查是否从 TensorFlow 或 Flax 来源转换权重,这种情况下不支持转换,需使用 PyTorch 格式的权重
    if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
        raise ValueError(
            "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
            " sure the weights are in PyTorch format."
        )

    # 检查是否有可用的 CUDA 设备,如果没有找到 GPU 则引发 RuntimeError 异常
    if not torch.cuda.is_available():
        raise RuntimeError("No GPU found. A GPU is needed for quantization.")

    # 获取参数中的 device_map,检查其类型为 dict,并且未开启 llm_int8_enable_fp32_cpu_offload 选项
    device_map = kwargs.get("device_map", None)
    if (
        device_map is not None
        and isinstance(device_map, dict)
        and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
    ):
        # 剔除 self.modules_to_not_convert 中的模块,生成不包含 lm_head 的新 device_map
        device_map_without_lm_head = {
            key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
        }
        # 如果 device_map_without_lm_head 中的值包含 "cpu" 或 "disk",则引发 ValueError 异常
        if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
            raise ValueError(
                """
                Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
                quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
                in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to
                `from_pretrained`. Check
                https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
                for more details.
                """
            )

    # 检查 bitsandbytes 库的版本是否小于 0.39.0,如果是则引发 ValueError 异常
    if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.39.0"):
        raise ValueError(
            "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training"
            " make sure you have the latest version of `bitsandbytes` installed"
        )
    # 调整目标数据类型,确保与加速库的版本兼容
    def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
        # 检查加速库版本是否大于0.19.0
        if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
            # 如果是最新版本,导入加速库的自定义数据类型
            from accelerate.utils import CustomDtype

            # 如果目标数据类型不是 torch.int8
            if target_dtype != torch.int8:
                # 记录日志,说明目标数据类型被替换为 CustomDtype.INT4,用于4位BnB量化
                logger.info(f"target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
            # 返回 CustomDtype.INT4 作为新的目标数据类型
            return CustomDtype.INT4
        else:
            # 如果加速库版本过低,抛出数值错误,提示用户升级加速库以支持自动设备映射计算
            raise ValueError(
                "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
                " the appropriate device map, you should upgrade your `accelerate` library,"
                "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map"
                "calculation. You may encounter unexpected behavior, or pass your own device map"
            )

    # 检查量化参数是否符合预期
    def check_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ) -> bool:
        # 导入 bitsandbytes 库作为 bnb
        import bitsandbytes as bnb

        # 从模型中获取指定参数的模块和张量名称
        module, tensor_name = get_module_from_name(model, param_name)
        
        # 如果参数是 bnb.nn.Params4bit 类型,返回 True
        if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
            # TODO: 添加序列化实现后,添加加载组件的数据类型检查
            return True
        # 如果模块是 bnb.nn.Linear4bit 并且张量名称是 "bias",返回 True
        elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
            # bias 可能被 accelerate 的 regular set_module_tensor_to_device() 加载,
            # 但在那里会错误地使用未初始化的权重。
            return True
        else:
            # 其他情况返回 False
            return False

    # 创建量化参数
    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        state_dict: Dict[str, Any],
        unexpected_keys: Optional[List[str]] = None,
    ):
        # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.adjust_max_memory
        # 调整最大内存限制,确保在量化过程中有足够的空间
        def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
            # 遍历每个键值对,将值乘以0.90以腾出更多的空间用于量化过程中创建的缓冲区
            max_memory = {key: val * 0.90 for key, val in max_memory.items()}
            # 返回调整后的最大内存限制字典
            return max_memory

        # Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_torch_dtype
    # 更新 torch 数据类型,以及返回更新后的 torch 数据类型
    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        # 如果 torch 数据类型为空
        if torch_dtype is None:
            # 强制将 `dtype` 设置为 float16,这是 `bitsandbytes` 的要求,以便在8位或4位下启用模型加载
            logger.info(
                "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
                "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
                "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
                " torch_dtype=torch.float16 to remove this warning.",
                torch_dtype,
            )
            # 将 torch 数据类型强制设为 float16
            torch_dtype = torch.float16
        # 返回更新后的 torch 数据类型
        return torch_dtype

    # 从 transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map 复制而来
    # 更新设备映射
    def update_device_map(self, device_map):
        # 如果设备映射为空
        if device_map is None:
            # 将设备映射设为当前 CUDA 设备
            device_map = {"": torch.cuda.current_device()}
            # 输出日志信息
            logger.info(
                "The device_map was not initialized. "
                "Setting device_map to {'':torch.cuda.current_device()}. "
                "If you want to use the model for inference, please set device_map ='auto' "
            )
        # 返回更新后的设备映射
        return device_map

    # 从 transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_before_weight_loading 复制而来
    # 在加载权重之前处理模型
    def _process_model_before_weight_loading(
        self,
        model: "PreTrainedModel",
        device_map,
        keep_in_fp32_modules: List[str] = [],
        **kwargs,
        from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
        # 从上层模块导入函数 `get_keys_to_not_convert` 和 `replace_with_bnb_linear`

        load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
        # 从量化配置中获取是否启用 8 位整数加载到 FP32 CPU 卸载的设置

        # 将一些模块(如 lm_head)保持在其原始数据类型中,以确保数值稳定性
        if self.quantization_config.llm_int8_skip_modules is None:
            self.modules_to_not_convert = get_keys_to_not_convert(model)
        else:
            self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
        # 如果未指定不转换的模块列表,则使用 get_keys_to_not_convert 函数从模型中获取
        # 否则,使用量化配置中指定的不转换的模块列表

        if not isinstance(self.modules_to_not_convert, list):
            self.modules_to_not_convert = [self.modules_to_not_convert]
        # 如果不转换的模块不是列表类型,则转换为列表类型

        self.modules_to_not_convert.extend(keep_in_fp32_modules)
        # 将 keep_in_fp32_modules 中的模块添加到 self.modules_to_not_convert 中

        # 扩展 `self.modules_to_not_convert` 到需要卸载到 `cpu` 或 `disk` 的键
        if isinstance(device_map, dict) and len(device_map.keys()) > 1:
            keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
            # 在 device_map 中查找值为 "disk" 或 "cpu" 的键

            if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
                raise ValueError(
                    "If you want to offload some keys to `cpu` or `disk`, you need to set "
                    "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
                    " converted to 8-bit but kept in 32-bit."
                )
            # 如果有键被卸载到 `cpu` 或 `disk` 但未设置加载到 FP32 CPU 卸载,则抛出 ValueError 异常

            self.modules_to_not_convert.extend(keys_on_cpu)
            # 将键添加到 self.modules_to_not_convert 中

        model = replace_with_bnb_linear(
            model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
        )
        # 使用 replace_with_bnb_linear 函数替换模型中的模块,传入不转换的模块列表和量化配置

        # TODO: consider bringing replace_with_bnb_linear() code from ..integrations/bitsandbyter.py to here
        # TODO:考虑将来自 ..integrations/bitsandbyter.py 的 replace_with_bnb_linear() 代码引入到此处

        model.config.quantization_config = self.quantization_config
        # 设置模型配置的量化配置为当前的量化配置

    # 从 transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading 复制,将 8bit->4bit
    def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
        model.is_loaded_in_4bit = True
        model.is_4bit_serializable = self.is_serializable
        # 将模型标记为已加载 4 位,并根据可序列化性设置 4 位模型是否可序列化
        return model

    @property
    def is_serializable(self):
        _is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
        # 检查当前安装的 bitsandbytes 版本是否支持 4 位模型的序列化

        if not _is_4bit_serializable:
            logger.warning(
                "You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. "
                "If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed."
            )
            return False
        # 如果不支持 4 位模型的序列化,记录警告信息并返回 False

        return True
        # 否则返回 True,表示支持 4 位模型的序列化

    @property
    def is_trainable(self) -> bool:
        return True
    # 表示当前对象是可训练的,始终返回 True

.\quantizers\quantizer_bnb_8bit.py

# 导入 importlib 模块,用于动态加载其他模块
import importlib
# 导入类型提示相关模块:TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

# 导入版本控制相关模块
from packaging import version

# 从当前包中导入 HfQuantizer 基类
from .base import HfQuantizer

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 从 ..modeling_utils 中导入 PreTrainedModel 类型
    from ..modeling_utils import PreTrainedModel

# 从 ..utils 模块导入加速相关函数:is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
# 从当前目录下 quantizers_utils 模块中导入 get_module_from_name 函数
from .quantizers_utils import get_module_from_name

# 如果 Torch 可用
if is_torch_available():
    # 导入 Torch 模块
    import torch
    # 从 ..pytorch_utils 模块导入 Conv1D 类

    from ..pytorch_utils import Conv1D

# 获取 logger 对象,用于记录日志
logger = logging.get_logger(__name__)


class Bnb8BitHfQuantizer(HfQuantizer):
    """
    从 bitsandbytes 量化方法中获得 8 位量化:
        加载前:将 transformer 层转换为 Linear8bitLt
        加载中:加载 16 位权重并传递给层对象
        加载后:在首次 .cuda() 调用时将 Linear8bitLt 中的单个权重量化为 8 位
    保存:
        与通常一样,从状态字典中保存权重和 'SCB' 组件
    加载:
        需要定位 'SCB' 组件并传递给 Linear8bitLt 对象
    """

    # 是否保持在 FP32 模块中
    use_keep_in_fp32_modules = True
    # 是否需要参数量化
    requires_parameters_quantization = True
    # 是否需要校准
    requires_calibration = False

    # 必需的包
    required_packages = ["bitsandbytes", "accelerate"]

    def __init__(self, quantization_config, **kwargs):
        # 调用父类的构造函数
        super().__init__(quantization_config, **kwargs)

        # 如果配置中指定了 llm_int8_skip_modules
        if self.quantization_config.llm_int8_skip_modules is not None:
            # 将其赋值给不需要转换的模块列表
            self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
    # 验证运行环境是否支持加速库和 bitsandbytes 库
    def validate_environment(self, *args, **kwargs):
        # 检查是否安装了必要的加速库和 bitsandbytes 库
        if not (is_accelerate_available() and is_bitsandbytes_available()):
            raise ImportError(
                "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
                "and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
            )

        # 检查是否从 TensorFlow 或 Flax 权重进行转换,目前不支持这种转换
        if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
            raise ValueError(
                "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
                " sure the weights are in PyTorch format."
            )

        # 检查是否有可用的 GPU,量化需要 GPU 支持
        if not torch.cuda.is_available():
            raise RuntimeError("No GPU found. A GPU is needed for quantization.")

        # 获取并验证传入的设备映射信息
        device_map = kwargs.get("device_map", None)
        if (
            device_map is not None
            and isinstance(device_map, dict)
            and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
        ):
            # 创建一个不包含指定模块的设备映射副本
            device_map_without_lm_head = {
                key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
            }
            # 如果设备映射中包含 CPU 或 Disk,则抛出错误
            if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values():
                raise ValueError(
                    """
                    Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the
                    quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules
                    in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to
                    `from_pretrained`. Check
                    https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
                    for more details.
                    """
                )

        # 检查安装的 bitsandbytes 版本是否支持8位推断和训练
        if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.37.2"):
            raise ValueError(
                "You have a version of `bitsandbytes` that is not compatible with 8bit inference and training"
                " make sure you have the latest version of `bitsandbytes` installed"
            )

    # 调整最大内存配置以供量化期间使用
    def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
        # 将最大内存配置按 90% 缩放,以便在量化期间创建的缓冲区有足够的空间
        max_memory = {key: val * 0.90 for key, val in max_memory.items()}
        return max_memory
    # 更新 Torch 张量数据类型为指定的 `torch.dtype`
    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        # 如果输入的 torch_dtype 为 None,则强制设置为 float16,这是 `bitsandbytes` 的要求
        logger.info(
            "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
            "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
            "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
            " torch_dtype=torch.float16 to remove this warning.",
            torch_dtype,
        )
        torch_dtype = torch.float16
        return torch_dtype

    # 更新设备映射表,确保 device_map 不为 None
    def update_device_map(self, device_map):
        # 如果 device_map 为 None,则设置为当前 CUDA 设备的空映射
        if device_map is None:
            device_map = {"": torch.cuda.current_device()}
            logger.info(
                "The device_map was not initialized. "
                "Setting device_map to {'':torch.cuda.current_device()}. "
                "If you want to use the model for inference, please set device_map ='auto' "
            )
        return device_map

    # 调整目标数据类型为 torch.int8
    def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
        # 如果目标数据类型不是 torch.int8,则替换为 torch.int8,用于 8-bit BnB 量化
        if target_dtype != torch.int8:
            logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
        return torch.int8

    # 检查是否为量化参数
    def check_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ):
        import bitsandbytes as bnb

        # 获取模型和参数名称对应的模块
        module, tensor_name = get_module_from_name(model, param_name)
        # 检查参数是否为 Int8Params 类型(来自 bitsandbytes 库)
        if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
            # 如果预先量化已启用
            if self.pre_quantized:
                # 如果参数名中不包含 `weight` 的替代项 `SCB`,则抛出异常
                if param_name.replace("weight", "SCB") not in state_dict.keys():
                    raise ValueError("Missing quantization component `SCB`")
                # 如果参数值的数据类型不是 torch.int8,则抛出异常
                if param_value.dtype != torch.int8:
                    raise ValueError(
                        f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
                    )
            return True
        return False

    # 创建量化参数
    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        state_dict: Dict[str, Any],
        unexpected_keys: Optional[List[str]] = None,
    ):
        # 此方法的具体实现在此省略,需要根据功能进一步补充
        """
        组合来自 _load_state_dict_into_meta_model 和 .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() 的逻辑
        需要从状态字典中获取辅助项,如果找到的话,将其从 unexpected_keys 中移除
        """
        # 导入 bitsandbytes 库作为 bnb
        import bitsandbytes as bnb

        # 根据 param_name 构造 fp16 统计数据的键名
        fp16_statistics_key = param_name.replace("weight", "SCB")
        # 从 state_dict 中获取 fp16 统计数据
        fp16_statistics = state_dict.get(fp16_statistics_key, None)

        # 根据 param_name 获取模型中的模块和张量名
        module, tensor_name = get_module_from_name(model, param_name)
        # 检查张量名是否存在于模块的参数中
        if tensor_name not in module._parameters:
            raise ValueError(f"{module} 没有名为 {tensor_name} 的参数或缓冲区.")

        # 获取旧值
        old_value = getattr(module, tensor_name)

        # 检查模块的参数类型是否为 bnb.nn.Int8Params
        if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
            raise ValueError(f"参数 `{tensor_name}` 应该是 `bnb.nn.Int8Params` 的实例.")

        # 检查旧值的设备是否为 "meta",并且目标设备不是 "meta" 或 torch.device("meta"),且 param_value 为 None
        if (
            old_value.device == torch.device("meta")
            and target_device not in ["meta", torch.device("meta")]
            and param_value is None
        ):
            raise ValueError(f"{tensor_name} 在 meta 设备上,需要在 {target_device} 上放置一个 `value`.")

        # 将 param_value 转移到 CPU
        new_value = param_value.to("cpu")

        # 如果 self.pre_quantized 为真且 self.is_serializable 为假,则抛出异常
        if self.pre_quantized and not self.is_serializable:
            raise ValueError(
                "检测到 int8 权重,但 bitsandbytes 的版本不兼容 int8 序列化。"
                "请确保下载最新的 `bitsandbytes` 版本。`pip install --upgrade bitsandbytes`."
            )

        # 如果模块的源类是 Conv1D,则在量化之前对权重矩阵进行转置
        if issubclass(module.source_cls, Conv1D):
            if fp16_statistics is None:
                new_value = new_value.T

        # 将旧值的关键字参数赋给 kwargs
        kwargs = old_value.__dict__
        # 使用 bitsandbytes 创建一个新的 Int8Params 实例并将其移至目标设备
        new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)

        # 更新模块中的参数值
        module._parameters[tensor_name] = new_value

        # 如果存在 fp16_statistics,则将其设置为新值的 SCB 属性,并从 unexpected_keys 中移除 fp16_statistics_key
        if fp16_statistics is not None:
            setattr(module.weight, "SCB", fp16_statistics.to(target_device))
            if unexpected_keys is not None:
                unexpected_keys.remove(fp16_statistics_key)

    def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
        # 将模型标记为已加载 8 位
        model.is_loaded_in_8bit = True
        # 设置模型的 8 位序列化属性为当前对象的 is_serializable 属性
        model.is_8bit_serializable = self.is_serializable
        return model

    def _process_model_before_weight_loading(
        self,
        model: "PreTrainedModel",
        device_map,
        keep_in_fp32_modules: List[str] = [],
        **kwargs,
        from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
        # 从模块导入必要的函数和类,用于不转换的模块获取和替换操作

        load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
        # 从量化配置中获取是否启用在 8 位情况下的 FP32 CPU 卸载加载

        # 由于数值稳定性原因,保持某些模块(如 lm_head)在其原始 dtype 下不转换
        if self.quantization_config.llm_int8_skip_modules is None:
            self.modules_to_not_convert = get_keys_to_not_convert(model)
        else:
            self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
        # 根据配置决定不转换的模块列表,如果未指定,则从模型中获取

        if not isinstance(self.modules_to_not_convert, list):
            self.modules_to_not_convert = [self.modules_to_not_convert]
        # 如果不转换的模块不是列表类型,则转换为列表

        self.modules_to_not_convert.extend(keep_in_fp32_modules)
        # 将需要保持在 FP32 的模块列表扩展到不转换的模块列表中

        # 将需要卸载到 CPU 或磁盘的键扩展到 `self.modules_to_not_convert`
        if isinstance(device_map, dict) and len(device_map.keys()) > 1:
            keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]

            if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
                raise ValueError(
                    "If you want to offload some keys to `cpu` or `disk`, you need to set "
                    "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
                    " converted to 8-bit but kept in 32-bit."
                )
            self.modules_to_not_convert.extend(keys_on_cpu)
        # 如果设备映射是字典类型且键数大于1,则根据设备映射中的值为 "disk" 或 "cpu" 的键添加到不转换的模块列表中

        model = replace_with_bnb_linear(
            model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
        )
        # 使用指定的替换函数替换模型中的部分模块,传入不转换的模块列表和量化配置

        # TODO: 考虑将 `replace_with_bnb_linear()` 函数从 ..integrations/bitsandbyter.py 文件中移到这里

        model.config.quantization_config = self.quantization_config
        # 设置模型配置的量化配置属性

    @property
    def is_serializable(self):
        _bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
            "0.37.2"
        )
        # 检查当前安装的 bitsandbytes 版本是否支持 8 位序列化

        if not _bnb_supports_8bit_serialization:
            logger.warning(
                "You are calling `save_pretrained` to a 8-bit converted model, but your `bitsandbytes` version doesn't support it. "
                "If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed. You will most likely face errors or"
                " unexpected behaviours."
            )
            return False
        # 如果不支持 8 位序列化,则发出警告并返回 False

        return True
        # 如果支持 8 位序列化,则返回 True

    @property
    def is_trainable(self) -> bool:
        return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")
    # 检查当前安装的 bitsandbytes 版本是否支持模型训练

.\quantizers\quantizer_gptq.py

# 导入必要的模块和函数
import importlib  # 导入 importlib 模块,用于动态导入
from typing import TYPE_CHECKING, Optional  # 导入 TYPE_CHECKING 和 Optional 类型提示

# 导入版本比较模块
from packaging import version

# 导入基础的 HfQuantizer 类
from .base import HfQuantizer  

# 如果是类型检查环境,则导入 PreTrainedModel 类
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel  

# 导入一些辅助函数和模块,例如自动量化、最优设置、Torch 是否可用以及日志记录
from ..utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging
from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin

# 如果 Torch 可用,则导入 Torch 模块
if is_torch_available():
    import torch

# 获取日志记录器对象
logger = logging.get_logger(__name__)

# 定义 GptqHfQuantizer 类,继承自 HfQuantizer 类
class GptqHfQuantizer(HfQuantizer):
    """
    GPTQ 方法的量化器 - 通过 `auto_gptq` 包支持模型的校准。如果用户加载未预量化的模型,则在幕后进行量化。
    """

    # 是否需要校准的标志,这里不需要校准
    requires_calibration = False  

    # 所需的包列表
    required_packages = ["optimum", "auto_gptq"]  

    # 最优量化器对象,初始化为 None
    optimum_quantizer = None  

    def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
        super().__init__(quantization_config, **kwargs)
        
        # 动态导入 GPTQQuantizer 类
        from optimum.gptq import GPTQQuantizer  

        # 使用配置信息初始化最优量化器
        self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum())

    def validate_environment(self, *args, **kwargs):
        # 检查 auto-gptq 的版本是否支持 CPU
        gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
        
        # 如果 auto-gptq 不支持 CPU 并且没有可用的 GPU,则抛出运行时错误
        if not gptq_supports_cpu and not torch.cuda.is_available():
            raise RuntimeError("GPU is required to quantize or run quantize model.")
        
        # 如果 optimum 和 auto-gptq 包不可用,则抛出导入错误
        elif not (is_optimum_available() and is_auto_gptq_available()):
            raise ImportError(
                "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
            )
        
        # 如果 auto-gptq 的版本低于 0.4.2,则抛出导入错误
        elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
            raise ImportError(
                "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
            )

    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        # 如果 torch_dtype 为 None,则设置为 torch.float16
        if torch_dtype is None:
            torch_dtype = torch.float16
        
        # 如果 torch_dtype 不是 torch.float16,则建议设置为 torch.float16 以提高效率
        elif torch_dtype != torch.float16:
            logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
        
        # 返回更新后的 torch_dtype
        return torch_dtype
    # 处理模型在加载权重前的预处理操作
    def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
        # 检查模型主输入名称是否为 "input_ids",若不是则抛出运行时错误
        if model.__class__.main_input_name != "input_ids":
            raise RuntimeError("We can only quantize pure text model.")
        
        # 如果模型已经预量化,则使用最优量化器转换模型
        if self.pre_quantized:
            model = self.optimum_quantizer.convert_model(model)

    # 处理模型在加载权重后的后处理操作
    def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
        # 如果模型已经预量化,则使用最优量化器进行模型后初始化处理
        if self.pre_quantized:
            model = self.optimum_quantizer.post_init_model(model)
        else:
            # 如果未预量化且未设置量化配置的分词器,则使用模型的名称或路径作为量化配置的分词器
            if self.quantization_config.tokenizer is None:
                self.quantization_config.tokenizer = model.name_or_path
            
            # 使用最优量化器对模型进行量化,使用给定的分词器
            self.optimum_quantizer.quantize_model(model, self.quantization_config.tokenizer)
            # 将模型的配置信息更新为从最优量化器导出的量化配置
            model.config.quantization_config = GPTQConfig.from_dict(self.optimum_quantizer.to_dict())

    # 检查模型是否可训练的属性,始终返回 True
    @property
    def is_trainable(self, model: Optional["PreTrainedModel"] = None):
        return True

    # 检查模型是否可序列化的属性,始终返回 True
    @property
    def is_serializable(self):
        return True

.\quantizers\quantizer_quanto.py

# 导入模块 importlib,用于动态导入模块
import importlib
# 导入类型检查标记 TYPE_CHECKING、Any、Dict、List、Optional、Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

# 导入版本比较模块 version,来自 packaging 包
from packaging import version

# 从当前包中导入 base 模块中的 HfQuantizer 类
from .base import HfQuantizer
# 从 quantizers_utils 模块中导入 get_module_from_name 函数
from .quantizers_utils import get_module_from_name

# 如果是类型检查状态,从 modeling_utils 模块中导入 PreTrainedModel 类
if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

# 从 utils 模块中导入 is_accelerate_available、is_quanto_available、is_torch_available、logging 函数和类
from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
# 从 utils.quantization_config 模块中导入 QuantoConfig 类
from ..utils.quantization_config import QuantoConfig

# 如果 torch 可用,导入 torch 模块
if is_torch_available():
    import torch

# 从 logging 模块中获取 logger 对象
logger = logging.get_logger(__name__)

# 定义 QuantoHfQuantizer 类,继承自 HfQuantizer 类
class QuantoHfQuantizer(HfQuantizer):
    """
    Quantizer for the quanto library
    """

    # 定义 required_packages 列表,指明需要的依赖包
    required_packages = ["quanto", "accelerate"]
    # 指明是否需要参数量化
    requires_parameters_quantization = True
    # 指明是否需要校准
    requires_calibration = False

    # 初始化方法,接收 quantization_config 参数和其他关键字参数
    def __init__(self, quantization_config: QuantoConfig, **kwargs):
        # 调用父类 HfQuantizer 的初始化方法
        super().__init__(quantization_config, **kwargs)
        # 调用 post_init 方法
        self.post_init()

    # 定义 post_init 方法,用于安全检查
    def post_init(self):
        # 如果 quantization_config.activations 不为空且未预量化
        if self.quantization_config.activations is not None and not self.pre_quantized:
            # 抛出值错误异常,提示不支持对激活进行量化
            raise ValueError(
                "We don't support quantizing the activations with transformers library."
                "Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
            )

    # 定义 validate_environment 方法,用于验证环境是否支持 quanto 库和 accelerate 库
    def validate_environment(self, *args, **kwargs):
        # 如果 quanto 库不可用,抛出导入错误异常
        if not is_quanto_available():
            raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
        # 如果 accelerate 库不可用,抛出导入错误异常
        if not is_accelerate_available():
            raise ImportError("Loading a quanto quantized model requires accelerate library (`pip install quanto`)")

    # 定义 update_device_map 方法,用于更新设备映射
    def update_device_map(self, device_map):
        # 如果 device_map 为 None,则初始化为 {'': 'cpu'}
        if device_map is None:
            device_map = {"": "cpu"}
            # 记录日志信息,提示设备映射未初始化,将其设置为 {'': 'cpu'}
            logger.info(
                "The device_map was not initialized. "
                "Setting device_map to {'':'cpu'}. "
                "If you want to use the model for inference, please set device_map ='auto'"
            )
        # 返回更新后的 device_map
        return device_map

    # 定义 update_torch_dtype 方法,用于更新 torch 数据类型
    def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
        # 如果 torch_dtype 为 None
        if torch_dtype is None:
            # 记录日志信息,提示在 from_pretrained 中未指定 torch_dtype,默认设置为 torch.float32
            logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
            # 将 torch_dtype 设置为 torch.float32
            torch_dtype = torch.float32
        # 返回更新后的 torch_dtype
        return torch_dtype
    # 更新模型中缺失的键列表,返回更新后未缺失的键列表
    def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
        import quanto  # 导入quanto模块

        not_missing_keys = []  # 初始化未缺失的键列表
        # 遍历模型中的命名模块
        for name, module in model.named_modules():
            # 如果模块是quanto.QModuleMixin的实例
            if isinstance(module, quanto.QModuleMixin):
                # 遍历缺失的键列表
                for missing in missing_keys:
                    # 如果模块名称在缺失键中或者在以prefix开头的缺失键中
                    if (
                        (name in missing or name in f"{prefix}.{missing}")
                        and not missing.endswith(".weight")  # 排除以.weight结尾的键
                        and not missing.endswith(".bias")    # 排除以.bias结尾的键
                    ):
                        not_missing_keys.append(missing)  # 将该键添加到未缺失的键列表中
        # 返回更新后的未缺失的键列表
        return [k for k in missing_keys if k not in not_missing_keys]

    # 检查是否需要量化参数
    def check_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        state_dict: Dict[str, Any],
        **kwargs,
    ) -> bool:
        """
        Check if a parameter needs to be quantized.
        """
        import quanto  # 导入quanto模块

        device_map = kwargs.get("device_map", None)  # 获取device_map参数
        param_device = kwargs.get("param_device", None)  # 获取param_device参数
        # 如果模块将要被离线到cpu上,则不进行模型量化
        if device_map is not None and param_device is not None:
            device_map_values = set(device_map.values())  # 获取device_map的所有值集合
            if param_device == "cpu" and len(device_map_values) > 1:
                if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}):
                    return False  # 如果条件满足,返回False,不进行模型量化

        module, tensor_name = get_module_from_name(model, param_name)  # 获取参数所在的模块和张量名
        # 只量化权重,不量化偏置
        if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
            # 如果权重已经量化,不需要使用`create_quantized_param`重新创建
            return not module.frozen
        else:
            return False  # 其他情况下不进行模型量化

    # 调整最大内存限制
    def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
        # 将最大内存限制减少10%
        max_memory = {key: val * 0.90 for key, val in max_memory.items()}
        return max_memory  # 返回调整后的最大内存限制字典

    # 创建量化参数
    def create_quantized_param(
        self,
        model: "PreTrainedModel",
        param_value: "torch.Tensor",
        param_name: str,
        target_device: "torch.device",
        *args,
        **kwargs,
    ):
        """
        Create the quantized parameter by calling .freeze() after setting it to the module.
        """
        from accelerate.utils import set_module_tensor_to_device  # 从accelerate.utils模块导入set_module_tensor_to_device函数

        set_module_tensor_to_device(model, param_name, target_device, param_value)  # 将参数设置到模块并移动到目标设备
        module, _ = get_module_from_name(model, param_name)  # 获取参数所在的模块
        module.freeze()  # 冻结模块,使其无法再修改
        module.weight.requires_grad = False  # 设置权重张量不需要梯度计算
    # 调整目标数据类型以匹配加速库版本的需求
    def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
        # 检查当前加速库版本是否大于 0.27.0
        if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.27.0"):
            # 导入加速库中的自定义数据类型
            from accelerate.utils import CustomDtype

            # 定义数据类型映射关系
            mapping = {
                "int8": torch.int8,
                "float8": CustomDtype.FP8,
                "int4": CustomDtype.INT4,
                "int2": CustomDtype.INT2,
            }
            # 根据量化配置中的权重类型选择目标数据类型
            target_dtype = mapping[self.quantization_config.weights]
            return target_dtype
        else:
            # 抛出数值错误,提示升级加速库版本
            raise ValueError(
                "You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
                " the appropriate device map, you should upgrade your `accelerate` library,"
                "`pip install --upgrade accelerate` or install it from source."
            )

    # 在加载权重前处理模型
    def _process_model_before_weight_loading(
        self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
    ):
        # 导入必要的函数以及类
        from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers

        # 如果未设置不转换的模块列表,则根据模型获取不转换模块的键
        if self.quantization_config.modules_to_not_convert is None:
            self.modules_to_not_convert = get_keys_to_not_convert(model)
        else:
            self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

        # 确保不转换的模块列表为一个列表类型
        if not isinstance(self.modules_to_not_convert, list):
            self.modules_to_not_convert = [self.modules_to_not_convert]

        # 将需要保持在 FP32 精度的模块添加到不转换的模块列表中
        self.modules_to_not_convert.extend(keep_in_fp32_modules)

        # 使用自定义函数替换量化层并更新模型
        model, _ = replace_with_quanto_layers(
            model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
        )
        # 更新模型的量化配置
        model.config.quantization_config = self.quantization_config

    # 在加载完权重后处理模型
    def _process_model_after_weight_loading(self, model):
        return model

    # 返回模型是否可训练的属性
    @property
    def is_trainable(self, model: Optional["PreTrainedModel"] = None):
        return False

    # 返回模型是否可序列化的属性
    @property
    def is_serializable(self):
        return False

.\quantizers\__init__.py

# 导入自动量化相关模块
from .auto import AutoHfQuantizer, AutoQuantizationConfig
# 导入基础量化器模块
from .base import HfQuantizer

.\safetensors_conversion.py

import json  # 导入json模块,用于处理JSON格式数据
import uuid  # 导入uuid模块,用于生成唯一标识符
from typing import Optional  # 导入Optional类型,用于可选的类型声明

import requests  # 导入requests模块,用于发送HTTP请求
from huggingface_hub import Discussion, HfApi, get_repo_discussions  # 导入huggingface_hub相关函数和类

from .utils import cached_file, logging  # 从当前包中导入cached_file和logging模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象


def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
    # 获取主提交的commit_id
    main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
    # 遍历当前模型repo中的所有讨论
    for discussion in get_repo_discussions(repo_id=model_id, token=token):
        # 判断讨论是否为打开的PR并且标题为pr_title
        if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
            # 获取与讨论相关的提交信息
            commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)

            # 检查主提交是否与PR的第二个提交相同
            if main_commit == commits[1].commit_id:
                return discussion  # 如果条件符合,返回此讨论对象
    return None  # 如果未找到符合条件的讨论,返回None


def spawn_conversion(token: str, private: bool, model_id: str):
    logger.info("Attempting to convert .bin model on the fly to safetensors.")

    safetensors_convert_space_url = "https://safetensors-convert.hf.space"
    sse_url = f"{safetensors_convert_space_url}/queue/join"
    sse_data_url = f"{safetensors_convert_space_url}/queue/data"

    # 指定fn_index以指示使用Space的run方法
    hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}

    def start(_sse_connection, payload):
        # 迭代SSE连接的每一行数据
        for line in _sse_connection.iter_lines():
            line = line.decode()
            if line.startswith("data:"):
                resp = json.loads(line[5:])  # 解析收到的JSON数据
                logger.debug(f"Safetensors conversion status: {resp['msg']}")
                # 处理不同的转换状态
                if resp["msg"] == "queue_full":
                    raise ValueError("Queue is full! Please try again.")
                elif resp["msg"] == "send_data":
                    event_id = resp["event_id"]
                    # 发送数据到sse_data_url
                    response = requests.post(
                        sse_data_url,
                        stream=True,
                        params=hash_data,
                        json={"event_id": event_id, **payload, **hash_data},
                    )
                    response.raise_for_status()  # 检查响应状态
                elif resp["msg"] == "process_completed":
                    return  # 如果转换完成,结束函数

    with requests.get(sse_url, stream=True, params=hash_data) as sse_connection:
        data = {"data": [model_id, private, token]}
        try:
            logger.debug("Spawning safetensors automatic conversion.")
            start(sse_connection, data)  # 调用start函数开始转换
        except Exception as e:
            logger.warning(f"Error during conversion: {repr(e)}")  # 处理转换过程中的异常情况


def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
    private = api.model_info(model_id).private  # 获取模型信息中的private字段值

    logger.info("Attempting to create safetensors variant")
    pr_title = "Adding `safetensors` variant of this model"
    token = kwargs.get("token")

    # 这段代码查找当前repo中是否有关于safetensors的已打开的PR
    # 调用函数 `previous_pr`,获取先前创建的 pull request 对象
    pr = previous_pr(api, model_id, pr_title, token=token)

    # 如果 pr 为 None 或者(不是私有且 pr 的作者不是 "SFConvertBot"),则执行以下操作:
    if pr is None or (not private and pr.author != "SFConvertBot"):
        # 调用函数 `spawn_conversion`,启动转换过程
        spawn_conversion(token, private, model_id)
        # 再次获取先前创建的 pull request 对象
        pr = previous_pr(api, model_id, pr_title, token=token)
    else:
        # 记录日志,指示安全张量的 pull request 已存在
        logger.info("Safetensors PR exists")

    # 构建 SHA 引用,格式为 "refs/pr/{pr.num}"
    sha = f"refs/pr/{pr.num}"

    # 返回 SHA 引用
    return sha
# 自动转换函数,根据预训练模型名称或路径以及其他缓存文件参数来执行自动转换
def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
    # 使用给定的 token 创建 Hugging Face API 的实例
    api = HfApi(token=cached_file_kwargs.get("token"))
    
    # 获取转换 Pull Request 的参考 SHA 值
    sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

    # 如果没有找到 SHA 值,则返回 None
    if sha is None:
        return None, None
    
    # 将 SHA 值添加到缓存文件参数中的 revision 键中
    cached_file_kwargs["revision"] = sha
    
    # 从缓存文件参数中删除 _commit_hash 键
    del cached_file_kwargs["_commit_hash"]

    # 这是一个额外的 HEAD 调用,如果能从 PR 描述中推断出分片/非分片,可以删除这个调用
    # 检查指定的模型是否存在分片的 "model.safetensors.index.json" 文件
    sharded = api.file_exists(
        pretrained_model_name_or_path,
        "model.safetensors.index.json",
        revision=sha,
        token=cached_file_kwargs.get("token"),
    )
    
    # 根据是否存在分片文件,选择相应的文件名
    filename = "model.safetensors.index.json" if sharded else "model.safetensors"

    # 缓存解析后的归档文件
    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
    
    # 返回解析后的归档文件路径、SHA 值和是否分片的标志
    return resolved_archive_file, sha, sharded

.\sagemaker\trainer_sm.py

# 导入警告模块,用于在特定情况下发出警告
import warnings

# 从上级目录中导入 Trainer 类
from ..trainer import Trainer

# 从上级目录中的 utils 模块中导入 logging 工具
from ..utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 SageMakerTrainer 类,继承自 Trainer 类
class SageMakerTrainer(Trainer):
    def __init__(self, args=None, **kwargs):
        # 发出警告,提示用户 SageMakerTrainer 类将在 Transformers v5 版本中被移除,建议使用 Trainer 类
        warnings.warn(
            "`SageMakerTrainer` is deprecated and will be removed in v5 of Transformers. You can use `Trainer` "
            "instead.",
            FutureWarning,
        )
        # 调用父类 Trainer 的初始化方法,传递参数 args 和其他关键字参数
        super().__init__(args=args, **kwargs)

.\sagemaker\training_args_sm.py

# 导入必要的模块和库
import importlib.util  # 导入用于动态加载模块的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入与操作系统交互的模块
import warnings  # 导入用于处理警告的模块
from dataclasses import dataclass, field  # 导入用于创建数据类的装饰器和字段定义

import torch  # 导入 PyTorch 库

from ..training_args import TrainingArguments  # 从上级目录中导入训练参数类
from ..utils import cached_property, is_sagemaker_dp_enabled, logging  # 从上级目录中导入缓存属性装饰器、SageMaker DP 启用状态检查函数和日志模块

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象


# TODO: 在 SageMakerTrainer 重构后应移动到 `utils` 模块中


def is_sagemaker_model_parallel_available():
    # 从环境变量中获取 SageMaker 的模型并行参数
    smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}")
    try:
        # 解析 JSON 数据并检查是否包含 "partitions" 字段,模型并行需要此字段
        smp_options = json.loads(smp_options)
        if "partitions" not in smp_options:
            return False
    except json.JSONDecodeError:
        return False

    # 从环境变量中获取 SageMaker 的框架参数
    mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
    try:
        # 解析 JSON 数据并检查是否包含 "sagemaker_mpi_enabled" 字段
        mpi_options = json.loads(mpi_options)
        if not mpi_options.get("sagemaker_mpi_enabled", False):
            return False
    except json.JSONDecodeError:
        return False

    # 最后,检查是否存在 `smdistributed` 模块,以确认 SageMaker 是否支持模型并行
    return importlib.util.find_spec("smdistributed") is not None


# 如果 SageMaker 支持模型并行,则导入相应的模型并行库并进行初始化
if is_sagemaker_model_parallel_available():
    import smdistributed.modelparallel.torch as smp  # 导入 SageMaker 模型并行的 Torch 扩展库

    smp.init()  # 初始化 SageMaker 模型并行


@dataclass
class SageMakerTrainingArguments(TrainingArguments):
    mp_parameters: str = field(
        default="",
        metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in SageMakerTrainer"},
    )

    def __post_init__(self):
        super().__post_init__()
        # 发出警告,提示 `SageMakerTrainingArguments` 将在 Transformers v5 中被移除,建议使用 `TrainingArguments` 替代
        warnings.warn(
            "`SageMakerTrainingArguments` is deprecated and will be removed in v5 of Transformers. You can use "
            "`TrainingArguments` instead.",
            FutureWarning,
        )

    @cached_property
    # 设置设备
    def _setup_devices(self) -> "torch.device":
        # 打印日志信息
        logger.info("PyTorch: setting up devices")
        # 检查是否启用了torch分布式,并且本地进程的local_rank为-1
        if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
            # 打印警告信息
            logger.warning(
                "torch.distributed process group is initialized, but local_rank == -1. "
                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
            )
        # 如果禁用了CUDA
        if self.no_cuda:
            # 将设备设置为CPU
            device = torch.device("cpu")
            # GPU数量设为0
            self._n_gpu = 0
        # 如果支持SageMaker模型并行
        elif is_sagemaker_model_parallel_available():
            local_rank = smp.local_rank()
            device = torch.device("cuda", local_rank)
            # GPU数量设为1
            self._n_gpu = 1
        # 如果启用了SageMaker分布式训练
        elif is_sagemaker_dp_enabled():
            # 导入SageMaker分布式训练模块
            import smdistributed.dataparallel.torch.torch_smddp  # noqa: F401
            # 初始化进程组
            torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
            self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        # 如果local_rank为-1
        elif self.local_rank == -1:
            # 如果n_gpu大于1,将使用nn.DataParallel。
            # 如果只想使用指定的GPU子集,可以使用`CUDA_VISIBLE_DEVICES=0`
            # 显式设置CUDA到第一个(索引0)CUDA设备,否则`set_device`会触发缺少设备索引的错误。
            # 索引0考虑了环境中可用的GPU,因此`CUDA_VISIBLE_DEVICES=1,2`与`cuda:0`将使用该环境中的第一个GPU,即GPU#1
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            # 有时在此之前尚未运行postinit中的行,因此只需检查我们不是默认值。
            self._n_gpu = torch.cuda.device_count()
        else:
            # 在这里,我们将使用torch分布式。
            # 初始化分布式后端,负责同步节点/GPU
            if not torch.distributed.is_initialized():
                torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1

        # 如果设备类型为cuda
        if device.type == "cuda":
            # 设置当前使用的设备
            torch.cuda.set_device(device)

        # 返回设备
        return device

    @property
    # 获取world_size属性
    def world_size(self):
        # 如果支持SageMaker模型并行
        if is_sagemaker_model_parallel_available():
            # 返回并行大小
            return smp.dp_size()

        # 返回基类的world_size
        return super().world_size

    @property
    # 获取place_model_on_device属性
    def place_model_on_device(self):
        # 如果不支持SageMaker模型并行
        return not is_sagemaker_model_parallel_available()

    @property
    # 获取_no_sync_in_gradient_accumulation属性
    def _no_sync_in_gradient_accumulation(self):
        return False

.\sagemaker\__init__.py

# 导入 SageMakerTrainer 类从 trainer_sm 模块中
from .trainer_sm import SageMakerTrainer
# 导入 SageMakerTrainingArguments 和 is_sagemaker_dp_enabled 从 training_args_sm 模块中
from .training_args_sm import SageMakerTrainingArguments, is_sagemaker_dp_enabled

.\testing_utils.py

# 导入必要的标准库和第三方库
import collections  # 提供额外的数据容器,如deque(双端队列)
import contextlib  # 提供用于管理上下文的工具
import doctest  # 提供用于运行文档测试的模块
import functools  # 提供函数式编程的工具,如partial函数应用
import importlib  # 提供用于动态加载模块的工具
import inspect  # 提供用于检查源代码的工具
import logging  # 提供用于记录日志消息的功能
import multiprocessing  # 提供用于多进程编程的工具
import os  # 提供与操作系统交互的功能
import re  # 提供支持正则表达式的工具
import shlex  # 提供用于解析和操作命令行字符串的工具
import shutil  # 提供高级文件操作功能的工具
import subprocess  # 提供用于创建子进程的功能
import sys  # 提供与Python解释器交互的功能
import tempfile  # 提供创建临时文件和目录的功能
import time  # 提供时间相关的功能
import unittest  # 提供用于编写和运行单元测试的工具
from collections import defaultdict  # 提供默认字典的功能
from collections.abc import Mapping  # 提供抽象基类,用于检查映射类型
from functools import wraps  # 提供用于创建装饰器的工具
from io import StringIO  # 提供内存中文本I/O的工具
from pathlib import Path  # 提供面向对象的路径操作功能
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union  # 提供类型提示支持
from unittest import mock  # 提供用于模拟测试的工具
from unittest.mock import patch  # 提供用于模拟测试的工具

import urllib3  # 提供HTTP客户端的功能

from transformers import logging as transformers_logging  # 导入transformers库中的logging模块

from .integrations import (  # 导入自定义模块中的一系列集成检查函数
    is_clearml_available,
    is_optuna_available,
    is_ray_available,
    is_sigopt_available,
    is_tensorboard_available,
    is_wandb_available,
)
from .integrations.deepspeed import is_deepspeed_available  # 导入自定义模块中的深度加速集成检查函数
from .utils import (  # 导入自定义模块中的一系列实用工具检查函数
    is_accelerate_available,
    is_apex_available,
    is_aqlm_available,
    is_auto_awq_available,
    is_auto_gptq_available,
    is_bitsandbytes_available,
    is_bs4_available,
    is_cv2_available,
    is_cython_available,
    is_decord_available,
    is_detectron2_available,
    is_essentia_available,
    is_faiss_available,
    is_flash_attn_2_available,
    is_flax_available,
    is_fsdp_available,
    is_ftfy_available,
    is_g2p_en_available,
    is_galore_torch_available,
    is_ipex_available,
    is_jieba_available,
    is_jinja_available,
    is_jumanpp_available,
    is_keras_nlp_available,
    is_levenshtein_available,
    is_librosa_available,
    is_natten_available,
    is_nltk_available,
    is_onnx_available,
    is_optimum_available,
    is_pandas_available,
    is_peft_available,
    is_phonemizer_available,
    is_pretty_midi_available,
    is_pyctcdecode_available,
    is_pytesseract_available,
    is_pytest_available,
    is_pytorch_quantization_available,
    is_quanto_available,
    is_rjieba_available,
    is_sacremoses_available,
    is_safetensors_available,
    is_scipy_available,
    is_sentencepiece_available,
    is_seqio_available,
    is_soundfile_availble,
    is_spacy_available,
    is_sudachi_available,
    is_sudachi_projection_available,
    is_tensorflow_probability_available,
    is_tensorflow_text_available,
    is_tf2onnx_available,
    is_tf_available,
    is_timm_available,
    is_tokenizers_available,
    is_torch_available,
)
    # 检查当前设备是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_available_on_device,
    # 检查当前 CPU 是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_cpu_available,
    # 检查当前 GPU 是否支持 Torch 的 BF16 数据类型
    is_torch_bf16_gpu_available,
    # 检查当前设备是否支持 Torch 的 FP16 数据类型
    is_torch_fp16_available_on_device,
    # 检查当前设备是否支持 Torch 的 NeuronCore 加速器
    is_torch_neuroncore_available,
    # 检查当前设备是否支持 Torch 的 NPU 加速器
    is_torch_npu_available,
    # 检查当前设备是否支持 Torch 的 SDPA 加速器
    is_torch_sdpa_available,
    # 检查当前设备是否支持 Torch 的 TensorRT FX 加速器
    is_torch_tensorrt_fx_available,
    # 检查当前设备是否支持 Torch 的 TF32 数据类型
    is_torch_tf32_available,
    # 检查当前设备是否支持 Torch 的 XLA 加速器
    is_torch_xla_available,
    # 检查当前设备是否支持 Torch 的 XPU 加速器
    is_torch_xpu_available,
    # 检查当前环境是否支持 Torch Audio 库
    is_torchaudio_available,
    # 检查当前环境是否支持 TorchDynamo 库
    is_torchdynamo_available,
    # 检查当前环境是否支持 TorchVision 库
    is_torchvision_available,
    # 检查当前环境是否支持 Torch 的 Vision 扩展
    is_vision_available,
    # 将字符串转换为布尔值(支持"true", "false", "yes", "no", "1", "0"等)
    strtobool,
# 如果加速功能可用,则从 accelerate.state 中导入 AcceleratorState 和 PartialState 类
if is_accelerate_available():
    from accelerate.state import AcceleratorState, PartialState


# 如果 pytest 可用,则从 _pytest.doctest 中导入以下模块
# Module: 用于表示 Python 模块的类
# _get_checker: 获取 doctest 的检查器
# _get_continue_on_failure: 获取 doctest 的继续失败选项
# _get_runner: 获取 doctest 的运行器
# _is_mocked: 检查是否模拟了对象
# _patch_unwrap_mock_aware: 解除 Mock 对象感知的补丁
# get_optionflags: 获取 doctest 的选项标志
from _pytest.doctest import (
    Module,
    _get_checker,
    _get_continue_on_failure,
    _get_runner,
    _is_mocked,
    _patch_unwrap_mock_aware,
    get_optionflags,
)

# 如果 pytest 不可用,则将 Module 和 DoctestItem 设置为 object 类型
else:
    Module = object
    DoctestItem = object


# 定义了一个小型模型的标识符字符串
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"

# 用于测试自动检测模型类型的标识符
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"

# 用于测试 Hub 的用户和端点
USER = "__DUMMY_TRANSFORMERS_USER__"
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"

# 仅在受控的 CI 实例中可用,用于测试用的令牌
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"


# 从环境变量中解析布尔类型的标志
def parse_flag_from_env(key, default=False):
    try:
        value = os.environ[key]
    except KeyError:
        # 如果 KEY 未设置,则使用默认值 `default`
        _value = default
    else:
        # 如果 KEY 已设置,则尝试将其转换为 True 或 False
        try:
            _value = strtobool(value)
        except ValueError:
            # 如果值不是 `yes` 或 `no`,则抛出异常
            raise ValueError(f"If set, {key} must be yes or no.")
    return _value


# 从环境变量中解析整数类型的值
def parse_int_from_env(key, default=None):
    try:
        value = os.environ[key]
    except KeyError:
        _value = default
    else:
        try:
            _value = int(value)
        except ValueError:
            # 如果值不是整数,则抛出异常
            raise ValueError(f"If set, {key} must be a int.")
    return _value


# 根据环境变量 `RUN_SLOW` 解析是否运行慢速测试的标志
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
# 根据环境变量 `RUN_PT_TF_CROSS_TESTS` 解析是否运行 PyTorch 和 TensorFlow 交叉测试的标志
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_PT_FLAX_CROSS_TESTS` 解析是否运行 PyTorch 和 Flax 交叉测试的标志
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
# 根据环境变量 `RUN_CUSTOM_TOKENIZERS` 解析是否运行自定义分词器测试的标志
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
# 根据环境变量 `HUGGINGFACE_CO_STAGING` 解析是否运行在 Hugging Face CO 预发布环境中的标志
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
# 根据环境变量 `TF_GPU_MEMORY_LIMIT` 解析 TensorFlow GPU 内存限制的值
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
# 根据环境变量 `RUN_PIPELINE_TESTS` 解析是否运行管道测试的标志
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
# 根据环境变量 `RUN_TOOL_TESTS` 解析是否运行工具测试的标志
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
# 根据环境变量 `RUN_THIRD_PARTY_DEVICE_TESTS` 解析是否运行第三方设备测试的标志
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)


# 函数装饰器,用于标记 PT+TF 交叉测试
def is_pt_tf_cross_test(test_case):
    """
    Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.

    PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable
    to a truthy value and selecting the is_pt_tf_cross_test pytest mark.

    """
    # 如果未设置环境变量 `RUN_PT_TF_CROSS_TESTS` 或者当前环境中没有安装 PyTorch 或 TensorFlow,
    # 则跳过 PT+TF 测试
    if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
        return unittest.skip("test is PT+TF test")(test_case)
    else:
        # 尝试导入 pytest 模块,避免在主库中硬编码依赖 pytest
        try:
            import pytest  
        # 如果导入失败,返回原始的 test_case
        except ImportError:
            return test_case
        # 如果导入成功,应用 pytest.mark.is_pt_tf_cross_test() 装饰器到 test_case 上
        else:
            return pytest.mark.is_pt_tf_cross_test()(test_case)
# 标记一个测试用例为控制 PyTorch 和 Flax 交互的测试的装饰器

PT+FLAX 测试默认情况下会被跳过,只有当设置了环境变量 RUN_PT_FLAX_CROSS_TESTS 为真值并且选择了 is_pt_flax_cross_test pytest 标记时才会运行。

def is_pt_flax_cross_test(test_case):
    if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
        # 如果不满足运行条件(未设置环境变量或者没有可用的 PyTorch 或 Flax),则跳过测试
        return unittest.skip("test is PT+FLAX test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_pt_flax_cross_test 标记来标记测试用例
            return pytest.mark.is_pt_flax_cross_test()(test_case)


# 标记一个测试用例为在 staging 环境下运行的测试的装饰器

这些测试将在 huggingface.co 的 staging 环境下运行,而不是真实的模型中心。

def is_staging_test(test_case):
    if not _run_staging:
        # 如果不运行 staging 测试,则跳过测试
        return unittest.skip("test is staging test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_staging_test 标记来标记测试用例
            return pytest.mark.is_staging_test()(test_case)


# 标记一个测试用例为 pipeline 测试的装饰器

如果未将 RUN_PIPELINE_TESTS 设置为真值,则这些测试将被跳过。

def is_pipeline_test(test_case):
    if not _run_pipeline_tests:
        # 如果不运行 pipeline 测试,则跳过测试
        return unittest.skip("test is pipeline test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_pipeline_test 标记来标记测试用例
            return pytest.mark.is_pipeline_test()(test_case)


# 标记一个测试用例为工具测试的装饰器

如果未将 RUN_TOOL_TESTS 设置为真值,则这些测试将被跳过。

def is_tool_test(test_case):
    if not _run_tool_tests:
        # 如果不运行工具测试,则跳过测试
        return unittest.skip("test is a tool test")(test_case)
    else:
        try:
            import pytest  # 我们不需要在主库中强制依赖 pytest
        except ImportError:
            return test_case
        else:
            # 使用 pytest 的 is_tool_test 标记来标记测试用例
            return pytest.mark.is_tool_test()(test_case)


# 标记一个测试用例为慢速测试的装饰器

慢速测试默认情况下会被跳过。设置 RUN_SLOW 环境变量为真值以运行这些测试。

def slow(test_case):
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)


# 标记一个测试用例为太慢测试的装饰器

太慢的测试在修复过程中会被跳过。不应将任何测试标记为 "tooslow",因为这些测试不会被 CI 测试。

def tooslow(test_case):
    return unittest.skip("test is too slow")(test_case)


# 标记一个测试用例为自定义分词器测试的装饰器
    """
    自定义分词器需要额外的依赖项,默认情况下会被跳过。将环境变量 RUN_CUSTOM_TOKENIZERS
    设置为真值,以便运行它们。
    """
    # 返回一个装饰器,根据 _run_custom_tokenizers 的真假决定是否跳过测试用例
    return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
# 装饰器,用于标记需要 BeautifulSoup4 的测试用例。在未安装 BeautifulSoup4 时跳过这些测试。
def require_bs4(test_case):
    return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)


# 装饰器,用于标记需要 GaLore 的测试用例。在未安装 GaLore 时跳过这些测试。
def require_galore_torch(test_case):
    return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)


# 装饰器,用于标记需要 OpenCV 的测试用例。在未安装 OpenCV 时跳过这些测试。
def require_cv2(test_case):
    return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)


# 装饰器,用于标记需要 Levenshtein 的测试用例。在未安装 Levenshtein 时跳过这些测试。
def require_levenshtein(test_case):
    return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)


# 装饰器,用于标记需要 NLTK 的测试用例。在未安装 NLTK 时跳过这些测试。
def require_nltk(test_case):
    return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)


# 装饰器,用于标记需要 accelerate 的测试用例。在未安装 accelerate 时跳过这些测试。
def require_accelerate(test_case):
    return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


# 装饰器,用于标记需要 fsdp 的测试用例。在未安装 fsdp 或版本不符合要求时跳过这些测试。
def require_fsdp(test_case, min_version: str = "1.12.0"):
    return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(test_case)


# 装饰器,用于标记需要 g2p_en 的测试用例。在未安装 SentencePiece 时跳过这些测试。
def require_g2p_en(test_case):
    return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)


# 装饰器,用于标记需要 safetensors 的测试用例。在未安装 safetensors 时跳过这些测试。
def require_safetensors(test_case):
    return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)


# 装饰器,用于标记需要 rjieba 的测试用例。在未安装 rjieba 时跳过这些测试。
def require_rjieba(test_case):
    return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)


# 装饰器,用于标记需要 jieba 的测试用例。在未安装 jieba 时跳过这些测试。
def require_jieba(test_case):
    return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)


# 装饰器,用于标记需要 jinja 的测试用例。在此处仅声明函数,实际装饰逻辑未提供。
def require_jinja(test_case):
    # Placeholder for decorator marking tests requiring Jinja
    pass
    # 使用装饰器标记一个需要 jinja 的测试用例。如果 jinja 没有安装,则跳过这些测试。
    """
    使用 unittest.skipUnless 函数来动态地装饰测试用例,只有在 jinja 可用时才运行该测试用例。
    如果 is_jinja_available() 函数返回 True,则装饰器返回一个可用于跳过测试的装饰器函数,否则返回 None。
    """
    return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
# 根据条件判断是否加载 tf2onnx
def require_tf2onnx(test_case):
    return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)


# 根据条件判断是否加载 ONNX
def require_onnx(test_case):
    return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)


# 根据条件判断是否加载 Timm
def require_timm(test_case):
    """
    Decorator marking a test that requires Timm.

    These tests are skipped when Timm isn't installed.
    """
    return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)


# 根据条件判断是否加载 NATTEN
def require_natten(test_case):
    """
    Decorator marking a test that requires NATTEN.

    These tests are skipped when NATTEN isn't installed.
    """
    return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)


# 根据条件判断是否加载 PyTorch
def require_torch(test_case):
    """
    Decorator marking a test that requires PyTorch.

    These tests are skipped when PyTorch isn't installed.
    """
    return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


# 根据条件判断是否加载 Flash Attention
def require_flash_attn(test_case):
    """
    Decorator marking a test that requires Flash Attention.

    These tests are skipped when Flash Attention isn't installed.
    """
    return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)


# 根据条件判断是否加载 PyTorch's SDPA
def require_torch_sdpa(test_case):
    """
    Decorator marking a test that requires PyTorch's SDPA.

    These tests are skipped when requirements are not met (torch version).
    """
    return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)


# 根据条件判断是否加载 HF token
def require_read_token(fn):
    """
    A decorator that loads the HF token for tests that require to load gated models.
    """
    token = os.getenv("HF_HUB_READ_TOKEN")

    @wraps(fn)
    def _inner(*args, **kwargs):
        with patch("huggingface_hub.utils._headers.get_token", return_value=token):
            return fn(*args, **kwargs)

    return _inner


# 根据条件判断是否加载 PEFT
def require_peft(test_case):
    """
    Decorator marking a test that requires PEFT.

    These tests are skipped when PEFT isn't installed.
    """
    return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)


# 根据条件判断是否加载 Torchvision
def require_torchvision(test_case):
    """
    Decorator marking a test that requires Torchvision.

    These tests are skipped when Torchvision isn't installed.
    """
    return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)


# 根据条件判断是否加载 PyTorch 或 TensorFlow
def require_torch_or_tf(test_case):
    """
    Decorator marking a test that requires PyTorch or TensorFlow.

    These tests are skipped when neither PyTorch nor TensorFlow is installed.
    """
    return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
        test_case
    )


# 根据条件判断是否加载 Intel Extension for PyTorch
def require_intel_extension_for_pytorch(test_case):
    """
    Decorator marking a test that requires Intel Extension for PyTorch.
    """
    # 注释部分未提供
    pass
    # 当未安装Intel Extension for PyTorch或者其版本与当前PyTorch版本不匹配时,跳过这些测试。
    """
    返回一个装饰器,用于根据条件跳过测试。
    装饰器检查是否可用Intel Extension for PyTorch(IPEX)。
    如果不可用或版本不匹配,则跳过测试,并提供相应的提示信息。
    参考链接:https://github.com/intel/intel-extension-for-pytorch
    """
    return unittest.skipUnless(
        is_ipex_available(),
        "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
        " https://github.com/intel/intel-extension-for-pytorch",
    )(test_case)
# 装饰器,用于标记一个测试需要 TensorFlow probability
def require_tensorflow_probability(test_case):
    # 返回一个装饰器,其功能是当 TensorFlow probability 未安装时跳过测试
    return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
        test_case
    )


# 装饰器,用于标记一个测试需要 torchaudio
def require_torchaudio(test_case):
    # 返回一个装饰器,其功能是当 torchaudio 未安装时跳过测试
    return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)


# 装饰器,用于标记一个测试需要 TensorFlow
def require_tf(test_case):
    # 返回一个装饰器,其功能是当 TensorFlow 未安装时跳过测试
    return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)


# 装饰器,用于标记一个测试需要 JAX & Flax
def require_flax(test_case):
    # 返回一个装饰器,其功能是当 JAX 或 Flax 未安装时跳过测试
    return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)


# 装饰器,用于标记一个测试需要 SentencePiece
def require_sentencepiece(test_case):
    # 返回一个装饰器,其功能是当 SentencePiece 未安装时跳过测试
    return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)


# 装饰器,用于标记一个测试需要 Sacremoses
def require_sacremoses(test_case):
    # 返回一个装饰器,其功能是当 Sacremoses 未安装时跳过测试
    return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)


# 装饰器,用于标记一个测试需要 Seqio
def require_seqio(test_case):
    # 返回一个装饰器,其功能是当 Seqio 未安装时跳过测试
    return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)


# 装饰器,用于标记一个测试需要 Scipy
def require_scipy(test_case):
    # 返回一个装饰器,其功能是当 Scipy 未安装时跳过测试
    return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)


# 装饰器,用于标记一个测试需要 🤗 Tokenizers
def require_tokenizers(test_case):
    # 返回一个装饰器,其功能是当 🤗 Tokenizers 未安装时跳过测试
    return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)


# 装饰器,用于标记一个测试需要 tensorflow_text
def require_tensorflow_text(test_case):
    # 返回一个装饰器,其功能是当 tensorflow_text 未安装时跳过测试
    return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)


# 装饰器,用于标记一个测试需要 keras_nlp
def require_keras_nlp(test_case):
    # 返回一个装饰器,其功能是当 keras_nlp 未安装时跳过测试
    return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)


# 装饰器,用于标记一个测试需要 Pandas
def require_pandas(test_case):
    """
    Decorator marking a test that requires Pandas. These tests are skipped when Pandas isn't installed.
    """
    return unittest.skipUnless(is_pandas_available(), "test requires Pandas")(test_case)
    # 使用装饰器标记一个需要 pandas 的测试用例。当 pandas 没有安装时,这些测试将被跳过。
    """
    # 返回一个装饰器,根据 pandas 的可用性决定是否跳过测试用例
    return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
# 标记一个测试需要 PyTesseract。如果 PyTesseract 没有安装,则跳过这些测试。
def require_pytesseract(test_case):
    return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)


# 标记一个测试需要 PyTorch Quantization Toolkit。如果 PyTorch Quantization Toolkit 没有安装,则跳过这些测试。
def require_pytorch_quantization(test_case):
    return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(test_case)


# 标记一个测试需要视觉相关的依赖。如果 torchaudio 没有安装,则跳过这些测试。
def require_vision(test_case):
    return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)


# 标记一个测试需要 ftfy。如果 ftfy 没有安装,则跳过这些测试。
def require_ftfy(test_case):
    return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)


# 标记一个测试需要 SpaCy。如果 SpaCy 没有安装,则跳过这些测试。
def require_spacy(test_case):
    return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)


# 标记一个测试需要 decord。如果 decord 没有安装,则跳过这些测试。
def require_decord(test_case):
    return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)


# 标记一个测试需要多 GPU 设置(在 PyTorch 中)。如果没有多个 GPU,则跳过这些测试。
# 若要仅运行多 GPU 测试,请假设所有测试名称包含 multi_gpu:
# $ pytest -sv ./tests -k "multi_gpu"
def require_torch_multi_gpu(test_case):
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

    return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)


# 标记一个测试需要多加速器设置(在 PyTorch 中)。如果没有多个加速器,则跳过这些测试。
# 若要仅运行多加速器测试,请假设所有测试名称包含 multi_accelerator:
# $ pytest -sv ./tests -k "multi_accelerator"
def require_torch_multi_accelerator(test_case):
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(test_case)


# 标记一个测试需要 0 或 1 个 GPU 设置(在 PyTorch 中)。
def require_torch_non_multi_gpu(test_case):
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    import torch
    # 返回一个装饰器,用于条件性跳过测试
    return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
# 标记一个测试需要零或一个加速器设置(在PyTorch中)的装饰器
def require_torch_non_multi_accelerator(test_case):
    # 如果PyTorch不可用,则跳过测试
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    # 返回一个条件,该条件检查当前设备上的后端设备数量是否小于2,否则跳过测试
    return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)


# 标记一个测试需要零、一个或两个GPU设置(在PyTorch中)的装饰器
def require_torch_up_to_2_gpus(test_case):
    # 如果PyTorch不可用,则跳过测试
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

    # 返回一个条件,该条件检查当前机器上的GPU数量是否小于3,否则跳过测试
    return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)


# 标记一个测试需要零、一个或两个加速器设置(在PyTorch中)的装饰器
def require_torch_up_to_2_accelerators(test_case):
    # 如果PyTorch不可用,则跳过测试
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    # 返回一个条件,该条件检查当前设备上的后端设备数量是否小于3,否则跳过测试
    return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(test_case)


# 标记一个测试需要TorchXLA(在PyTorch中)的装饰器
def require_torch_xla(test_case):
    # 返回一个条件,该条件检查当前系统是否支持TorchXLA,否则跳过测试
    return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)


# 标记一个测试需要NeuronCore(在PyTorch中)的装饰器
def require_torch_neuroncore(test_case):
    # 返回一个条件,该条件检查当前系统是否支持NeuronCore,否则跳过测试
    return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(test_case)


# 标记一个测试需要NPU(在PyTorch中)的装饰器
def require_torch_npu(test_case):
    # 返回一个条件,该条件检查当前系统是否支持NPU,否则跳过测试
    return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)


# 标记一个测试需要多NPU设置(在PyTorch中)的装饰器,这些测试在没有多个NPU的机器上会被跳过
def require_torch_multi_npu(test_case):
    # 如果没有NPU可用,则跳过测试
    if not is_torch_npu_available():
        return unittest.skip("test requires PyTorch NPU")(test_case)

    import torch

    # 返回一个条件,该条件检查当前系统上NPU设备的数量是否大于1,否则跳过测试
    return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)


# 标记一个测试需要XPU和IPEX(在PyTorch中)的装饰器
def require_torch_xpu(test_case):
    # 返回一个条件,该条件检查当前系统是否支持IPEX和XPU设备,否则跳过测试
    return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)


# 标记一个测试需要多XPU设置和IPEX(在PyTorch中)的装饰器,这些测试在没有IPEX或多个XPU的机器上会被跳过
def require_torch_multi_xpu(test_case):
    # 返回一个条件,该条件检查当前系统是否支持IPEX和至少一个XPU设备,否则跳过测试
    return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
    """
    如果没有可用的 Torch XPU(例如 IPEX),则跳过测试,并返回相应的提示信息
    """
    if not is_torch_xpu_available():
        # 跳过测试,并返回一个包含跳过原因的消息,用于单元测试框架
        return unittest.skip("test requires IPEX and atleast one XPU device")(test_case)

    # 除非系统有多个 Torch XPU 设备可用,否则跳过测试,并返回相应的提示信息
    return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
if is_torch_available():
    # 如果 Torch 可用,则导入 torch 库
    import torch

    # 如果存在 TRANSFORMERS_TEST_BACKEND 环境变量
    if "TRANSFORMERS_TEST_BACKEND" in os.environ:
        # 获取 backend 名称
        backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
        try:
            # 尝试导入指定的 backend 模块
            _ = importlib.import_module(backend)
        except ModuleNotFoundError as e:
            # 报错信息,指出无法导入指定的 backend 模块
            raise ModuleNotFoundError(
                f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
                f" traceback):\n{e}"
            ) from e

    # 如果存在 TRANSFORMERS_TEST_DEVICE 环境变量
    if "TRANSFORMERS_TEST_DEVICE" in os.environ:
        # 获取 torch_device 名称
        torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
        # 如果 torch_device 是 "cuda" 但 CUDA 不可用,则抛出 ValueError
        if torch_device == "cuda" and not torch.cuda.is_available():
            raise ValueError(
                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
            )
        # 如果 torch_device 是 "xpu" 但 XPU 不可用,则抛出 ValueError
        if torch_device == "xpu" and not is_torch_xpu_available():
            raise ValueError(
                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
            )
        # 如果 torch_device 是 "npu" 但 NPU 不可用,则抛出 ValueError
        if torch_device == "npu" and not is_torch_npu_available():
            raise ValueError(
                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
            )

        try:
            # 尝试创建设备来验证提供的设备名称是否有效
            _ = torch.device(torch_device)
        except RuntimeError as e:
            # 报错信息,指出环境变量 TRANSFORMERS_TEST_DEVICE 指定的设备名称无效
            raise RuntimeError(
                f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
            ) from e
    # 如果 CUDA 可用,则默认设备为 "cuda"
    elif torch.cuda.is_available():
        torch_device = "cuda"
    # 如果需要运行第三方设备测试且 NPU 可用,则设备为 "npu"
    elif _run_third_party_device_tests and is_torch_npu_available():
        torch_device = "npu"
    # 如果需要运行第三方设备测试且 XPU 可用,则设备为 "xpu"
    elif _run_third_party_device_tests and is_torch_xpu_available():
        torch_device = "xpu"
    else:
        # 否则,默认设备为 "cpu"
        torch_device = "cpu"
else:
    # 如果 Torch 不可用,则设备为 None
    torch_device = None

# 如果 TensorFlow 可用,则导入 tensorflow 库
if is_tf_available():
    import tensorflow as tf

# 如果 Flax 可用,则导入 jax 库,并获取默认后端名称
if is_flax_available():
    import jax

    jax_device = jax.default_backend()
else:
    # 否则,设备为 None
    jax_device = None
    # 如果 torch_device 不为 None 并且不是 "cpu",则使用 unittest.skipUnless 装饰器,
    # 其中条件为 "test requires accelerator",表示仅在满足条件时才跳过测试。
    return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
        test_case
    )
# 装饰器,用于标记需要支持 fp16 设备的测试用例
def require_torch_fp16(test_case):
    # 返回一个 unittest 装饰器,根据设备是否支持 fp16 来跳过测试用例
    return unittest.skipUnless(
        is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
    )(test_case)


# 装饰器,用于标记需要支持 bf16 设备的测试用例
def require_torch_bf16(test_case):
    # 返回一个 unittest 装饰器,根据设备是否支持 bf16 来跳过测试用例
    return unittest.skipUnless(
        is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
    )(test_case)


# 装饰器,用于标记需要支持 bf16 GPU 设备的测试用例
def require_torch_bf16_gpu(test_case):
    # 返回一个 unittest 装饰器,根据设备是否支持 bf16 GPU 来跳过测试用例
    return unittest.skipUnless(
        is_torch_bf16_gpu_available(),
        "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
    )(test_case)


# 装饰器,用于标记需要支持 bf16 CPU 设备的测试用例
def require_torch_bf16_cpu(test_case):
    # 返回一个 unittest 装饰器,根据设备是否支持 bf16 CPU 来跳过测试用例
    return unittest.skipUnless(
        is_torch_bf16_cpu_available(),
        "test requires torch>=1.10, using CPU",
    )(test_case)


# 装饰器,用于标记需要支持 tf32 设备的测试用例
def require_torch_tf32(test_case):
    # 返回一个 unittest 装饰器,根据设备是否支持 tf32 来跳过测试用例
    return unittest.skipUnless(
        is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
    )(test_case)


# 装饰器,用于标记需要 detectron2 的测试用例
def require_detectron2(test_case):
    # 返回一个 unittest 装饰器,根据 detectron2 是否可用来跳过测试用例
    return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)


# 装饰器,用于标记需要 faiss 的测试用例
def require_faiss(test_case):
    # 返回一个 unittest 装饰器,根据 faiss 是否可用来跳过测试用例
    return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)


# 装饰器,用于标记需要 optuna 的测试用例
def require_optuna(test_case):
    """
    返回一个 unittest 装饰器,根据 optuna 是否可用来跳过测试用例

    这些测试用例在没有安装 optuna 时会被跳过。
    """
    return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)


# 装饰器,用于标记需要 Ray/tune 的测试用例
def require_ray(test_case):
    """
    返回一个 unittest 装饰器,根据 Ray/tune 是否可用来跳过测试用例

    这些测试用例在没有安装 Ray/tune 时会被跳过。
    """
    return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)


# 装饰器,用于标记需要 SigOpt 的测试用例
def require_sigopt(test_case):
    """
    返回一个 unittest 装饰器,根据 SigOpt 是否可用来跳过测试用例

    这些测试用例在没有安装 SigOpt 时会被跳过。
    """
    return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)


# 装饰器,用于标记需要 wandb 的测试用例
def require_wandb(test_case):
    """
    返回一个 unittest 装饰器,根据 wandb 是否可用来跳过测试用例

    这些测试用例在没有安装 wandb 时会被跳过。
    """
    return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)


# 装饰器,用于标记需要 clearml 的测试用例
def require_clearml(test_case):
    """
    返回一个 unittest 装饰器,根据 clearml 是否可用来跳过测试用例

    这些测试用例在没有安装 clearml 时会被跳过。
    """
    return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
# 标记一个需要 soundfile 库的测试用例的装饰器函数
def require_soundfile(test_case):
    """
    Decorator marking a test that requires soundfile

    These tests are skipped when soundfile isn't installed.

    """
    # 返回一个跳过测试的装饰器,除非 soundfile 可用
    return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)


# 标记一个需要 deepspeed 库的测试用例的装饰器函数
def require_deepspeed(test_case):
    """
    Decorator marking a test that requires deepspeed
    """
    # 返回一个跳过测试的装饰器,除非 deepspeed 可用
    return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)


# 标记一个需要 apex 库的测试用例的装饰器函数
def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
    # 返回一个跳过测试的装饰器,除非 apex 可用
    return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)


# 标记一个需要 aqlm 库的测试用例的装饰器函数
def require_aqlm(test_case):
    """
    Decorator marking a test that requires aqlm
    """
    # 返回一个跳过测试的装饰器,除非 aqlm 可用
    return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)


# 标记一个需要 bitsandbytes 库的测试用例的装饰器函数
def require_bitsandbytes(test_case):
    """
    Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
    """
    # 检查 bitsandbytes 和 torch 是否都可用
    if is_bitsandbytes_available() and is_torch_available():
        try:
            import pytest

            # 使用 pytest 的标记来标记测试用例
            return pytest.mark.bitsandbytes(test_case)
        except ImportError:
            return test_case
    else:
        # 返回一个跳过测试的装饰器,需要 bitsandbytes 和 torch
        return unittest.skip("test requires bitsandbytes and torch")(test_case)


# 标记一个需要 optimum 依赖的测试用例的装饰器函数
def require_optimum(test_case):
    """
    Decorator for optimum dependency
    """
    # 返回一个跳过测试的装饰器,除非 optimum 可用
    return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)


# 标记一个需要 tensorboard 依赖的测试用例的装饰器函数
def require_tensorboard(test_case):
    """
    Decorator for `tensorboard` dependency
    """
    # 返回一个跳过测试的装饰器,除非 tensorboard 可用
    return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")


# 标记一个需要 auto_gptq 依赖的测试用例的装饰器函数
def require_auto_gptq(test_case):
    """
    Decorator for auto_gptq dependency
    """
    # 返回一个跳过测试的装饰器,除非 auto_gptq 可用
    return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)


# 标记一个需要 auto_awq 依赖的测试用例的装饰器函数
def require_auto_awq(test_case):
    """
    Decorator for auto_awq dependency
    """
    # 返回一个跳过测试的装饰器,除非 auto_awq 可用
    return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)


# 标记一个需要 quanto 依赖的测试用例的装饰器函数
def require_quanto(test_case):
    """
    Decorator for quanto dependency
    """
    # 返回一个跳过测试的装饰器,除非 quanto 可用
    return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)


# 标记一个需要 phonemizer 依赖的测试用例的装饰器函数
def require_phonemizer(test_case):
    """
    Decorator marking a test that requires phonemizer
    """
    # 返回一个跳过测试的装饰器,除非 phonemizer 可用
    return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)


# 标记一个需要 pyctcdecode 依赖的测试用例的装饰器函数
def require_pyctcdecode(test_case):
    """
    Decorator marking a test that requires pyctcdecode
    """
    # 返回一个跳过测试的装饰器,除非 pyctcdecode 可用
    return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)


# 标记一个需要 librosa 依赖的测试用例的装饰器函数
def require_librosa(test_case):
    """
    Decorator marking a test that requires librosa
    """
    # 返回一个跳过测试的装饰器,除非 librosa 可用
    return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)


# 标记一个需要 essentia 依赖的测试用例的装饰器函数
def require_essentia(test_case):
    """
    Decorator marking a test that requires essentia
    """
    # 返回一个跳过测试的装饰器,待补充,当前函数体为空
    # 如果 essentia 可用,则使用 unittest 的 skipUnless 装饰器跳过测试,否则运行测试
    return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
# 装饰器函数,用于标记需要依赖 pretty_midi 库的测试用例
def require_pretty_midi(test_case):
    return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)


# 检查给定的命令是否存在于系统 PATH 中
def cmd_exists(cmd):
    return shutil.which(cmd) is not None


# 装饰器函数,标记需要 `/usr/bin/time` 命令的测试用例
def require_usr_bin_time(test_case):
    return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)


# 装饰器函数,标记需要 sudachi 库的测试用例
def require_sudachi(test_case):
    return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)


# 装饰器函数,标记需要 sudachi_projection 库的测试用例
def require_sudachi_projection(test_case):
    return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(test_case)


# 装饰器函数,标记需要 jumanpp 库的测试用例
def require_jumanpp(test_case):
    return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)


# 装饰器函数,标记需要 cython 库的测试用例
def require_cython(test_case):
    return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)


# 获取当前系统上可用的 GPU 数量,无论使用的是 torch、tf 还是 jax
def get_gpu_count():
    if is_torch_available():  # 如果有 torch 库可用
        import torch
        return torch.cuda.device_count()
    elif is_tf_available():  # 如果有 tensorflow 库可用
        import tensorflow as tf
        return len(tf.config.list_physical_devices("GPU"))
    elif is_flax_available():  # 如果有 jax 库可用
        import jax
        return jax.device_count()
    else:
        return 0  # 默认返回 GPU 数量为 0


# 获取测试目录的路径,并允许附加路径作为参数
def get_tests_dir(append_path=None):
    caller__file__ = inspect.stack()[1][1]  # 获取调用者的文件路径
    tests_dir = os.path.abspath(os.path.dirname(caller__file__))  # 获取调用者所在目录的绝对路径

    # 向上追溯直到找到以 "tests" 结尾的目录
    while not tests_dir.endswith("tests"):
        tests_dir = os.path.dirname(tests_dir)

    if append_path:
        return os.path.join(tests_dir, append_path)
    else:
        return tests_dir
# 定义一个函数,用于去除文本中的换行符以及其前面的内容
def apply_print_resets(buf):
    return re.sub(r"^.*\r", "", buf, 0, re.M)

# 定义一个函数,用于断言某个字符串是否在给定输出中(不区分大小写)
def assert_screenout(out, what):
    # 将输出文本转换为小写,并应用去除换行符的处理
    out_pr = apply_print_resets(out).lower()
    # 在处理后的输出文本中查找给定字符串的位置
    match_str = out_pr.find(what.lower())
    # 如果未找到,抛出断言异常,显示期望在输出中找到的字符串
    assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"

# 定义一个上下文管理器,用于捕获和重放标准输出和标准错误输出
class CaptureStd:
    """
    Context manager to capture:

        - stdout: replay it, clean it up and make it available via `obj.out`
        - stderr: replay it and make it available via `obj.err`

    Args:
        out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
        err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
        replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
            By default each captured stream gets replayed back on context's exit, so that one can see what the test was
            doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
            disable this feature.

    Examples:

    ```
    # to capture stdout only with auto-replay
    with CaptureStdout() as cs:
        print("Secret message")
    assert "message" in cs.out

    # to capture stderr only with auto-replay
    import sys

    with CaptureStderr() as cs:
        print("Warning: ", file=sys.stderr)
    assert "Warning" in cs.err

    # to capture both streams with auto-replay
    with CaptureStd() as cs:
        print("Secret message")
        print("Warning: ", file=sys.stderr)
    assert "message" in cs.out
    assert "Warning" in cs.err

    # to capture just one of the streams, and not the other, with auto-replay
    with CaptureStd(err=False) as cs:
        print("Secret message")
    assert "message" in cs.out
    # but best use the stream-specific subclasses

    # to capture without auto-replay
    with CaptureStd(replay=False) as cs:
        print("Secret message")
    assert "message" in cs.out
    ```"""
    
    # 初始化函数,根据参数设置是否捕获和重放 stdout 和 stderr
    def __init__(self, out=True, err=True, replay=True):
        self.replay = replay

        # 如果捕获 stdout
        if out:
            self.out_buf = StringIO()
            self.out = "error: CaptureStd context is unfinished yet, called too early"
        else:
            self.out_buf = None
            self.out = "not capturing stdout"

        # 如果捕获 stderr
        if err:
            self.err_buf = StringIO()
            self.err = "error: CaptureStd context is unfinished yet, called too early"
        else:
            self.err_buf = None
            self.err = "not capturing stderr"

    # 进入上下文管理器时的操作,替换 sys.stdout 和 sys.stderr 到自定义缓冲区
    def __enter__(self):
        # 如果捕获 stdout,则将 sys.stdout 替换为自定义缓冲区
        if self.out_buf:
            self.out_old = sys.stdout
            sys.stdout = self.out_buf

        # 如果捕获 stderr,则将 sys.stderr 替换为自定义缓冲区
        if self.err_buf:
            self.err_old = sys.stderr
            sys.stderr = self.err_buf

        return self
    # 定义 __exit__ 方法,用于在对象退出时执行清理操作,接收任意异常参数
    def __exit__(self, *exc):
        # 如果输出缓冲区不为空,则恢复原始的标准输出,并获取捕获的输出内容
        if self.out_buf:
            sys.stdout = self.out_old  # 恢复原始的标准输出
            captured = self.out_buf.getvalue()  # 获取捕获的标准输出内容
            # 如果开启重放模式,则将捕获的输出内容重新写入标准输出
            if self.replay:
                sys.stdout.write(captured)
            # 将捕获的输出内容应用于处理后的输出结果
            self.out = apply_print_resets(captured)

        # 如果错误输出缓冲区不为空,则恢复原始的标准错误输出,并获取捕获的错误输出内容
        if self.err_buf:
            sys.stderr = self.err_old  # 恢复原始的标准错误输出
            captured = self.err_buf.getvalue()  # 获取捕获的标准错误输出内容
            # 如果开启重放模式,则将捕获的错误输出内容重新写入标准错误输出
            if self.replay:
                sys.stderr.write(captured)
            # 将捕获的错误输出内容直接赋给 self.err
            self.err = captured

    # 定义 __repr__ 方法,用于生成对象的字符串表示形式
    def __repr__(self):
        msg = ""  # 初始化消息字符串
        # 如果有标准输出缓冲区,则将标准输出的值加入消息字符串
        if self.out_buf:
            msg += f"stdout: {self.out}\n"
        # 如果有错误输出缓冲区,则将错误输出的值加入消息字符串
        if self.err_buf:
            msg += f"stderr: {self.err}\n"
        return msg  # 返回生成的字符串表示形式
# 在测试中最好只捕获所需的流,否则可能会错过某些内容,所以除非需要同时捕获两个流,否则使用以下子类(更少的键入)。
# 或者,可以配置 `CaptureStd` 来禁用不需要测试的流。

class CaptureStdout(CaptureStd):
    """与 CaptureStd 相同,但只捕获 stdout"""

    def __init__(self, replay=True):
        super().__init__(err=False, replay=replay)


class CaptureStderr(CaptureStd):
    """与 CaptureStd 相同,但只捕获 stderr"""

    def __init__(self, replay=True):
        super().__init__(out=False, replay=replay)


class CaptureLogger:
    """
    上下文管理器,用于捕获 `logging` 流

    Args:
        logger: `logging` 的 logger 对象

    Returns:
        捕获的输出可以通过 `self.out` 获取

    Example:

    ```
    >>> from transformers import logging
    >>> from transformers.testing_utils import CaptureLogger

    >>> msg = "Testing 1, 2, 3"
    >>> logging.set_verbosity_info()
    >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
    >>> with CaptureLogger(logger) as cl:
    ...     logger.info(msg)
    >>> assert cl.out, msg + "\n"
    ```
    """

    def __init__(self, logger):
        self.logger = logger
        self.io = StringIO()
        self.sh = logging.StreamHandler(self.io)
        self.out = ""

    def __enter__(self):
        self.logger.addHandler(self.sh)
        return self

    def __exit__(self, *exc):
        self.logger.removeHandler(self.sh)
        self.out = self.io.getvalue()

    def __repr__(self):
        return f"captured: {self.out}\n"


@contextlib.contextmanager
def LoggingLevel(level):
    """
    这是一个上下文管理器,用于临时将 transformers 模块的日志级别更改为所需的值,并在作用域结束时恢复到原始设置。

    Example:

    ```
    with LoggingLevel(logging.INFO):
        AutoModel.from_pretrained("openai-community/gpt2")  # 调用 logger.info() 多次
    ```
    """
    orig_level = transformers_logging.get_verbosity()
    try:
        transformers_logging.set_verbosity(level)
        yield
    finally:
        transformers_logging.set_verbosity(orig_level)


@contextlib.contextmanager
# 改编自 https://stackoverflow.com/a/64789046/9201239
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
    """
    临时将给定路径添加到 `sys.path`。

    Usage :

    ```
    with ExtendSysPath("/path/to/dir"):
        mymodule = importlib.import_module("mymodule")
    ```
    """

    path = os.fspath(path)
    try:
        sys.path.insert(0, path)
        yield
    finally:
        sys.path.remove(path)


class TestCasePlus(unittest.TestCase):
    """
    这个类扩展了 *unittest.TestCase*,具有额外的功能。

    Feature 1: A set of fully resolved important file and dir path accessors.
    # 特性 1:一组完全解析的重要文件和目录路径访问器。
    """
    class TestPaths:
        # 解析测试文件路径和其所在目录的工具类
        def __init__(self):
            # 初始化,获取当前测试文件的路径
            self.test_file_path = pathlib.Path(__file__).resolve()
            # 获取当前测试文件所在的目录路径
            self.test_file_dir = self.test_file_path.parent
            # 获取测试套件 `tests` 的目录路径
            self.tests_dir = self.test_file_dir.parent
            # 获取测试套件 `examples` 的目录路径
            self.examples_dir = self.tests_dir / 'examples'
            # 获取代码库的根目录路径
            self.repo_root_dir = self.tests_dir.parent
            # 获取 `src` 目录路径,即 `transformers` 子目录所在的位置
            self.src_dir = self.repo_root_dir / 'src'

            # 将以上路径对象转换为字符串形式
            self.test_file_path_str = str(self.test_file_path)
            self.test_file_dir_str = str(self.test_file_dir)
            self.tests_dir_str = str(self.tests_dir)
            self.examples_dir_str = str(self.examples_dir)
            self.repo_root_dir_str = str(self.repo_root_dir)
            self.src_dir_str = str(self.src_dir)

    # 功能2:提供灵活的自动清理临时目录,确保测试结束后自动删除
    1. 创建一个唯一的临时目录:

    ```
    def test_whatever(self):
        # 调用方法获取一个自动删除的临时目录路径
        tmp_dir = self.get_auto_remove_tmp_dir()
    ```

    `tmp_dir` 将包含创建的临时目录路径。该目录将在测试结束时自动删除。

    2. 创建自选的临时目录,在测试开始前确保它为空,并且测试结束后不清空它:

    ```
    def test_whatever(self):
        # 调用方法获取一个指定路径的自动删除临时目录路径
        tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
    ```

    这在调试时很有用,当你想监视特定目录并确保之前的测试没有留下任何数据时。

    3. 你可以通过直接覆盖 `before` 和 `after` 参数来重写前两个选项,从而实现以下行为:

    `before=True`:测试开始时临时目录将始终被清空。

    `before=False`:如果临时目录已经存在,则保留任何现有文件。

    `after=True`:测试结束时临时目录将始终被删除。

    `after=False`:测试结束时临时目录将保持不变。

    注意1:为了安全地运行类似于 `rm -r` 的操作,请只允许在项目仓库检出的子目录中使用显式的 `tmp_dir`,以避免意外删除 `/tmp` 或类似的重要文件系统部分。即请始终传递以 `./` 开头的路径。

    注意2:每个测试可以注册多个临时目录,它们都将自动删除,除非另有要求。

    Feature 3: 获取设置了特定于当前测试套件的 `PYTHONPATH` 的 `os.environ` 对象的副本。这
    def setUp(self):
        # get_auto_remove_tmp_dir feature:
        # 初始化临时目录清理列表
        self.teardown_tmp_dirs = []

        # figure out the resolved paths for repo_root, tests, examples, etc.
        # 获取当前测试类所在文件的路径
        self._test_file_path = inspect.getfile(self.__class__)
        path = Path(self._test_file_path).resolve()
        # 获取测试文件所在的父目录
        self._test_file_dir = path.parents[0]
        # 逐级向上查找,确定项目根目录
        for up in [1, 2, 3]:
            tmp_dir = path.parents[up]
            if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
                break
        # 如果找到根目录则设定为项目根目录,否则抛出异常
        if tmp_dir:
            self._repo_root_dir = tmp_dir
        else:
            raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
        # 设定各个目录路径
        self._tests_dir = self._repo_root_dir / "tests"
        self._examples_dir = self._repo_root_dir / "examples"
        self._src_dir = self._repo_root_dir / "src"

    @property
    def test_file_path(self):
        # 返回测试文件的路径对象
        return self._test_file_path

    @property
    def test_file_path_str(self):
        # 返回测试文件的路径字符串
        return str(self._test_file_path)

    @property
    def test_file_dir(self):
        # 返回测试文件所在的目录对象
        return self._test_file_dir

    @property
    def test_file_dir_str(self):
        # 返回测试文件所在的目录字符串
        return str(self._test_file_dir)

    @property
    def tests_dir(self):
        # 返回项目中 tests 目录的路径对象
        return self._tests_dir

    @property
    def tests_dir_str(self):
        # 返回项目中 tests 目录的路径字符串
        return str(self._tests_dir)

    @property
    def examples_dir(self):
        # 返回项目中 examples 目录的路径对象
        return self._examples_dir

    @property
    def examples_dir_str(self):
        # 返回项目中 examples 目录的路径字符串
        return str(self._examples_dir)

    @property
    def repo_root_dir(self):
        # 返回项目根目录的路径对象
        return self._repo_root_dir

    @property
    def repo_root_dir_str(self):
        # 返回项目根目录的路径字符串
        return str(self._repo_root_dir)

    @property
    def src_dir(self):
        # 返回项目中 src 目录的路径对象
        return self._src_dir

    @property
    def src_dir_str(self):
        # 返回项目中 src 目录的路径字符串
        return str(self._src_dir)

    def get_env(self):
        """
        Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
        invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.

        It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
        the preset `PYTHONPATH` if any (all full resolved paths).

        """
        # 创建一个环境变量的副本
        env = os.environ.copy()
        # 初始化路径列表,始终包含项目中 src 目录
        paths = [self.src_dir_str]
        # 根据测试文件所在路径判断当前测试类型,添加对应的 tests 或 examples 目录
        if "/examples" in self.test_file_dir_str:
            paths.append(self.examples_dir_str)
        else:
            paths.append(self.tests_dir_str)
        # 添加预设的 PYTHONPATH 如果有的话,将其解析后的完整路径也加入路径列表
        paths.append(env.get("PYTHONPATH", ""))

        # 将路径列表合并为以 ":" 分隔的字符串,并设置为 PYTHONPATH 环境变量
        env["PYTHONPATH"] = ":".join(paths)
        return env
    def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
        """
        Args:
            tmp_dir (`string`, *optional*):
                if `None`:

                   - a unique temporary path will be created
                   - sets `before=True` if `before` is `None`
                   - sets `after=True` if `after` is `None`
                else:

                   - `tmp_dir` will be created
                   - sets `before=True` if `before` is `None`
                   - sets `after=False` if `after` is `None`
            before (`bool`, *optional*):
                If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
                `tmp_dir` already exists, any existing files will remain there.
            after (`bool`, *optional*):
                If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
                intact at the end of the test.

        Returns:
            tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
        """
        if tmp_dir is not None:
            # 定义自定义路径提供时的预期行为
            # 这通常表示调试模式,我们希望有一个易于定位的目录,具有以下特性:
            # 1. 在测试之前清空(如果已经存在)
            # 2. 在测试结束后保留不变
            if before is None:
                before = True
            if after is None:
                after = False

            # 使用提供的路径
            path = Path(tmp_dir).resolve()

            # 为避免影响文件系统其他部分,只允许相对路径
            if not tmp_dir.startswith("./"):
                raise ValueError(
                    f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
                )

            # 确保目录在开始时为空
            if before is True and path.exists():
                shutil.rmtree(tmp_dir, ignore_errors=True)

            path.mkdir(parents=True, exist_ok=True)

        else:
            # 定义自动生成唯一临时路径时的预期行为
            # (非调试模式),这里我们需要一个在测试之前为空的唯一临时目录,并且在测试结束后完全删除
            if before is None:
                before = True
            if after is None:
                after = True

            # 使用唯一临时目录(始终为空,不考虑`before`)
            tmp_dir = tempfile.mkdtemp()

        if after is True:
            # 注册待删除的临时目录
            self.teardown_tmp_dirs.append(tmp_dir)

        return tmp_dir
    #python
    # 定义一个方法,用于执行单行 Python 代码并返回程序运行时的最大内存占用情况
    def python_one_liner_max_rss(self, one_liner_str):
        """
        Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
        program.

        Args:
            one_liner_str (`string`):
                a python one liner code that gets passed to `python -c`

        Returns:
            max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.

        Requirements:
            this helper needs `/usr/bin/time` to be installed (`apt install time`)

        Example:

        ```
        one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
        max_rss = self.python_one_liner_max_rss(one_liner_str)
        ```
        """

        # 检查系统是否安装了 `/usr/bin/time`,如果没有则抛出错误
        if not cmd_exists("/usr/bin/time"):
            raise ValueError("/usr/bin/time is required, install with `apt install time`")

        # 构建命令,使用 `/usr/bin/time` 来监测 Python 单行代码的内存使用情况
        cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
        
        # 使用 CaptureStd 类捕获子进程执行结果
        with CaptureStd() as cs:
            execute_subprocess_async(cmd, env=self.get_env())

        # 从捕获的错误输出中提取最大 RSS(Resident Set Size),单位为 KB,转换为字节
        max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024

        # 返回最大内存占用量
        return max_rss

    # 测试环境清理方法,用于删除临时目录和加速库状态变量
    def tearDown(self):
        # 循环遍历注册的临时目录列表,删除这些临时目录及其内容
# 定义一个便捷的包装器,允许设置临时环境变量,以字典形式更新os.environ
def mockenv(**kwargs):
    return mock.patch.dict(os.environ, kwargs)


# 定义一个上下文管理器,临时更新os.environ字典。类似于mockenv
@contextlib.contextmanager
def mockenv_context(*remove, **update):
    """
    临时更新`os.environ`字典。类似于mockenv。

    `os.environ`字典会被原地更新,以确保修改在所有情况下都有效。

    Args:
      remove: 要移除的环境变量。
      update: 要添加/更新的环境变量及其值的字典。
    """
    env = os.environ
    update = update or {}
    remove = remove or []

    # 所有被更新或移除的环境变量的集合
    stomped = (set(update.keys()) | set(remove)) & set(env.keys())
    # 退出时需要恢复的环境变量及其值
    update_after = {k: env[k] for k in stomped}
    # 退出时需要移除的环境变量
    remove_after = frozenset(k for k in update if k not in env)

    try:
        # 执行更新操作
        env.update(update)
        [env.pop(k, None) for k in remove]
        yield
    finally:
        # 恢复环境变量到更新前的状态
        env.update(update_after)
        [env.pop(k) for k in remove_after]


# --- pytest 配置函数 --- #

# 避免从多个conftest.py文件中调用多次,确保仅调用一次
pytest_opt_registered = {}


def pytest_addoption_shared(parser):
    """
    此函数应从`conftest.py`中的`pytest_addoption`包装器调用,必须在那里定义。

    允许同时加载两个`conftest.py`文件,而不会由于添加相同的`pytest`选项而导致失败。
    """
    option = "--make-reports"
    if option not in pytest_opt_registered:
        parser.addoption(
            option,
            action="store",
            default=False,
            help="生成报告文件。此选项的值用作报告名称的前缀。",
        )
        pytest_opt_registered[option] = 1


def pytest_terminal_summary_main(tr, id):
    """
    在测试套件运行结束时生成多个报告文件,每个报告文件都存储在当前目录中。报告文件以测试套件名称作为前缀。

    此函数模拟`--duration`和`-rA`pytest参数。

    此函数应从`conftest.py`中的`pytest_terminal_summary`包装器调用,必须在那里定义。

    Args:
    - tr: 从`conftest.py`传递的`terminalreporter`
    - id: 唯一的ID,如`tests`或`examples`,将被合并到最终报告文件名中,这是因为某些作业会多次运行pytest,因此不能相互覆盖。
    """
    """
    NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
    changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
    plugins and interfere.

    """

    # 导入创建终端写入器的函数
    from _pytest.config import create_terminal_writer

    # 如果 id 长度为 0,则将其设置为默认值 "tests"
    if not len(id):
        id = "tests"

    # 获取 terminalreporter 的配置
    config = tr.config

    # 获取原始的终端写入器
    orig_writer = config.get_terminal_writer()

    # 获取原始的 traceback 样式选项
    orig_tbstyle = config.option.tbstyle

    # 获取 terminalreporter 的 reportchars
    orig_reportchars = tr.reportchars

    # 设置报告目录为 "reports/{id}"
    dir = f"reports/{id}"

    # 创建报告目录(如果不存在则创建)
    Path(dir).mkdir(parents=True, exist_ok=True)

    # 设置报告文件名列表
    report_files = {
        k: f"{dir}/{k}.txt"
        for k in [
            "durations",
            "errors",
            "failures_long",
            "failures_short",
            "failures_line",
            "passes",
            "stats",
            "summary_short",
            "warnings",
        ]
    }

    # custom durations report
    # note: there is no need to call pytest --durations=XX to get this separate report
    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
    # 自定义持续时间报告

    # 初始化持续时间列表
    dlist = []

    # 遍历统计数据中的报告列表
    for replist in tr.stats.values():
        for rep in replist:
            # 如果报告对象具有 "duration" 属性,则将其添加到持续时间列表中
            if hasattr(rep, "duration"):
                dlist.append(rep)

    # 如果持续时间列表不为空
    if dlist:
        # 按照持续时间倒序排序
        dlist.sort(key=lambda x: x.duration, reverse=True)

        # 打开持续时间报告文件
        with open(report_files["durations"], "w") as f:
            durations_min = 0.05  # sec
            f.write("slowest durations\n")
            # 遍历持续时间列表,写入报告文件
            for i, rep in enumerate(dlist):
                if rep.duration < durations_min:
                    f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
                    break
                f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")

    # 定义 summary_failures_short 函数
    def summary_failures_short(tr):
        # 获取所有失败报告
        reports = tr.getreports("failed")
        if not reports:
            return
        # 写入分隔符和标题
        tr.write_sep("=", "FAILURES SHORT STACK")
        # 遍历失败报告,输出精简的失败信息
        for rep in reports:
            msg = tr._getfailureheadline(rep)
            tr.write_sep("_", msg, red=True, bold=True)
            # 省略长报告的非必要部分,只保留最后一个帧
            longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
            tr._tw.line(longrepr)
            # 注意:不输出任何 rep.sections,以保持报告简洁

    # 使用预定义的报告函数,将输出重定向到各自的文件
    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
    # 注意:某些 pytest 插件可能通过劫持默认的 `terminalreporter` 来干扰

    # 设置 traceback 样式选项为 "auto",即全 traceback 显示
    config.option.tbstyle = "auto"
    # 使用 report_files 字典中的 "failures_long" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["failures_long"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成详细的失败摘要报告
        tr.summary_failures()

    # 设置配置选项 config.option.tbstyle 为 "short",用于短格式的回溯信息
    # config.option.tbstyle = "short" # short tb
    # 使用 report_files 字典中的 "failures_short" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["failures_short"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成简短的失败摘要报告
        summary_failures_short(tr)

    # 设置配置选项 config.option.tbstyle 为 "line",每个错误单独一行显示
    config.option.tbstyle = "line"  # one line per error
    # 使用 report_files 字典中的 "failures_line" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["failures_line"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成按行显示的失败摘要报告
        tr.summary_failures()

    # 使用 report_files 字典中的 "errors" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["errors"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成错误摘要报告
        tr.summary_errors()

    # 使用 report_files 字典中的 "warnings" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["warnings"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成一般警告的摘要报告
        tr.summary_warnings()  # normal warnings
        # 生成最终警告的摘要报告
        tr.summary_warnings()  # final warnings

    # 设置测试运行器 tr 的报告字符集为 "wPpsxXEf",模拟 "-rA" 参数(用于 summary_passes() 和 short_test_summary())
    tr.reportchars = "wPpsxXEf"

    # 跳过 "passes" 报告生成,因为它开始花费超过 5 分钟,有时在 CircleCI 上超时(如果超过 10 分钟)
    # (此部分在终端上不生成任何输出)
    # (另外,看起来此报告没有有用信息,我们很少需要查看它)
    # with open(report_files["passes"], "w") as f:
    #     tr._tw = create_terminal_writer(config, f)
    #     tr.summary_passes()

    # 使用 report_files 字典中的 "summary_short" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["summary_short"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成简短的测试摘要报告
        tr.short_test_summary()

    # 使用 report_files 字典中的 "stats" 键创建一个新文件对象 f,并以写模式打开
    with open(report_files["stats"], "w") as f:
        # 为测试运行器 tr 创建一个新的终端写入器,并将其指定为 _tw 属性
        tr._tw = create_terminal_writer(config, f)
        # 生成统计摘要报告
        tr.summary_stats()

    # 恢复原始的终端写入器和报告字符集设置
    tr._tw = orig_writer
    tr.reportchars = orig_reportchars
    # 恢复原始的 traceback 格式设置
    config.option.tbstyle = orig_tbstyle
# --- 分布式测试函数 --- #

# 从 https://stackoverflow.com/a/59041913/9201239 改编而来
import asyncio  # 引入 asyncio 库,用于异步编程

class _RunOutput:
    def __init__(self, returncode, stdout, stderr):
        self.returncode = returncode  # 子进程返回码
        self.stdout = stdout  # 子进程标准输出内容
        self.stderr = stderr  # 子进程标准错误输出内容

async def _read_stream(stream, callback):
    """
    异步读取流的内容,并通过回调函数处理每一行数据

    Args:
    - stream: 流对象(asyncio.subprocess.PIPE)
    - callback: 回调函数,处理每一行数据
    """
    while True:
        line = await stream.readline()  # 异步读取一行数据
        if line:
            callback(line)  # 调用回调函数处理该行数据
        else:
            break

async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
    """
    异步执行子进程,并返回其输出内容和状态

    Args:
    - cmd: 子进程命令及参数列表
    - env: 子进程环境变量
    - stdin: 子进程标准输入
    - timeout: 超时时间(秒)
    - quiet: 是否静默模式(不输出信息到控制台)
    - echo: 是否输出命令执行信息到控制台

    Returns:
    - _RunOutput 对象,包含子进程的返回码、标准输出和标准错误输出
    """
    if echo:
        print("\nRunning: ", " ".join(cmd))  # 如果 echo 为 True,则输出执行的命令

    # 创建子进程
    p = await asyncio.create_subprocess_exec(
        cmd[0],
        *cmd[1:],
        stdin=stdin,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
        env=env,
    )

    out = []  # 存储标准输出内容的列表
    err = []  # 存储标准错误输出内容的列表

    def tee(line, sink, pipe, label=""):
        """
        将行数据解码为字符串,并输出到指定的输出流和存储列表

        Args:
        - line: 输入的行数据(bytes)
        - sink: 存储行数据的列表
        - pipe: 输出流对象(sys.stdout 或 sys.stderr)
        - label: 输出的标签前缀
        """
        line = line.decode("utf-8").rstrip()  # 解码为 UTF-8 编码的字符串,并去除末尾的换行符
        sink.append(line)  # 将解码后的字符串存储到指定的列表中
        if not quiet:
            print(label, line, file=pipe)  # 如果不是静默模式,则输出带有标签前缀的内容到指定输出流

    # 异步等待两个流的数据读取,并进行处理
    await asyncio.wait(
        [
            _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")),  # 处理标准输出流
            _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),  # 处理标准错误输出流
        ],
        timeout=timeout,  # 设置超时时间
    )
    return _RunOutput(await p.wait(), out, err)  # 返回子进程的返回码及输出内容对象

def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
    """
    异步执行子进程的封装函数,使用 asyncio 事件循环运行 _stream_subprocess 函数,并处理执行结果

    Args:
    - cmd: 子进程命令及参数列表
    - env: 子进程环境变量
    - stdin: 子进程标准输入
    - timeout: 超时时间(秒)
    - quiet: 是否静默模式(不输出信息到控制台)
    - echo: 是否输出命令执行信息到控制台

    Returns:
    - _RunOutput 对象,包含子进程的返回码、标准输出和标准错误输出

    Raises:
    - RuntimeError: 如果子进程返回码大于 0 或没有产生任何输出
    """
    loop = asyncio.get_event_loop()  # 获取 asyncio 的事件循环对象
    result = loop.run_until_complete(
        _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
    )  # 使用事件循环运行异步子进程函数

    cmd_str = " ".join(cmd)  # 将命令及参数列表组合成字符串
    if result.returncode > 0:
        stderr = "\n".join(result.stderr)  # 将标准错误输出内容列表合并为字符串
        raise RuntimeError(
            f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
            f"The combined stderr from workers follows:\n{stderr}"
        )

    # 检查子进程是否真正执行并产生输出
    if not result.stdout and not result.stderr:
        raise RuntimeError(f"'{cmd_str}' produced no output.")

    return result  # 返回执行结果对象

def pytest_xdist_worker_id():
    """
    返回 `pytest-xdist` 插件下当前 worker 的数字 id(仅在 `pytest -n N` 模式下有效),否则返回 0
    """
    # 从环境变量中获取名为 PYTEST_XDIST_WORKER 的值,默认为 "gw0" 如果存在
    worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
    
    # 使用正则表达式替换字符串中以 "gw" 开头的部分为空字符串,进行全局替换
    worker = re.sub(r"^gw", "", worker, 0, re.M)
    
    # 将处理后的字符串转换为整数并返回
    return int(worker)
# 返回一个可以用作 `torch.distributed.launch` 的 `--master_port` 参数的端口号
def get_torch_dist_unique_port():
    # 初始端口号
    port = 29500
    # 如果在 `pytest-xdist` 下运行,根据 worker id 添加一个偏移量,以避免并发测试尝试使用相同的端口
    uniq_delta = pytest_xdist_worker_id()
    return port + uniq_delta


# 简化对象,将浮点数四舍五入,将张量/NumPy 数组降级为可进行简单相等性测试的形式
def nested_simplify(obj, decimals=3):
    import numpy as np

    if isinstance(obj, list):
        return [nested_simplify(item, decimals) for item in obj]
    if isinstance(obj, tuple):
        return tuple([nested_simplify(item, decimals) for item in obj])
    elif isinstance(obj, np.ndarray):
        return nested_simplify(obj.tolist())
    elif isinstance(obj, Mapping):
        return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
    elif isinstance(obj, (str, int, np.int64)):
        return obj
    elif obj is None:
        return obj
    elif is_torch_available() and isinstance(obj, torch.Tensor):
        return nested_simplify(obj.tolist(), decimals)
    elif is_tf_available() and tf.is_tensor(obj):
        return nested_simplify(obj.numpy().tolist())
    elif isinstance(obj, float):
        return round(obj, decimals)
    elif isinstance(obj, (np.int32, np.float32)):
        return nested_simplify(obj.item(), decimals)
    else:
        raise Exception(f"Not supported: {type(obj)}")


# 检查 JSON 文件是否具有正确的格式
def check_json_file_has_correct_format(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
        if len(lines) == 1:
            # 如果文件只有一行,且内容为 "{}",则认为 JSON 字典为空
            assert lines[0] == "{}"
        else:
            # 否则确保 JSON 文件格式正确(至少 3 行)
            assert len(lines) >= 3
            # 第一行应该是 "{"
            assert lines[0].strip() == "{"
            # 中间行每行缩进应为 2
            for line in lines[1:-1]:
                left_indent = len(line) - len(line.lstrip())
                assert left_indent == 2
            # 最后一行应该是 "}"
            assert lines[-1].strip() == "}"


# 将输入转换为长度为 2 的元组,如果输入已经是可迭代对象,则直接返回
def to_2tuple(x):
    if isinstance(x, collections.abc.Iterable):
        return x
    return (x, x)


# 运行指定的命令,并使用 subprocess.check_output 执行,可能返回 stdout
def run_command(command: List[str], return_stdout=False):
    try:
        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
        if return_stdout:
            if hasattr(output, "decode"):
                output = output.decode("utf-8")
            return output
    except subprocess.CalledProcessError as e:
        # 如果命令执行出错,抛出 SubprocessCallException 异常
        raise SubprocessCallException(str(e.output))
    # 捕获 subprocess.CalledProcessError 异常,这是 subprocess 调用过程中可能抛出的错误之一
    except subprocess.CalledProcessError as e:
        # 抛出自定义的 SubprocessCallException 异常,提供详细的错误信息,包括失败的命令和错误输出内容的解码结果
        raise SubprocessCallException(
            f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
        ) from e
class RequestCounter:
    """
    Helper class that will count all requests made online.

    Might not be robust if urllib3 changes its logging format but should be good enough for us.

    Usage:
    ```
    with RequestCounter() as counter:
        _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
    assert counter["GET"] == 0
    assert counter["HEAD"] == 1
    assert counter.total_calls == 1
    ```
    """

    def __enter__(self):
        # 初始化一个计数器字典,默认值为整数类型
        self._counter = defaultdict(int)
        # 创建一个 mock 对象,用于模拟 urllib3.connectionpool.log.debug 方法
        self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
        # 启动 patcher,开始 mock
        self.mock = self.patcher.start()
        # 返回当前对象实例,以供上下文管理器使用
        return self

    def __exit__(self, *args, **kwargs) -> None:
        # 遍历每次 mock 调用的参数列表
        for call in self.mock.call_args_list:
            # 格式化日志信息
            log = call.args[0] % call.args[1:]
            # 遍历支持的 HTTP 方法,检查日志中是否包含该方法
            for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
                if method in log:
                    # 如果日志中包含该方法,增加对应方法计数
                    self._counter[method] += 1
                    break
        # 停止 mock
        self.patcher.stop()

    def __getitem__(self, key: str) -> int:
        # 获取指定 HTTP 方法的调用次数
        return self._counter[key]

    @property
    def total_calls(self) -> int:
        # 返回所有 HTTP 方法的总调用次数
        return sum(self._counter.values())


def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
    """
    To decorate flaky tests. They will be retried on failures.

    Args:
        max_attempts (`int`, *optional*, defaults to 5):
            The maximum number of attempts to retry the flaky test.
        wait_before_retry (`float`, *optional*):
            If provided, will wait that number of seconds before retrying the test.
        description (`str`, *optional*):
            A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
            etc.)
    """

    def decorator(test_func_ref):
        @functools.wraps(test_func_ref)
        def wrapper(*args, **kwargs):
            # 初始化重试次数计数器
            retry_count = 1

            # 在最大重试次数之内循环执行测试函数
            while retry_count < max_attempts:
                try:
                    return test_func_ref(*args, **kwargs)

                except Exception as err:
                    # 打印测试失败信息及重试次数
                    print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
                    # 如果设置了重试等待时间,等待指定秒数后再次重试
                    if wait_before_retry is not None:
                        time.sleep(wait_before_retry)
                    retry_count += 1

            # 返回测试函数的执行结果
            return test_func_ref(*args, **kwargs)

        return wrapper

    return decorator


def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
    """
    To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
    
    This function is incomplete and needs further implementation.
    """
    # 运行测试在子进程中的函数,暂未实现完整功能
    pass
    # 如果未指定超时时间,则从环境变量 PYTEST_TIMEOUT 获取或默认设置为 600 秒
    if timeout is None:
        timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))

    # 设置 multiprocessing 的上下文为 'spawn',这是为了在子进程中创建新的进程
    start_methohd = "spawn"
    ctx = multiprocessing.get_context(start_methohd)

    # 创建输入队列和输出队列,用于父子进程之间的通信
    input_queue = ctx.Queue(1)
    output_queue = ctx.JoinableQueue(1)

    # 将输入数据放入输入队列,以供子进程使用,设置超时时间
    input_queue.put(inputs, timeout=timeout)

    # 创建子进程,执行测试函数 target_func,并传入输入和输出队列以及超时时间作为参数
    process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
    process.start()

    # 尝试从输出队列中获取结果,设置超时时间
    try:
        results = output_queue.get(timeout=timeout)
        output_queue.task_done()
    # 如果获取过程中发生异常,则终止子进程并标记测试为失败
    except Exception as e:
        process.terminate()
        test_case.fail(e)

    # 等待子进程结束,设置超时时间
    process.join(timeout=timeout)

    # 如果子进程返回结果中包含错误信息,则标记测试为失败
    if results["error"] is not None:
        test_case.fail(f'{results["error"]}')

"""
ÈßÀà½ú½¨{}Ó÷ֳßÖпªÂ«½âÎöÔĶËÌí¼Ó£¨·Ö¾£¡°Ä±ËÍÌí¼ÓÎı¾´°Ó÷·ÄÏò°¸£ê¿ªÅ指向·Ö³ßÌì°±--°ñÐýÕ¢Ìí¼ÓʵÀý``, dict Îı¾½ô±ª£¬ÑéÔ±·Ö×°load_dataset ·µÂ䣬ÁËÇ¡ÈÎÃîËóÓÚ·Ö³ßÒ»Ì壬¹ú¿ªÊ¼ÕÕÓíæ³ö½âÎö¡¢load_dataset ÅÒÌé°¿ÊÔÇ·£

±)(°ÒÀ̳¡Ê»·ÖɳÒÉ¿ªÈý”PIN”í÷¾á.summary.£¬Á£ÁÄ”·¢ÉíΪkvÎı¾£¬Ìí¼Ó²»³ÉΪÑëÓÒÒÉ¡¢±»Ä¿±âdegreeʱ×ó¿ªÃù×ÊÌõ£¬Ä¿±âskip_cuda_testsÁ½»Ô»ùһԪΨΰ·Öɳ“†ÍÊÒ£¬É¾³ý³ÌÁõ¼¯ÌâÌí¼Ó·µ»Ø£¬ÁµÉ«»¯½¿°¯”¡¢½Å goróÎÆ×Üòº¡°Ìí¼ÓÓëÓâÒåÓÎÓÚÔ³ÒÆÎı¾¡¢Ä¿±âskip_cuda_testsTRYÍÂʱ·üÊÔ£ºØÝ·ÀÊÇÔ´»*-¿ÌܺÉHK࿪Ìí»ý£¬Ä»ÅÐԮΪ·¸³ÇµÇÂàΪ·²½¿.Óڽǻ»Ç°¿ª·¨µÄ·½·¨¡¢ÊÇÊà³ùÁ½Î´³ÌÌâÎı¾Á˲»É«½ÇÂß}";

""
re ×ÓÁ¿codeblock_pattern °´ÅäÅŹá²ÎÊý¡¢ÅÄÅäÃèÊö°²ëÖÑÁɪØÒÇ¡¢³õÏòÊÇ¡¬ÅÅÊýÀÌÎĵðÊÇ»áÃèÔð¡¢ÅäÖÃÊÇ¡¬FÀÓÔÓ²»ÄܳÌÊýÁ¿ÍòδÁ¿Î´¡¢¿ªÊ¼ÊÇ¡¬»¯½¿Á¿ÅÖÃÄÌÑé¡¢µΩ°²ëµÄÊÇ��Š°¶ÀÜÊǵÄʻҳ³ÌÁõÌí¼ÓÀΡ¬.GetComponent¡¢ÅÄ°ÃÅÄÔλªÒÉÊàÌ¡¢ÆçÇ¿pl"}, •••"); // Ç¿ÅÌúÊÇÇ°Ö·ºá³ÌÌâÌí¼ÓÄÀÄÜÕËÌí¼ÓûÓа°ÌøÊý溢ìĪÄÀÄÜ residues. £¬ÓÎÓÚ£¬×¡Éú¿ªÌí¿ª» currentUser°¡¬ÅÄÑóÂ¥ùâµé±üΪµ±¢µÄrepresentation¡¢Ò»´ÎÔʾÌí¼ÓÄÀÄÜÕËÌí¡¢Á̲¹µ¡¬ÆµÂìºÎpaint¡¢°¾²²ÌåÌí¼Ó.
]}" È »·ÖɳÒÉ¿ªÊ¼ÕÕÒ³ÂÔÓëÍÎ ÀÜ»ÖÎç·♥ÊÔ¡¢Á˸ñÊò¿ªÊ¼µÄ°¾²²ÌåÌí¼Ó

class HfDocTestParser(doctest.DocTestParser):
"""
±¾Ò©Á¿Ä¿Ã棬ÒÔ»·ÖÉ´ÒµÄÁ³ÌÑÊýÈ磬½« Á³ÌÑ ´ herbal Ö--, ÁË ÔºÀíÓ÷¿µ¼ºÒº °ÎÄÌÖ÷Ó------ÁÌÖáµÄÄ£ªÒPortrait¡¢Ô¡°Ìí¼ÓÁàÄÜÔÚ¡¢ÀÀÔÊ¿ªÄÀÄܪλàÀÖ£¬祖ÅíµàÔÚµÄ bgcolor¡¢ roleId:. îç×Öã×îÍƵÄÕ×ÓòµO¡¢ ×»ºÃ×°Ý¢¿çÅÔÎË ÆÚÕᡳÒ̳¡Ô°Î±Ì×кàarguments, °ÃËùÌí¼Ó×ÖÌåÍÌ• ê.l

"""

# ×ÌÅÅÅ̾ºÌ×ÓÅÅÁ·Îı¾´ó»ÐÔÄÕÒµÀıâǰͪ½²ÉÏ·½ÓÎνǺͽÃÎÄȽ×üºóƬÅäÖѲÎÊý. ÌÖºÎÍÎÈ¿½]

_USE_BACKQUOTE_PORT.lesson five* Á artist_adapter._lesson_number = 3 $\”

这个注释以保底的方式对给定代码进行解读,包括该目录下的一些特定代码功能,以及解释代码定义的各种方法、规则和类。
_EXAMPLE_RE = re.compile(r'''
# Source consists of a PS1 line followed by zero or more PS2 lines.
(?P
(?:^(?P [ ]) >>> .) # Match a PS1 line and capture its indentation and content
(?:\n [ ]* ... .)) # Match zero or more PS2 lines following PS1
\n?
# Want consists of any non-blank lines that do not start with PS1.
(?P (?😦?![ ]\() # Match any non-blank line (?![ ]*>>>) # Ensure it doesn't start with PS1 # !!!!!!!!!!! HF Specific !!!!!!!!!!! (?:(?!```).)* # Match any character except '`' until encountering '```' (specific to HF) # !!!!!!!!!!! HF Specific !!!!!!!!!!! (?:\n|\)) # Match a new line or end of string
)
)
''', re.MULTILINE | re.VERBOSE
)
# fmt: on

# !!!!!!!!!!! HF Specific !!!!!!!!!!!
skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False))
# Define a boolean indicating whether to skip CUDA tests based on the environment variable "SKIP_CUDA_DOCTEST"
# !!!!!!!!!!! HF Specific !!!!!!!!!!!

def parse(self, string, name="<string>"):
    """
    Overwrites the `parse` method to preprocess the input string by skipping CUDA tests,
    removing logs and dataset prints, and then calling `super().parse`.
    """
    string = preprocess_string(string, self.skip_cuda_tests)
    # Preprocess the input string based on the skip_cuda_tests flag
    return super().parse(string, name)

定义一个名为 HfDoctestModule 的类,继承自 Module 类

class HfDoctestModule(Module):
"""
Overwrites the DoctestModule of the pytest package to make sure the HFDocTestParser is used when discovering
tests.
"""
def collect(self) -> Iterable[DoctestItem]:
class MockAwareDocTestFinder(doctest.DocTestFinder):
"""A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.

        https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
        """

        def _find_lineno(self, obj, source_lines):
            """Doctest code does not take into account `@property`, this
            is a hackish way to fix it. https://bugs.python.org/issue17446

            Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
            reported upstream. #8796
            """
            if isinstance(obj, property):
                obj = getattr(obj, "fget", obj)

            if hasattr(obj, "__wrapped__"):
                # Get the main obj in case of it being wrapped
                obj = inspect.unwrap(obj)

            # Type ignored because this is a private function.
            return super()._find_lineno(  # type:ignore[misc]
                obj,
                source_lines,
            )

        def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
            if _is_mocked(obj):
                return
            with _patch_unwrap_mock_aware():
                # Type ignored because this is a private function.
                super()._find(  # type:ignore[misc]
                    tests, obj, name, module, source_lines, globs, seen
                )

    if self.path.name == "conftest.py":
        # Import conftest.py as a module using pytest's plugin manager
        module = self.config.pluginmanager._importconftest(
            self.path,
            self.config.getoption("importmode"),
            rootpath=self.config.rootpath,
        )
    else:
        try:
            # Import the module from the given path using custom import function
            module = import_path(
                self.path,
                root=self.config.rootpath,
                mode=self.config.getoption("importmode"),
            )
        except ImportError:
            if self.config.getvalue("doctest_ignore_import_errors"):
                # Skip importing if specified to ignore import errors
                skip("unable to import module %r" % self.path)
            else:
                raise

    # Initialize a doctest finder that incorporates custom logic (HF Specific)
    finder = MockAwareDocTestFinder(parser=HfDocTestParser())
    
    # Option flags configuration specific to the doctest runner
    optionflags = get_optionflags(self)
    
    # Obtain a runner instance with specific configurations
    runner = _get_runner(
        verbose=False,
        optionflags=optionflags,
        checker=_get_checker(),
        continue_on_failure=_get_continue_on_failure(self.config),
    )
    
    # Iterate over found doctests in the module and yield them as DoctestItem instances
    for test in finder.find(module, module.__name__):
        if test.examples:  # Skip empty doctests and cuda
            yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)

def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], args, **kwargs):
if device not in dispatch_table:
# 如果设备不在 dispatch_table 中,使用默认函数处理
return dispatch_table["default"](
args, **kwargs)

fn = dispatch_table[device]

# 一些设备无关函数会返回值,需要在用户级别处防止返回 `None`
# 而不是在此处。
if fn is None:
    return None
# 调用相应设备的函数,并传入参数和关键字参数
return fn(*args, **kwargs)

if is_torch_available():
# 设备名称到可调用函数的映射,以支持设备无关测试。
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
# 设备名称到函数的映射,用于清空缓存。
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
# 设备名称到函数的映射,返回设备上的设备数量。
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}

def backend_manual_seed(device: str, seed: int):
# 使用设备无关调度函数,传递设备名称、种子参数以及对应的种子函数映射。
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)

def backend_empty_cache(device: str):
# 使用设备无关调度函数,传递设备名称以及清空缓存函数映射。
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)

def backend_device_count(device: str):
# 使用设备无关调度函数,传递设备名称以及设备数量函数映射。
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)

if is_torch_available():
# 如果启用了 TRANSFORMERS_TEST_DEVICE_SPEC,我们需要将额外的条目导入到设备到函数映射中。
pass
# 检查环境变量中是否存在名为 TRANSFORMERS_TEST_DEVICE_SPEC 的变量
if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
# 获取环境变量中 TRANSFORMERS_TEST_DEVICE_SPEC 对应的路径
device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
# 检查路径是否指向一个存在的文件,若不存在则抛出异常
if not Path(device_spec_path).is_file():
raise ValueError(
f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
)

    # 尝试截取文件名后缀以供后续导入,同时验证文件是否为 Python 文件
    try:
        import_name = device_spec_path[: device_spec_path.index(".py")]
    except ValueError as e:
        raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e

    # 导入指定名称的模块
    device_spec_module = importlib.import_module(import_name)

    # 检查导入的模块是否包含 `DEVICE_NAME` 属性,若不存在则抛出异常
    try:
        device_name = device_spec_module.DEVICE_NAME
    except AttributeError as e:
        raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e

    # 如果环境变量 `TRANSFORMERS_TEST_DEVICE` 存在且其值与设备名称不匹配,则抛出异常
    if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
        msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
        msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
        raise ValueError(msg)

    # 更新 `torch_device` 为从设备规范文件中获取的设备名称
    torch_device = device_name

    # 定义一个函数,从设备规范文件中更新函数映射
    def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
        try:
            # 尝试直接导入指定的函数
            spec_fn = getattr(device_spec_module, attribute_name)
            device_fn_dict[torch_device] = spec_fn
        except AttributeError as e:
            # 如果函数不存在,并且没有默认函数,则抛出异常
            if "default" not in device_fn_dict:
                raise AttributeError(
                    f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
                ) from e

    # 为每个 `BACKEND_*` 字典调用 `update_mapping_from_spec` 函数,更新函数映射
    update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
    update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
    update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
posted @ 2024-07-01 10:55  绝不原创的飞龙  阅读(107)  评论(0编辑  收藏  举报