Transformers-源码解析-六-

Transformers 源码解析(六)

.\image_utils.py

# 导入必要的库和模块
import base64  # 用于 base64 编解码
import os  # 系统操作相关功能
from io import BytesIO  # 提供字节流操作
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union  # 类型提示相关模块

import numpy as np  # 数组操作库
import requests  # 发送 HTTP 请求的库
from packaging import version  # 版本管理相关功能

from .utils import (  # 导入自定义工具函数
    ExplicitEnum,
    is_jax_tensor,
    is_tf_tensor,
    is_torch_available,
    is_torch_tensor,
    is_vision_available,
    logging,
    requires_backends,
    to_numpy,
)
from .utils.constants import (  # 导入常量
    IMAGENET_DEFAULT_MEAN,
    IMAGENET_DEFAULT_STD,
    IMAGENET_STANDARD_MEAN,
    IMAGENET_STANDARD_STD,
    OPENAI_CLIP_MEAN,
    OPENAI_CLIP_STD,
)

# 如果视觉库可用,导入图像处理相关库
if is_vision_available():
    import PIL.Image  # Python Imaging Library,用于图像处理
    import PIL.ImageOps  # PIL 的图像处理操作

    # 根据 PIL 版本选择不同的图像重采样方法
    if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
        PILImageResampling = PIL.Image.Resampling
    else:
        PILImageResampling = PIL.Image

# 如果在类型检查模式下,检查是否有 Torch 可用,若可用则导入 Torch
if TYPE_CHECKING:
    if is_torch_available():
        import torch  # PyTorch 深度学习库

# 获取日志记录器
logger = logging.get_logger(__name__)

# 定义图像输入类型,可以是 PIL 图像、numpy 数组、Torch 张量的列表
ImageInput = Union[
    "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
]  # noqa


class ChannelDimension(ExplicitEnum):
    FIRST = "channels_first"  # 通道维度在前
    LAST = "channels_last"  # 通道维度在后


class AnnotationFormat(ExplicitEnum):
    COCO_DETECTION = "coco_detection"  # COCO 检测注释格式
    COCO_PANOPTIC = "coco_panoptic"  # COCO 全景注释格式


class AnnotionFormat(ExplicitEnum):
    COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value  # COCO 检测注释格式
    COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value  # COCO 全景注释格式


AnnotationType = Dict[str, Union[int, str, List[Dict]]]  # 注释类型,字典形式


def is_pil_image(img):
    return is_vision_available() and isinstance(img, PIL.Image.Image)


def is_valid_image(img):
    return (
        (is_vision_available() and isinstance(img, PIL.Image.Image))  # 图像是 PIL 图像
        or isinstance(img, np.ndarray)  # 图像是 numpy 数组
        or is_torch_tensor(img)  # 图像是 Torch 张量
        or is_tf_tensor(img)  # 图像是 TensorFlow 张量
        or is_jax_tensor(img)  # 图像是 JAX 张量
    )


def valid_images(imgs):
    # 如果是图像列表或元组,则检查每个图像是否有效
    if isinstance(imgs, (list, tuple)):
        for img in imgs:
            if not valid_images(img):
                return False
    # 如果不是图像列表或元组,则检查单个图像或批量张量是否有效
    elif not is_valid_image(imgs):
        return False
    return True


def is_batched(img):
    if isinstance(img, (list, tuple)):
        return is_valid_image(img[0])  # 如果是列表或元组,且第一个元素是有效图像,则认为是批量数据
    return False
# 检查图像是否已经被重新缩放到 [0, 1] 范围内
def is_scaled_image(image: np.ndarray) -> bool:
    if image.dtype == np.uint8:
        return False

    # 可能图像的像素值在 [0, 255] 范围内,但是数据类型是浮点型
    return np.min(image) >= 0 and np.max(image) <= 1


# 确保输入是一个图像列表。如果输入是单个图像,将其转换为长度为 1 的列表。
# 如果输入是图像批次,将其转换为图像列表。
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
    if is_batched(images):
        return images

    # 如果输入是单个 PIL 图像,则创建长度为 1 的列表
    if isinstance(images, PIL.Image.Image):
        # PIL 图像永远不会是批次
        return [images]

    if is_valid_image(images):
        if images.ndim == expected_ndims + 1:
            # 图像批次
            images = list(images)
        elif images.ndim == expected_ndims:
            # 单个图像
            images = [images]
        else:
            raise ValueError(
                f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
                f" {images.ndim} dimensions."
            )
        return images
    raise ValueError(
        "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
        f"jax.ndarray, but got {type(images)}."
    )


# 将输入图像转换为 numpy 数组
def to_numpy_array(img) -> np.ndarray:
    if not is_valid_image(img):
        raise ValueError(f"Invalid image type: {type(img)}")

    if is_vision_available() and isinstance(img, PIL.Image.Image):
        return np.array(img)
    return to_numpy(img)


# 推断图像的通道维度格式
def infer_channel_dimension_format(
    image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
) -> ChannelDimension:
    """
    Infers the channel dimension format of `image`.

    Args:
        image (`np.ndarray`):
            The image to infer the channel dimension of.
        num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
            The number of channels of the image.

    Returns:
        The channel dimension of the image.
    """
    num_channels = num_channels if num_channels is not None else (1, 3)
    num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels

    if image.ndim == 3:
        first_dim, last_dim = 0, 2
    elif image.ndim == 4:
        first_dim, last_dim = 1, 3
    else:
        raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
    # 检查图像数组的指定维度是否在给定的通道数列表中
    if image.shape[first_dim] in num_channels:
        # 如果第一维度的大小存在于通道数列表中,则返回首维度作为通道维度
        return ChannelDimension.FIRST
    elif image.shape[last_dim] in num_channels:
        # 如果最后一维度的大小存在于通道数列表中,则返回末尾维度作为通道维度
        return ChannelDimension.LAST
    # 如果未能确定通道维度的格式,则引发值错误异常
    raise ValueError("Unable to infer channel dimension format")
def get_channel_dimension_axis(
    image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
) -> int:
    """
    Returns the channel dimension axis of the image.

    Args:
        image (`np.ndarray`):
            The image to get the channel dimension axis of.
        input_data_format (`ChannelDimension` or `str`, *optional*):
            The channel dimension format of the image. If `None`, will infer the channel dimension from the image.

    Returns:
        The channel dimension axis of the image.
    """
    # 如果未指定数据格式,从图像推断通道维度格式
    if input_data_format is None:
        input_data_format = infer_channel_dimension_format(image)
    # 如果数据格式为第一维度优先,则返回倒数第三维度的索引
    if input_data_format == ChannelDimension.FIRST:
        return image.ndim - 3
    # 如果数据格式为最后一维度优先,则返回倒数第一维度的索引
    elif input_data_format == ChannelDimension.LAST:
        return image.ndim - 1
    # 抛出异常,不支持的数据格式
    raise ValueError(f"Unsupported data format: {input_data_format}")


def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
    """
    Returns the (height, width) dimensions of the image.

    Args:
        image (`np.ndarray`):
            The image to get the dimensions of.
        channel_dim (`ChannelDimension`, *optional*):
            Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.

    Returns:
        A tuple of the image's height and width.
    """
    # 如果未指定通道维度,从图像推断通道维度格式
    if channel_dim is None:
        channel_dim = infer_channel_dimension_format(image)

    # 如果通道维度为第一维度优先,则返回倒数第二和倒数第一维度的尺寸
    if channel_dim == ChannelDimension.FIRST:
        return image.shape[-2], image.shape[-1]
    # 如果通道维度为最后一维度优先,则返回倒数第三和倒数第二维度的尺寸
    elif channel_dim == ChannelDimension.LAST:
        return image.shape[-3], image.shape[-2]
    # 抛出异常,不支持的数据格式
    else:
        raise ValueError(f"Unsupported data format: {channel_dim}")


def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
    """
    Checks if the given annotation is a valid COCO detection annotation.

    Args:
        annotation (`Dict[str, Union[List, Tuple]]`):
            The annotation dictionary to validate.

    Returns:
        `True` if the annotation is valid, `False` otherwise.
    """
    # 检查注释是否为有效的 COCO 检测注释
    if (
        isinstance(annotation, dict)
        and "image_id" in annotation
        and "annotations" in annotation
        and isinstance(annotation["annotations"], (list, tuple))
        and (
            # 一个图像可能没有注释
            len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
        )
    ):
        return True
    return False


def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
    """
    Checks if the given annotation is a valid COCO panoptic segmentation annotation.

    Args:
        annotation (`Dict[str, Union[List, Tuple]]`):
            The annotation dictionary to validate.

    Returns:
        `True` if the annotation is valid, `False` otherwise.
    """
    # 检查注释是否为有效的 COCO 全景分割注释
    if (
        isinstance(annotation, dict)
        and "image_id" in annotation
        and "segments_info" in annotation
        and "file_name" in annotation
        and isinstance(annotation["segments_info"], (list, tuple))
        and (
            # 一个图像可能没有分割信息
            len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
        )
    ):
        return True
    return False


def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
    """
    Checks if all annotations in the given iterable are valid COCO detection annotations.

    Args:
        annotations (`Iterable[Dict[str, Union[List, Tuple]]]`):
            The iterable of annotations to validate.

    Returns:
        `True` if all annotations are valid, `False` otherwise.
    """
    # 检查给定可迭代对象中的所有注释是否都是有效的 COCO 检测注释
    return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
# 验证所有 COCO Panoptic 注释的有效性
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
    # 使用 `is_valid_annotation_coco_panoptic` 函数检查每个注释项是否有效,全部有效则返回 True
    return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)


def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
    """
    将 `image` 加载为 PIL 图像。

    Args:
        image (`str` or `PIL.Image.Image`):
            要转换为 PIL 图像格式的图像。
        timeout (`float`, *optional*):
            URL 请求的超时值(秒)。

    Returns:
        `PIL.Image.Image`: 一个 PIL 图像。
    """
    # 确保加载图像所需的后端库已加载
    requires_backends(load_image, ["vision"])
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            # 如果图像是通过 HTTP 或 HTTPS 协议访问的 URL,则使用 `requests` 获取图像流,并打开为 PIL 图像
            image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw)
        elif os.path.isfile(image):
            # 如果图像路径是一个文件,则直接打开为 PIL 图像
            image = PIL.Image.open(image)
        else:
            if image.startswith("data:image/"):
                # 如果图像以 data:image/ 开头,则取出 base64 编码的部分
                image = image.split(",")[1]

            # 尝试作为 base64 字符串加载图像
            try:
                b64 = base64.b64decode(image, validate=True)
                image = PIL.Image.open(BytesIO(b64))
            except Exception as e:
                raise ValueError(
                    f"图像源格式错误。必须是以 `http://` 或 `https://` 开头的有效 URL,有效的图像文件路径,或者是 base64 编码的字符串。传入值为 {image}。错误信息:{e}"
                )
    elif isinstance(image, PIL.Image.Image):
        # 如果图像已经是 PIL 图像,则保持不变
        image = image
    else:
        raise ValueError(
            "图像格式不正确。应为指向图像的 URL、base64 字符串、本地路径,或者是一个 PIL 图像。"
        )
    # 根据 EXIF 信息对图像进行自动旋转
    image = PIL.ImageOps.exif_transpose(image)
    # 将图像转换为 RGB 模式(如果不是的话)
    image = image.convert("RGB")
    return image


def validate_preprocess_arguments(
    do_rescale: Optional[bool] = None,
    rescale_factor: Optional[float] = None,
    do_normalize: Optional[bool] = None,
    image_mean: Optional[Union[float, List[float]]] = None,
    image_std: Optional[Union[float, List[float]]] = None,
    do_pad: Optional[bool] = None,
    size_divisibility: Optional[int] = None,
    do_center_crop: Optional[bool] = None,
    crop_size: Optional[Dict[str, int]] = None,
    do_resize: Optional[bool] = None,
    size: Optional[Dict[str, int]] = None,
    resample: Optional["PILImageResampling"] = None,
):
    """
    检查 `ImageProcessor` 的 `preprocess` 方法中常用参数的有效性。
    如果发现参数不兼容,则抛出 `ValueError` 异常。
    许多不兼容性是与模型相关的。`do_pad` 有时需要 `size_divisor`,有时需要 `size_divisibility`,有时需要 `size`。
    新增的模型和处理器应尽量遵循现有参数的使用规则。

    """
    # 如果需要进行重新缩放,并且未指定缩放因子,则抛出数值错误异常
    if do_rescale and rescale_factor is None:
        raise ValueError("rescale_factor must be specified if do_rescale is True.")

    # 如果需要进行填充,并且未指定尺寸可被整除的值,则抛出数值错误异常
    # 在这里,size_divisibility可能被作为size的值传递
    raise ValueError(
        "Depending on model, size_divisibility, size_divisor, pad_size or size must be specified if do_pad is True."
    )

    # 如果需要进行归一化,并且未指定图像均值和标准差,则抛出数值错误异常
    if do_normalize and (image_mean is None or image_std is None):
        raise ValueError("image_mean and image_std must both be specified if do_normalize is True.")

    # 如果需要进行中心裁剪,并且未指定裁剪尺寸,则抛出数值错误异常
    if do_center_crop and crop_size is None:
        raise ValueError("crop_size must be specified if do_center_crop is True.")

    # 如果需要进行调整大小,并且未指定大小或重采样方法,则抛出数值错误异常
    if do_resize and (size is None or resample is None):
        raise ValueError("size and resample must be specified if do_resize is True.")
# 在未来,如果我们有了 TensorFlow 模型,可以在这里添加 TF 的实现。
class ImageFeatureExtractionMixin:
    """
    包含用于准备图像特征的工具函数的 Mixin。
    """

    def _ensure_format_supported(self, image):
        """
        确保图像格式受支持,如果不受支持则引发 ValueError 异常。

        Args:
            image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
                要检查的图像对象。
        """
        if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
            raise ValueError(
                f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
                "`torch.Tensor` are."
            )

    def to_pil_image(self, image, rescale=None):
        """
        将 `image` 转换为 PIL Image 格式。可选地重新缩放,并在需要时将通道维度放回到最后一个轴。

        Args:
            image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
                要转换为 PIL Image 格式的图像对象。
            rescale (`bool`, *optional*):
                是否应用缩放因子(使像素值成为介于0到255之间的整数)。如果图像类型是浮点类型,则默认为 `True`。
        """
        self._ensure_format_supported(image)

        if is_torch_tensor(image):
            image = image.numpy()

        if isinstance(image, np.ndarray):
            if rescale is None:
                # 如果数组是浮点类型,则默认 rescale 为 True。
                rescale = isinstance(image.flat[0], np.floating)
            # 如果通道被移动到第一个维度,我们将其放回到最后。
            if image.ndim == 3 and image.shape[0] in [1, 3]:
                image = image.transpose(1, 2, 0)
            if rescale:
                image = image * 255
            image = image.astype(np.uint8)
            return PIL.Image.fromarray(image)
        return image

    def convert_rgb(self, image):
        """
        将 `PIL.Image.Image` 转换为 RGB 格式。

        Args:
            image (`PIL.Image.Image`):
                要转换的图像对象。
        """
        self._ensure_format_supported(image)
        if not isinstance(image, PIL.Image.Image):
            return image

        return image.convert("RGB")

    def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
        """
        缩放 numpy 图像按比例 `scale`。

        Args:
            image (`numpy.ndarray`):
                要缩放的图像数组。
            scale (Union[float, int]):
                缩放因子。

        Returns:
            `numpy.ndarray`: 缩放后的图像数组。
        """
        self._ensure_format_supported(image)
        return image * scale
    def to_numpy_array(self, image, rescale=None, channel_first=True):
        """
        Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
        dimension.

        Args:
            image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
                The image to convert to a NumPy array.
            rescale (`bool`, *optional*):
                Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
                default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
            channel_first (`bool`, *optional*, defaults to `True`):
                Whether or not to permute the dimensions of the image to put the channel dimension first.
        """
        # 确保传入的图像格式受支持
        self._ensure_format_supported(image)

        # 如果图像是 PIL Image 对象,则转换为 numpy 数组
        if isinstance(image, PIL.Image.Image):
            image = np.array(image)

        # 如果图像是 torch Tensor,则转换为 numpy 数组
        if is_torch_tensor(image):
            image = image.numpy()

        # 如果 rescale 未指定,则根据图像的数据类型判断是否需要进行重新缩放
        rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale

        # 如果需要重新缩放,则将图像像素值缩放到 [0, 1] 范围内
        if rescale:
            image = self.rescale(image.astype(np.float32), 1 / 255.0)

        # 如果需要将通道维度放在第一维,则进行维度变换
        if channel_first and image.ndim == 3:
            image = image.transpose(2, 0, 1)

        return image

    def expand_dims(self, image):
        """
        Expands 2-dimensional `image` to 3 dimensions.

        Args:
            image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
                The image to expand.
        """
        # 确保传入的图像格式受支持
        self._ensure_format_supported(image)

        # 如果图像是 PIL Image 对象,则直接返回,不做任何维度扩展操作
        if isinstance(image, PIL.Image.Image):
            return image

        # 如果图像是 torch Tensor,则在第0维上增加一个维度
        if is_torch_tensor(image):
            image = image.unsqueeze(0)
        else:
            # 如果图像是 numpy 数组,则在第0维上增加一个维度
            image = np.expand_dims(image, axis=0)
        return image
    def normalize(self, image, mean, std, rescale=False):
        """
        Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
        if it's a PIL Image.

        Args:
            image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
                The image to normalize.
            mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
                The mean (per channel) to use for normalization.
            std (`List[float]` or `np.ndarray` or `torch.Tensor`):
                The standard deviation (per channel) to use for normalization.
            rescale (`bool`, *optional*, defaults to `False`):
                Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
                happen automatically.
        """
        # Ensure that the image format is supported for normalization
        self._ensure_format_supported(image)

        # Convert PIL Image to NumPy array and optionally rescale if required
        if isinstance(image, PIL.Image.Image):
            image = self.to_numpy_array(image, rescale=True)
        # If image is not PIL, check if rescaling is needed and handle accordingly
        elif rescale:
            if isinstance(image, np.ndarray):
                image = self.rescale(image.astype(np.float32), 1 / 255.0)
            elif is_torch_tensor(image):
                image = self.rescale(image.float(), 1 / 255.0)

        # Ensure mean and std are in the correct format based on image type
        if isinstance(image, np.ndarray):
            if not isinstance(mean, np.ndarray):
                mean = np.array(mean).astype(image.dtype)
            if not isinstance(std, np.ndarray):
                std = np.array(std).astype(image.dtype)
        elif is_torch_tensor(image):
            import torch

            if not isinstance(mean, torch.Tensor):
                mean = torch.tensor(mean)
            if not isinstance(std, torch.Tensor):
                std = torch.tensor(std)

        # Normalize the image based on its dimensions and channel structure
        if image.ndim == 3 and image.shape[0] in [1, 3]:  # RGB or grayscale image
            return (image - mean[:, None, None]) / std[:, None, None]
        else:  # Handle other image types
            return (image - mean) / std

    def flip_channel_order(self, image):
        """
        Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
        `image` to a NumPy array if it's a PIL Image.

        Args:
            image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
                The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
                be first.
        """
        # Ensure that the image format is supported for channel flipping
        self._ensure_format_supported(image)

        # Convert PIL Image to NumPy array for manipulation
        if isinstance(image, PIL.Image.Image):
            image = self.to_numpy_array(image)

        # Reverse the order of color channels (RGB <-> BGR)
        return image[::-1, :, :]
    def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
        """
        Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
        counter clockwise around its centre.

        Args:
            image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
                The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
                rotating.
            angle (float or int):
                The rotation angle in degrees. Positive angles are counter-clockwise.
            resample (int, optional):
                An optional resampling filter. Default is `PIL.Image.NEAREST`.
            expand (bool or int, optional):
                Optional expansion flag. If true, the output image size is expanded to contain the entire rotated image.
                If integer, it specifies the desired size of the output image (tuple). Default is 0.
            center (tuple of int, optional):
                Optional center of rotation. Default is None, which means the center is calculated as the center of the image.
            translate (tuple of int, optional):
                Optional translation offset. Default is None.
            fillcolor (tuple or int, optional):
                Optional background color given as a single integer value or a tuple of three integers.

        Returns:
            `PIL.Image.Image`: A rotated `PIL.Image.Image` instance.

        """
        # 如果未指定 resample 参数,则使用默认的 NEAREST 模式
        resample = resample if resample is not None else PIL.Image.NEAREST

        # 确保图像格式受支持,调用对象内部方法进行检查
        self._ensure_format_supported(image)

        # 如果输入的 image 不是 PIL.Image.Image 对象,则将其转换为 PIL.Image.Image 对象
        if not isinstance(image, PIL.Image.Image):
            image = self.to_pil_image(image)

        # 调用 PIL 库中的 rotate 方法进行图像旋转,并返回旋转后的图像副本
        return image.rotate(
            angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
        )
# 根据给定的注释格式升级注释格式对象
def promote_annotation_format(annotation_format: Union[AnnotionFormat, AnnotationFormat]) -> AnnotationFormat:
    # 当 `AnnotionFormat` 完全废弃后,此行代码可以移除
    return AnnotationFormat(annotation_format.value)


# 验证注释的有效性
def validate_annotations(
    annotation_format: AnnotationFormat,
    supported_annotation_formats: Tuple[AnnotationFormat, ...],
    annotations: List[Dict],
) -> None:
    # 如果注释格式是旧的 `AnnotionFormat` 类型,则发出警告,并升级为 `AnnotationFormat`
    if isinstance(annotation_format, AnnotionFormat):
        logger.warning_once(
            f"`{annotation_format.__class__.__name__}` is deprecated and will be removed in v4.38. "
            f"Please use `{AnnotationFormat.__name__}` instead."
        )
        annotation_format = promote_annotation_format(annotation_format)

    # 检查注释格式是否在支持的注释格式列表中
    if annotation_format not in supported_annotation_formats:
        raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")

    # 如果注释格式为 `AnnotationFormat.COCO_DETECTION`,则验证 COCO 检测注释的有效性
    if annotation_format is AnnotationFormat.COCO_DETECTION:
        if not valid_coco_detection_annotations(annotations):
            raise ValueError(
                "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
                "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
                "being a list of annotations in the COCO format."
            )

    # 如果注释格式为 `AnnotationFormat.COCO_PANOPTIC`,则验证 COCO 全景注释的有效性
    if annotation_format is AnnotationFormat.COCO_PANOPTIC:
        if not valid_coco_panoptic_annotations(annotations):
            raise ValueError(
                "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
                "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
                "the latter being a list of annotations in the COCO format."
            )


# 验证关键字参数的有效性,并发出警告对于未使用或无法识别的关键字参数
def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
    unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
    if unused_keys:
        unused_key_str = ", ".join(unused_keys)
        # TODO: 这里是否应该发出警告而不是仅仅记录日志?
        logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")

.\integrations\aqlm.py

# 版权声明和许可信息
# 版权所有 2024 年 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证 2.0 版本(“许可证”)获得许可;
# 除非符合许可证的规定,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件,
# 无论是明示的还是暗示的条件,包括但不限于适销性和特定用途的适用性。
# 有关详细信息,请参阅许可证。
"AQLM (Additive Quantization of Language Model) integration file"

# 导入所需的模块和函数
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available

# 如果 torch 可用,则导入 torch.nn 模块
if is_torch_available():
    import torch.nn as nn

# 替换模型中的线性层为 AQLM 量化层的公共方法
def replace_with_aqlm_linear(
    model,
    quantization_config=None,
    linear_weights_not_to_quantize=None,
    current_key_name=None,
    has_been_replaced=False,
):
    """
    Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
    `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
    conversion has been successfull or not.

    Args:
        model (`torch.nn.Module`):
            The model to convert, can be any `torch.nn.Module` instance.
        quantization_config (`AqlmConfig`):
            The quantization config object that contains the quantization parameters.
        linear_weights_not_to_quantize (`list[str]`, *optional*):
            A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
            converted.
        current_key_name (`list`, *optional*):
            A list that contains the current key name. This is used for recursion and should not be passed by the user.
        has_been_replaced (`bool`, *optional*):
            A boolean that indicates if the conversion has been successful or not. This is used for recursion and
            should not be passed by the user.
    """

    # 检查是否安装了 AQLM
    if not is_aqlm_available():
        raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")

    # 检查是否安装了 Accelerate
    if not is_accelerate_available():
        raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`")

    # 如果未提供 linear_weights_not_to_quantize 参数,则初始化为空列表
    if linear_weights_not_to_quantize is None:
        linear_weights_not_to_quantize = []

    # 导入所需的函数和类
    from accelerate import init_empty_weights
    from aqlm import QuantizedLinear
    # 遍历模型的每个子模块的名称和模块本身
    for name, module in model.named_children():
        # 如果当前键名为 None,则初始化为空列表
        if current_key_name is None:
            current_key_name = []
        # 将当前模块名称添加到当前键名列表中
        current_key_name.append(name)

        # 如果当前模块是线性层(nn.Linear)
        if isinstance(module, nn.Linear):
            # 构造当前模块权重的完整路径,以便检查是否不需要量化
            if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
                # 使用 init_empty_weights 上下文管理器初始化空权重
                with init_empty_weights():
                    # 获取输入和输出特征数
                    in_features = module.in_features
                    out_features = module.out_features

                    # 替换当前模块为量化后的 QuantizedLinear 模块
                    model._modules[name] = QuantizedLinear(
                        in_features,
                        out_features,
                        bias=module.bias is not None,
                        in_group_size=quantization_config.in_group_size,
                        out_group_size=quantization_config.out_group_size,
                        num_codebooks=quantization_config.num_codebooks,
                        nbits_per_codebook=quantization_config.nbits_per_codebook,
                    )
                    # 标记模块已被替换
                    has_been_replaced = True

                    # 存储原始模块类以备稍后可能需要对权重进行转置
                    model._modules[name].source_cls = type(module)
                    # 将 requires_grad 设置为 False,避免意外错误
                    model._modules[name].requires_grad_(False)

        # 如果当前模块有子模块
        if len(list(module.children())) > 0:
            # 递归调用 replace_with_aqlm_linear 函数替换子模块中的线性层
            _, has_been_replaced = replace_with_aqlm_linear(
                module,
                quantization_config=quantization_config,
                linear_weights_not_to_quantize=linear_weights_not_to_quantize,
                current_key_name=current_key_name,
                has_been_replaced=has_been_replaced,
            )
        
        # 递归结束后,移除当前键名列表中的最后一个键名
        current_key_name.pop(-1)

    # 返回替换后的模型及替换状态标志
    return model, has_been_replaced

.\integrations\awq.py

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"AWQ (Activation aware Weight Quantization) integration file"
from ..activations import ACT2FN  # 导入激活函数映射
from ..modeling_utils import PreTrainedModel  # 导入预训练模型工具函数
from ..utils import is_auto_awq_available, is_torch_available  # 导入 AWQ 自动可用性检查和 Torch 可用性检查
from ..utils.quantization_config import (  # 导入量化配置
    AwqBackendPackingMethod,
    AwqConfig,
    AWQLinearVersion,
    ExllamaVersion,
)

if is_torch_available():  # 如果 Torch 可用
    import torch  # 导入 PyTorch
    import torch.nn as nn  # 导入 PyTorch 神经网络模块

# AWQ_FUSED_MAPPINGS 定义了不同模型类型的层映射字典,用于 AWQ 线性层替换
AWQ_FUSED_MAPPINGS = {
    "mistral": {
        "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
        "mlp": ["gate_proj", "up_proj", "down_proj"],
        "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
        "use_alibi": False,
    },
    "mixtral": {
        "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
        "mlp": ["w1", "w3", "w2"],
        "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
        "use_alibi": False,
        "rope_theta": 1000000.0,
    },
    "llama": {
        "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
        "mlp": ["gate_proj", "up_proj", "down_proj"],
        "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
        "use_alibi": False,
    },
    "llava": {
        "attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
        "mlp": ["gate_proj", "up_proj", "down_proj"],
        "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
        "use_alibi": False,
    },
}

def replace_with_awq_linear(
    model,
    modules_to_not_convert=None,
    quantization_config=None,
    current_key_name=None,
    has_been_replaced=False,
) -> bool:
    """
    Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
    `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
    conversion has been successful or not.

    During the module replacement, we also infer the backend to use through the `quantization_config` object.
    """
    Args:
        model (`torch.nn.Module`):
            要转换的模型,可以是任何 `torch.nn.Module` 实例。
        quantization_config (`AwqConfig`):
            包含量化参数的量化配置对象。
        modules_to_not_convert (`list`, *可选*):
            不需要转换的模块列表。如果模块名在列表中(例如 `lm_head`),则不会进行转换。
        current_key_name (`list`, *可选*):
            包含当前键名的列表。这用于递归,用户不应传递此参数。
        has_been_replaced (`bool`, *可选*):
            表示转换是否成功的布尔值。这用于递归,用户不应传递此参数。
    """
    # 如果未指定不转换的模块列表,则初始化为空列表
    if modules_to_not_convert is None:
        modules_to_not_convert = []

    # 获取量化配置中的后端信息
    backend = quantization_config.backend

    # 检查是否存在自动 AWQ 支持
    if not is_auto_awq_available():
        raise ValueError(
            "AWQ(`autoawq` 或 `llmawq`)不可用。请使用 `pip install autoawq` 安装或查看安装指南:https://github.com/mit-han-lab/llm-awq"
        )

    # 根据量化配置选择合适的量化线性层类
    if backend == AwqBackendPackingMethod.AUTOAWQ:
        if quantization_config.version == AWQLinearVersion.GEMM:
            # 导入 GEMM 版本的量化线性层类
            from awq.modules.linear.gemm import WQLinear_GEMM

            target_cls = WQLinear_GEMM
        elif quantization_config.version == AWQLinearVersion.GEMV:
            # 导入 GEMV 版本的量化线性层类
            from awq.modules.linear.gemv import WQLinear_GEMV

            target_cls = WQLinear_GEMV
        elif quantization_config.version == AWQLinearVersion.EXLLAMA:
            if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
                # 导入 Exllama 版本一的量化线性层类
                from awq.modules.linear.exllama import WQLinear_Exllama

                target_cls = WQLinear_Exllama
            elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
                # 导入 Exllama 版本二的量化线性层类
                from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2

                target_cls = WQLinear_ExllamaV2
            else:
                raise ValueError(f"未知的 Exllama 版本: {quantization_config.exllama_config['version']}")
        else:
            raise ValueError(f"未知的 AWQ 版本: {quantization_config.version}")
    else:
        # 若未选择 AUTOAWQ 后端,则使用默认的量化线性层类
        from awq.quantize.qmodule import WQLinear

        target_cls = WQLinear
    # 遍历模型的所有子模块,获取每个子模块的名称和实例
    for name, module in model.named_children():
        # 如果当前键名为 None,则初始化为一个空列表
        if current_key_name is None:
            current_key_name = []
        # 将当前子模块名称添加到当前键名列表中
        current_key_name.append(name)

        # 如果当前模块是 nn.Linear 类型,并且其名称不在不转换的模块列表中
        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
            # 检查当前键名组合不在不转换模块列表中的任何键中
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
                # 获取线性层的输入和输出特征数
                in_features = module.in_features
                out_features = module.out_features

                # 将模型中的当前线性层替换为目标类的实例化对象
                model._modules[name] = target_cls(
                    w_bit=quantization_config.bits,
                    group_size=quantization_config.group_size,
                    in_features=in_features,
                    out_features=out_features,
                    bias=module.bias is not None,
                    dev=module.weight.device,
                )
                # 标记该模块已被替换
                has_been_replaced = True

                # 强制设置该模块的 requires_grad 为 False,以避免意外错误
                model._modules[name].requires_grad_(False)

        # 如果当前模块还有子模块,则递归调用此函数来替换子模块
        if len(list(module.children())) > 0:
            _, has_been_replaced = replace_with_awq_linear(
                module,
                modules_to_not_convert=modules_to_not_convert,
                current_key_name=current_key_name,
                quantization_config=quantization_config,
                has_been_replaced=has_been_replaced,
            )
        
        # 移除当前键名列表中的最后一个键名,为递归调用做准备
        current_key_name.pop(-1)
    
    # 返回替换后的模型和是否有模块被替换的标志
    return model, has_been_replaced
# 返回模型中需要融合的模块映射,根据给定的量化配置和模型
def get_modules_to_fuse(model, quantization_config):
    """
    Returns the fusing mapping given the quantization config and the model

    Args:
        model (`~PreTrainedModel`):
            The model to fuse - note this model should have been converted into AWQ format beforehand.
        quantization_config (`~transformers.quantization_config.AWQConfig`):
            The quantization configuration to use.
    """
    # 如果模型不是PreTrainedModel的实例,则抛出错误
    if not isinstance(model, PreTrainedModel):
        raise ValueError(f"The model should be an instance of `PreTrainedModel`, got {model.__class__.__name__}")

    # 总是默认使用 `quantization_config.modules_to_fuse`
    if quantization_config.modules_to_fuse is not None:
        current_fused_mapping = quantization_config.modules_to_fuse
        current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
    # 如果 `quantization_config.modules_to_fuse` 为None,则根据模型类型在 `AWQ_FUSED_MAPPINGS` 中查找对应的映射
    elif model.config.model_type in AWQ_FUSED_MAPPINGS:
        current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]

        # 处理多模态模型的情况(如Llava),区分 `model.config` 和 `model.config.text_config`
        if not hasattr(model.config, "text_config"):
            config = model.config
        else:
            config = model.config.text_config

        # 单独处理 `hidden_size`、`num_attention_heads` 和 `num_key_value_heads` 字段
        hidden_size = config.hidden_size
        num_attention_heads = config.num_attention_heads
        num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)

        # 填充 `current_fused_mapping` 中的预期值
        current_fused_mapping["hidden_size"] = hidden_size
        current_fused_mapping["num_attention_heads"] = num_attention_heads
        current_fused_mapping["num_key_value_heads"] = num_key_value_heads
        current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
    # 如果都没有找到合适的融合映射,则抛出错误
    else:
        raise ValueError(
            "Fusing mapping not found either on the quantization config or the supported `AWQ_FUSED_MAPPINGS`. Please pass a `fused_mapping` argument"
            " in the `quantization_config` or raise an issue on transformers https://github.com/huggingface/transformers to add its support."
        )
    return current_fused_mapping


# 可选地融合模型中的一些模块以加速推断
def fuse_awq_modules(model, quantization_config):
    """
    Optionally fuse some modules in the model to speedup inference.

    Args:
        model (`~PreTrainedModel`):
            The model to fuse - note this model should have been converted into AWQ format beforehand.
        quantization_config (`Union[AwqConfig, dict]`):
            The quantization configuration to use.
    """
    # 如果 `quantization_config` 是字典,则将其转换为 AwqConfig 对象以便获取 `backend` 等字段
    # 否则这些字段将不可用
    # https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
    if isinstance(quantization_config, dict):
        quantization_config = AwqConfig.from_dict(quantization_config)
    # 获取量化配置中的后端信息
    backend = quantization_config.backend

    # 获取需要融合的模块列表
    modules_to_fuse = get_modules_to_fuse(model, quantization_config)
    
    # 获取不需要转换的模块列表(如果有的话)
    modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)

    # 检查是否使用自动 AWQ 后端
    if backend == AwqBackendPackingMethod.AUTOAWQ:
        # 导入 AWQ 后端的融合模块
        from awq.modules.fused.attn import QuantAttentionFused
        from awq.modules.fused.mlp import QuantFusedMLP
        from awq.modules.fused.norm import FasterTransformerRMSNorm
    else:
        # 抛出数值错误,只支持 AutoAWQ 后端的融合
        raise ValueError("Fusing is only supported for the AutoAWQ backend")

    # 遍历模型的所有命名模块
    for name, module in model.named_modules():
        # 如果存在不需要转换的模块列表
        if modules_to_not_convert is not None:
            # 检查当前模块名是否在不需要转换的模块列表中的任意一个
            if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
                # 如果是,则跳过当前模块的处理
                continue

        # 替换模型中的 LayerNorm 层
        _fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)

        # 替换模型中的 MLP 层
        _fuse_awq_mlp(model, name, modules_to_fuse["mlp"], module, QuantFusedMLP)

        # 替换模型中的 Attention 层
        _fuse_awq_attention_layers(model, module, modules_to_fuse, name, QuantAttentionFused)

    # 返回融合后的模型
    return model
# 融合 LayerNorm 层到目标类中,使用自动 AWQ(Automatic Weight Quantization)
def _fuse_awq_layernorm(fuse_module_names, module, target_cls):
    """
    Fuse the LayerNorm layers into a target class using autoawq

    Args:
        fuse_module_names (`List[str]`):
            The list of module names to fuse
        module (`nn.Module`):
            The pytorch parent module that has layernorm modules to fuse
        target_cls (`~autoawq.FasterTransformerRMSNorm`):
            The `FasterTransformerRMSNorm` class as it only supports that class
            for now.
    """
    # 遍历要融合的模块名列表
    for module_name in fuse_module_names:
        # 检查父模块是否具有指定的模块名
        if hasattr(module, module_name):
            # 获取旧的 LayerNorm 模块
            old_module = getattr(module, module_name)
            # 创建新的 target_cls 类实例来替换旧模块
            module._modules[module_name] = target_cls(
                old_module.weight,
                old_module.variance_epsilon,
            ).to(old_module.weight.device)
            # 删除旧模块,释放内存
            del old_module


# 融合 MLP 层到目标类中,使用自动 AWQ
def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_cls):
    """
    Fuse the MLP layers into a target class using autoawq

    Args:
        model (`~PreTrainedModel`):
            The input pretrained model
        current_module_name (`str`):
            The current submodule name
        fuse_module_names (`List[str]`):
            The list of module names to fuse. For the MLP layers it has to be an array
            of length 3 that consists of the 3 MLP layers in the order (gate (dense layer post-attention) / up / down layers)
        module (`nn.Module`):
            The pytorch parent module that has layernorm modules to fuse
        target_cls (`~autoawq.QuantFusedMLP`):
            The `QuantFusedMLP` class as it only supports that class
            for now.
    """
    # 如果没有要融合的模块名,直接返回
    if len(fuse_module_names) == 0:
        return

    # 检查父模块是否具有第一个要融合的模块名
    if hasattr(module, fuse_module_names[0]):
        # 获取三个 MLP 层的引用
        gate_proj = getattr(module, fuse_module_names[0])
        up_proj = getattr(module, fuse_module_names[1])
        down_proj = getattr(module, fuse_module_names[2])

        # 记录 gate_proj 的设备信息
        previous_device = gate_proj.qweight.device

        # 处理模型具有 `text_config` 属性的情况
        hidden_act = (
            model.config.hidden_act
            if not hasattr(model.config, "text_config")
            else model.config.text_config.hidden_act
        )
        # 根据 hidden_act 获取激活函数
        activation_fn = ACT2FN[hidden_act]
        # 创建新的 QuantFusedMLP 实例
        new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)

        # 分离当前模块的父子名称
        parent_name, child_name = current_module_name.rsplit(".", 1)
        # 获取父模块
        parent = model.get_submodule(parent_name)
        # 将新模块设置为子模块的属性,并转移到之前的设备上
        setattr(parent, child_name, new_module.to(previous_device))

        # 删除临时变量,释放内存
        del gate_proj, up_proj, down_proj


def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_name, target_cls):
    """
    Fuse the Attention layers into a target class using autoawq
    """
    # 这部分代码还未提供,需要根据实际情况继续补充
    pass
    # 导入需要的模块:WQLinear_GEMM 和 WQLinear_GEMV,来自 awq.modules.linear
    from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV

    # 如果 modules_to_fuse 字典中的 attention 列表为空,直接返回
    if len(modules_to_fuse["attention"]) == 0:
        return
    # 检查模块是否具有指定名称的属性
    if hasattr(module, modules_to_fuse["attention"][0]):
        # 获取模块中指定注意力组件的引用
        q_proj = getattr(module, modules_to_fuse["attention"][0])

        # 根据不同的注意力组件类型选择相应的线性层类和连接维度
        if isinstance(q_proj, WQLinear_GEMV):
            linear_target_cls = WQLinear_GEMV
            cat_dim = 0
        elif isinstance(q_proj, WQLinear_GEMM):
            linear_target_cls = WQLinear_GEMM
            cat_dim = 1
        else:
            # 如果遇到不支持的 q_proj 类型,则抛出异常
            raise ValueError(f"Unsupported q_proj type: {type(q_proj)}")

        # 记录 q_proj 的设备信息
        previous_device = q_proj.qweight.device

        # 获取其他相关的注意力组件
        k_proj = getattr(module, modules_to_fuse["attention"][1])
        v_proj = getattr(module, modules_to_fuse["attention"][2])
        o_proj = getattr(module, modules_to_fuse["attention"][3])

        # 合并偏置项,如果存在的话
        bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

        # 创建新的量化线性层,整合 QKV 权重、量化零点和缩放因子
        qkv_layer = linear_target_cls(
            q_proj.w_bit,
            q_proj.group_size,
            q_proj.in_features,
            q_proj.out_features + k_proj.out_features + v_proj.out_features,
            q_proj.bias is not None,
            next(iter(module.state_dict().values())).device,
        )

        # 合并 QKV 层的权重、量化零点和缩放因子
        qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=cat_dim)
        qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=cat_dim)
        qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=cat_dim)

        # 如果是 WQLinear_GEMV 类型的量化线性层,设置其特有的 split_k_iters 属性
        if isinstance(qkv_layer, WQLinear_GEMV):
            qkv_layer.split_k_iters = q_proj.split_k_iters

        # 设置合并后注意力层的偏置项
        qkv_layer.bias = bias

        # 创建融合后的注意力层,使用指定的参数初始化
        fused_attention_layer = target_cls(
            modules_to_fuse["hidden_size"],
            modules_to_fuse["num_attention_heads"],
            modules_to_fuse["num_key_value_heads"],
            qkv_layer,
            o_proj,
            previous_device,
            modules_to_fuse["max_seq_len"],
            use_alibi=modules_to_fuse["use_alibi"],
            rope_theta=modules_to_fuse.get("rope_theta", 10000.0),  # 设置默认的 rope_theta 值为 10000.0
        )

        # 标记融合后的注意力层是 HF Transformers 的一部分
        fused_attention_layer.is_hf_transformers = True

        # 将融合后的注意力层设置为模型中对应的子模块
        parent_name, child_name = current_module_name.rsplit(".", 1)
        parent = model.get_submodule(parent_name)
        setattr(parent, child_name, fused_attention_layer.to(previous_device))

        # 清理不再需要的变量引用,释放内存
        del q_proj, k_proj, v_proj, o_proj
def post_init_awq_exllama_modules(model, exllama_config):
    """
    Runs post init for Exllama layers which performs:
        - Weights unpacking, reordering and repacking
        - Devices scratch space allocation
    """

    # 检查配置中的 Exllama 版本是否为版本一
    if exllama_config["version"] == ExllamaVersion.ONE:
        # 如果是版本一,则导入版本一的初始化函数,并对模型进行处理
        from awq.modules.linear.exllama import exllama_post_init
        model = exllama_post_init(model)
    # 检查配置中的 Exllama 版本是否为版本二
    elif exllama_config["version"] == ExllamaVersion.TWO:
        # 如果是版本二,则导入版本二的初始化函数,并根据配置参数对模型进行处理
        from awq.modules.linear.exllamav2 import exllamav2_post_init
        model = exllamav2_post_init(
            model,
            max_input_len=exllama_config["max_input_len"],
            max_batch_size=exllama_config["max_batch_size"],
        )
    else:
        # 如果配置中的 Exllama 版本既不是一也不是二,则抛出异常
        raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")

    # 返回经过 Exllama 模块处理后的模型
    return model

.\integrations\bitsandbytes.py

import importlib.metadata  # 导入元数据模块,用于获取包的版本信息
import warnings  # 导入警告模块,用于处理警告信息
from copy import deepcopy  # 导入深拷贝函数,用于复制对象
from inspect import signature  # 导入签名模块,用于获取函数的参数签名信息

from packaging import version  # 导入版本模块,用于处理版本号

from ..utils import is_accelerate_available, is_bitsandbytes_available, logging  # 导入自定义工具函数和日志模块


if is_bitsandbytes_available():
    import bitsandbytes as bnb  # 如果bitsandbytes可用,导入bitsandbytes库
    import torch  # 导入PyTorch库
    import torch.nn as nn  # 导入PyTorch的神经网络模块

    from ..pytorch_utils import Conv1D  # 导入自定义的Conv1D模块

if is_accelerate_available():
    from accelerate import init_empty_weights  # 如果accelerate可用,导入初始化空权重函数
    from accelerate.utils import find_tied_parameters  # 导入查找绑定参数的函数

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, quantized_stats=None):
    """
    A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
    `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
    function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
    class `Int8Params` from `bitsandbytes`.

    Args:
        module (`torch.nn.Module`):
            The module in which the tensor we want to move lives.
        tensor_name (`str`):
            The full name of the parameter/buffer.
        device (`int`, `str` or `torch.device`):
            The device on which to set the tensor.
        value (`torch.Tensor`, *optional*):
            The value of the tensor (useful when going from the meta device to any other device).
        quantized_stats (`dict[str, Any]`, *optional*):
            Dict with items for either 4-bit or 8-bit serialization
    """
    # 如果张量名包含点号,递归访问模块的子模块直到找到张量名
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]

    # 如果张量名不在参数或缓冲区中,抛出值错误异常
    if tensor_name not in module._parameters and tensor_name not in module._buffers:
        raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
    is_buffer = tensor_name in module._buffers  # 标记张量名是否在缓冲区中
    old_value = getattr(module, tensor_name)  # 获取模块中的旧张量值

    # 如果旧张量值在meta设备上,但目标设备不是meta,且没有提供值,则抛出值错误异常
    if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
        raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

    prequantized_loading = quantized_stats is not None  # 标记是否为预量化加载

    # 如果是缓冲区或bitsandbytes库不可用,则不是4位或8位量化
    if is_buffer or not is_bitsandbytes_available():
        is_8bit = False
        is_4bit = False
    else:
        # 检查是否是4位参数并且模块参数是4位
        is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
        # 检查是否是8位参数
        is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
    # 检查是否为8位或4位量化模型
    if is_8bit or is_4bit:
        # 获取模块中指定张量名称的参数
        param = module._parameters[tensor_name]
        # 如果参数不在CUDA设备上,则需要进行数据迁移
        if param.device.type != "cuda":
            # 根据值的类型和情况,将旧值转移到指定设备上或者转换为CPU上的张量
            if value is None:
                new_value = old_value.to(device)
            elif isinstance(value, torch.Tensor):
                new_value = value.to("cpu")
            else:
                new_value = torch.tensor(value, device="cpu")

            # 如果模块源类型是Conv1D,并且不是预量化加载情况下,需要转置权重矩阵以支持Conv1D替代nn.Linear的模型
            if issubclass(module.source_cls, Conv1D) and not prequantized_loading:
                new_value = new_value.T

            # 将旧值的属性作为关键字参数传递给新值
            kwargs = old_value.__dict__

            # 检查新值的dtype是否与参数量化状态兼容
            if prequantized_loading != (new_value.dtype in (torch.int8, torch.uint8)):
                raise ValueError(
                    f"Value dtype `{new_value.dtype}` is not compatible with parameter quantization status."
                )

            # 如果是8位量化模型
            if is_8bit:
                # 检查bitsandbytes库版本是否支持int8的序列化
                is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
                    "0.37.2"
                )
                # 如果新值的dtype是int8或uint8且bitsandbytes版本不支持int8序列化,则抛出错误
                if new_value.dtype in (torch.int8, torch.uint8) and not is_8bit_serializable:
                    raise ValueError(
                        "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
                        "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
                    )
                # 使用bitsandbytes库中的Int8Params将新值转换为int8参数,设置不需要梯度,并应用到指定设备上
                new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
                # 如果是预量化加载情况下,将quantized_stats中的SCB属性设置到新值的SCB属性上
                if prequantized_loading:
                    setattr(new_value, "SCB", quantized_stats["SCB"].to(device))
            # 如果是4位量化模型
            elif is_4bit:
                # 如果是预量化加载情况下,检查bitsandbytes库版本是否支持4位的序列化
                if prequantized_loading:
                    is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
                        "0.41.3"
                    )
                    # 如果新值的dtype是int8或uint8且bitsandbytes版本不支持4位序列化,则抛出错误
                    if new_value.dtype in (torch.int8, torch.uint8) and not is_4bit_serializable:
                        raise ValueError(
                            "Detected 4-bit weights but the version of bitsandbytes is not compatible with 4-bit serialization. "
                            "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
                        )
                    # 使用bitsandbytes库中的Params4bit.from_prequantized方法从预量化数据创建4位参数,设置不需要梯度,并应用到指定设备上
                    new_value = bnb.nn.Params4bit.from_prequantized(
                        data=new_value,
                        quantized_stats=quantized_stats,
                        requires_grad=False,
                        device=device,
                        **kwargs,
                    )
                else:
                    # 使用bitsandbytes库中的Params4bit将新值转换为4位参数,设置不需要梯度,并应用到指定设备上
                    new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
            # 将模块中指定张量名称的参数更新为新值
            module._parameters[tensor_name] = new_value
    else:
        # 如果value为None,则将old_value转移到指定设备(device)
        if value is None:
            new_value = old_value.to(device)
        # 如果value是torch.Tensor类型,则将其移动到指定设备(device)
        elif isinstance(value, torch.Tensor):
            new_value = value.to(device)
        # 否则,将value转换为torch.tensor,并移动到指定设备(device)
        else:
            new_value = torch.tensor(value, device=device)

        # 如果是缓冲区(buffer),则更新module的_buffers字典
        if is_buffer:
            module._buffers[tensor_name] = new_value
        # 否则,将new_value封装为nn.Parameter,并将其存储在module的_parameters字典中
        else:
            new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
            module._parameters[tensor_name] = new_value
# 定义一个私有方法,用于递归替换模块的功能。返回替换后的模型和一个布尔值,指示替换是否成功。
def _replace_with_bnb_linear(
    model,
    modules_to_not_convert=None,
    current_key_name=None,
    quantization_config=None,
    has_been_replaced=False,
):
    """
    Private method that wraps the recursion for module replacement.

    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
    """
    return model, has_been_replaced


# 定义一个函数,用于将所有 `torch.nn.Linear` 模块替换为 `bnb.nn.Linear8bit` 模块。
# 这样可以实现使用混合 int8 精度运行模型,如论文 `LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale` 所述。
# 在运行此函数之前,请确保已正确安装支持正确 CUDA 版本的 `bitsandbytes` 库。
# `pip install -i https://test.pypi.org/simple/bitsandbytes`
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
    """
    A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
    library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
    bitsandbytes`

    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
    be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
    CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
    matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
    (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
    predictive degradation is possible for very large models (>=176B parameters).

    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
            Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
            for numerical stability reasons.
        current_key_name (`List[`str`]`, *optional*):
            An array to track the current key of the recursion. This is used to check whether the current key (part of
            it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
            `disk`).
    """
    modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
    # 调用私有方法 `_replace_with_bnb_linear` 进行实际的替换操作
    model, has_been_replaced = _replace_with_bnb_linear(
        model, modules_to_not_convert, current_key_name, quantization_config
    )

    # 如果没有替换成功,则记录警告信息,提示可能出现了问题
    if not has_been_replaced:
        logger.warning(
            "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
            " Please double check your model architecture, or submit an issue on github if you think this is"
            " a bug."
        )

    # 返回替换后的模型
    return model


# 为了向后兼容而定义的占位符注释
# 引发 FutureWarning 警告,提示 `replace_8bit_linear` 将来会被弃用,建议使用 `replace_with_bnb_linear` 替代
def replace_8bit_linear(*args, **kwargs):
    warnings.warn(
        "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
        FutureWarning,
    )
    # 调用 `replace_with_bnb_linear` 函数并返回其结果
    return replace_with_bnb_linear(*args, **kwargs)


# 为了向后兼容性而设立的函数
# 引发 FutureWarning 警告,提示 `set_module_8bit_tensor_to_device` 将来会被弃用,建议使用 `set_module_quantized_tensor_to_device` 替代
def set_module_8bit_tensor_to_device(*args, **kwargs):
    warnings.warn(
        "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
        FutureWarning,
    )
    # 调用 `set_module_quantized_tensor_to_device` 函数并返回其结果
    return set_module_quantized_tensor_to_device(*args, **kwargs)


def get_keys_to_not_convert(model):
    r"""
    获取模块的键列表,用于指定不转换为 int8 的模块。例如对于 CausalLM 模块,
    我们可能希望保持 lm_head 以完整精度,以确保数值稳定性。对于其他架构,
    我们可能希望保持模型的 tied weights。该函数将返回一个不需要转换为 int8 的模块键列表。

    Parameters:
    model (`torch.nn.Module`):
        输入的模型
    """
    # 复制模型并绑定权重,然后检查是否包含绑定的权重
    tied_model = deepcopy(model)  # 这个操作在 `init_empty_weights` 上下文管理器内部不会有额外开销
    tied_model.tie_weights()

    tied_params = find_tied_parameters(tied_model)
    # 兼容 Accelerate < 0.18
    if isinstance(tied_params, dict):
        tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys())
    else:
        tied_keys = sum(tied_params, [])
    has_tied_params = len(tied_keys) > 0

    # 如果没有绑定的权重,我们希望保持 lm_head(output_embedding)以完整精度
    if not has_tied_params:
        output_emb = model.get_output_embeddings()
        if output_emb is not None:
            list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
            return list_last_module

    # 否则,没有绑定的权重,也没有定义输出嵌入,简单地保持最后一个模块以完整精度
    list_modules = list(model.named_parameters())
    list_last_module = [list_modules[-1][0]]
    # 将最后一个模块与绑定的权重一起添加到列表中
    intersection = set(list_last_module) - set(tied_keys)
    list_untouched = list(set(tied_keys)) + list(intersection)

    # 从键中移除 ".weight" 和 ".bias"
    names_to_remove = [".weight", ".bias"]
    filtered_module_names = []
    for name in list_untouched:
        for name_to_remove in names_to_remove:
            if name_to_remove in name:
                name = name.replace(name_to_remove, "")
        filtered_module_names.append(name)

    return filtered_module_names

.\integrations\deepspeed.py

"""
Integration with Deepspeed
"""
# 引入必要的模块和函数
import copy  # 导入深拷贝函数
import importlib.metadata as importlib_metadata  # 导入元数据模块
import importlib.util  # 导入模块加载工具
import weakref  # 导入弱引用模块
from functools import partialmethod  # 导入partialmethod函数

# 导入依赖版本检查和工具函数
from ..dependency_versions_check import dep_version_check  
from ..utils import is_accelerate_available, is_torch_available, logging  

# 如果torch可用,则导入torch模块
if is_torch_available():
    import torch  

# 获取日志记录器对象
logger = logging.get_logger(__name__)


# 检查是否存在DeepSpeed库
def is_deepspeed_available():
    package_exists = importlib.util.find_spec("deepspeed") is not None

    # 检查确保导入的是DeepSpeed库,而非其他内容,同时尝试获取其版本信息和作者信息进行验证
    if package_exists:
        try:
            _ = importlib_metadata.metadata("deepspeed")
            return True
        except importlib_metadata.PackageNotFoundError:
            return False


# 如果同时安装了accelerate和deepspeed,则导入DeepSpeedConfig类
if is_accelerate_available() and is_deepspeed_available():
    from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig
else:
    # 如果accelerate不可用,则继承自dummy `object`,以确保可以导入本文件
    from builtins import object as DeepSpeedConfig


class HfDeepSpeedConfig(DeepSpeedConfig):
    """
    This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.

    A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
    things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
    it's important that this object remains alive while the program is still running.

    [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
    with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
    the DeepSpeed configuration is not modified in any way.

    Args:
        config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.

    """
    def __init__(self, config_file_or_dict):
        # 设置全局的弱引用对象
        set_hf_deepspeed_config(self)
        # 检查加速库 "accelerate" 的依赖版本
        dep_version_check("accelerate")
        # 检查深度加速库 "deepspeed" 的依赖版本
        dep_version_check("deepspeed")
        # 调用父类的初始化方法,传入配置文件或字典参数
        super().__init__(config_file_or_dict)
class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
    """
    The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the
    same lifespan as the latter.
    """

    def __init__(self, config_file_or_dict):
        # 调用父类的初始化方法,传入配置文件或字典
        super().__init__(config_file_or_dict)
        # 初始化私有变量 _dtype 为 None
        self._dtype = None
        # 初始化 mismatches 列表为空
        self.mismatches = []

    def dtype(self):
        # 如果 _dtype 为 None,则抛出数值错误异常
        if self._dtype is None:
            raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
        # 返回 _dtype 的值
        return self._dtype

    def is_auto(self, ds_key_long):
        # 获取指定长键名 ds_key_long 对应的值
        val = self.get_value(ds_key_long)
        # 如果值为 None,则返回 False
        if val is None:
            return False
        else:
            # 否则返回值是否为 "auto" 的布尔结果
            return val == "auto"

    def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
        """
        A utility method that massages the config file and can optionally verify that the values match.

        1. Replace "auto" values with `TrainingArguments` value.

        2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer
        config values and if mismatched add the entry to `self.mismatched` - will assert during
        `trainer_config_finalize` for one or more mismatches.
        """
        # 查找指定长键名 ds_key_long 对应的配置节点和键名
        config, ds_key = self.find_config_node(ds_key_long)
        # 如果配置节点不存在,则直接返回
        if config is None:
            return

        # 如果配置值为 "auto",则用 hf_val 替换它
        if config.get(ds_key) == "auto":
            config[ds_key] = hf_val
            return

        # 如果不需要匹配,则直接返回
        if not must_match:
            return

        # 否则,获取当前配置值和传入的 hf_val 进行比较
        ds_val = config.get(ds_key)
        # 如果值存在且与 hf_val 不匹配,则将不匹配信息添加到 self.mismatches 列表中
        if ds_val is not None and ds_val != hf_val:
            self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")

    # 定义 fill_only 方法为 fill_match 的偏函数,关闭 must_match 参数
    fill_only = partialmethod(fill_match, must_match=False)
    # 处理 `auto` 值的配置键,并依赖于模型的隐藏大小
    hidden_size_based_keys = [
        "zero_optimization.reduce_bucket_size",
        "zero_optimization.stage3_prefetch_bucket_size",
        "zero_optimization.stage3_param_persistence_threshold",
    ]
    # 筛选出需要使用 `auto` 值的配置键列表
    hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)]

    # 如果存在需要使用 `auto` 值的配置键
    if len(hidden_size_auto_keys) > 0:
        # 检查模型配置中是否有 `hidden_size` 属性
        if hasattr(model.config, "hidden_size"):
            hidden_size = model.config.hidden_size
        # 如果没有 `hidden_size` 属性,但有 `hidden_sizes` 属性,则选择最大的隐藏大小
        elif hasattr(model.config, "hidden_sizes"):
            hidden_size = max(model.config.hidden_sizes)
        else:
            # 如果模型配置文件既没有 `hidden_size` 也没有 `hidden_sizes` 条目,则引发错误
            raise ValueError(
                "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, "
                "therefore it's not possible to automatically fill out the following `auto` entries "
                f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
                "`auto` values for these keys with an integer value of your choice."
            )

        # 使用隐藏大小填充指定的配置键
        self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
        if self.is_zero3():
            # 如果是 Zero3 模式,根据模型配置自动分配优化配置值
            self.fill_only(
                "zero_optimization.stage3_prefetch_bucket_size",
                0.9 * hidden_size * hidden_size,
            )
            self.fill_only(
                "zero_optimization.stage3_param_persistence_threshold",
                10 * hidden_size,
            )

    # 填充调度器相关的参数值,匹配训练总步数和预热步数
    self.fill_match(
        "scheduler.params.total_num_steps",
        num_training_steps,
        "num_training_steps (calculated)",
    )
    self.fill_match(
        "scheduler.params.warmup_num_steps",
        args.get_warmup_steps(num_training_steps),
        "warmup_steps",
    )

    # 如果存在配置值不匹配的情况,引发 ValueError 异常
    if len(self.mismatches) > 0:
        mismatches = "\n".join(self.mismatches)
        raise ValueError(
            "Please correct the following DeepSpeed config values that mismatch TrainingArguments"
            f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
        )
# 将全局的 `_hf_deepspeed_config_weak_ref` 对象保持为全局状态,以便在 `TrainingArguments` 生命周期中的任何地方访问它。
_hf_deepspeed_config_weak_ref = None


def set_hf_deepspeed_config(hf_deepspeed_config_obj):
    # 这是一个特殊的弱引用全局对象,允许我们从没有简单方式获取 Deepspeed 配置的 API 中获取 Deepspeed 配置。
    # 当 `HfDeepSpeedConfig` 销毁时(即 `TrainingArguments` 销毁时),它会自动消失。
    global _hf_deepspeed_config_weak_ref
    _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)


def unset_hf_deepspeed_config():
    # 有助于单元测试确保全局状态不会泄露 - 从 `tearDown` 方法中调用。
    global _hf_deepspeed_config_weak_ref
    _hf_deepspeed_config_weak_ref = None


def is_deepspeed_zero3_enabled():
    if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
        return _hf_deepspeed_config_weak_ref().is_zero3()
    else:
        return False


def deepspeed_config():
    if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
        return _hf_deepspeed_config_weak_ref().config
    else:
        return None


def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
    """
    A convenience wrapper that deals with optimizer and lr scheduler configuration.
    """
    from accelerate.utils import DummyOptim, DummyScheduler

    config = hf_deepspeed_config.config

    # 如果在配置中发现了 `optimizer` 字段
    if "optimizer" in config:
        # 如果传递了 `--adafactor` 参数,则抛出值错误异常,因为 DeepSpeed 配置中只能配置一个优化器。
        if args.adafactor:
            raise ValueError(
                "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
                "Only one optimizer can be configured."
            )
        # 创建一个虚拟的优化器对象 `DummyOptim`,用于占位模型参数
        optimizer = DummyOptim(params=model_parameters)
    else:
        # 如果 DeepSpeed 配置开启了 Offload
        if hf_deepspeed_config.is_offload():
            logger.info(
                "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the"
                " custom optimizer has both CPU and GPU implementation (except LAMB)"
            )

        # 默认情况下,trainer 使用 AdamW 优化器
        # 创建一个优化器对象,根据需要可以使用其它优化器,但会使 `zero_allow_untested_optimizer` 无效。
        optimizer = trainer.create_optimizer()
        config["zero_allow_untested_optimizer"] = True

    lr_scheduler = None
    # 检查配置中是否存在 "scheduler" 键
    if "scheduler" in config:
        # 如果存在,则创建一个 DummyScheduler 的实例,使用给定的 optimizer
        lr_scheduler = DummyScheduler(optimizer)
    else:
        # 如果不存在 "scheduler" 键,则进入 else 分支
        # 检查 optimizer 是否是 DummyOptim 的实例
        if isinstance(optimizer, DummyOptim):

            # 定义一个内部函数 _lr_scheduler_callable,用于创建一个新的 lr_scheduler
            def _lr_scheduler_callable(optimizer):
                # 首先创建 trainer 的浅拷贝,以防后续修改影响原始的 trainer
                trainer_copy = copy.copy(trainer)
                # 在调用 _lr_scheduler_callable 时,trainer.lr_scheduler 已经被设置
                # 将其更新为 None,以便可以重新创建新的 scheduler
                trainer_copy.lr_scheduler = None
                # 使用 trainer_copy 创建一个新的 scheduler,并返回
                lr_scheduler = trainer_copy.create_scheduler(
                    num_training_steps=num_training_steps, optimizer=optimizer
                )
                return lr_scheduler

            # 创建一个 DummyScheduler 的实例,同时指定 lr_scheduler_callable 为 _lr_scheduler_callable 函数
            lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
        else:
            # 如果 optimizer 不是 DummyOptim 的实例,则调用 trainer 的 create_scheduler 方法创建一个新的 scheduler
            lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

    # 返回 optimizer 和相应的 lr_scheduler
    return optimizer, lr_scheduler
# 初始化 DeepSpeed,根据 Trainer 的参数更新 DeepSpeed 配置。
# 如果指定了 resume_from_checkpoint,则尝试从先前保存的检查点恢复。
# Args:
#   trainer: Trainer 对象
#   num_training_steps: 每个单 GPU 的训练步数
#   inference: 是否启动推断模式(无优化器和学习率调度器)
# Returns:
#   optimizer, lr_scheduler:优化器和学习率调度器实例

def deepspeed_init(trainer, num_training_steps, inference=False):
    from deepspeed.utils import logger as ds_logger

    model = trainer.model  # 获取 Trainer 对象中的模型
    args = trainer.args  # 获取 Trainer 对象中的参数

    hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config  # 获取 DeepSpeed 插件的配置

    # 更新 DeepSpeed 配置的 trainer 部分,包括 args、model 和 num_training_steps
    hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)

    # 设置 DeepSpeed 日志级别与 Trainer 一致
    ds_logger.setLevel(args.get_process_log_level())

    if inference:
        # 推断模式下仅支持 ZeRO Stage 3
        if not hf_deepspeed_config.is_zero3():
            raise ValueError("ZeRO inference 只适用于 ZeRO Stage 3,请调整配置")

        # 清除 optimizer 和 lr_scheduler 配置,因为推断模式下不需要
        hf_deepspeed_config.del_config_sub_tree("optimizer")
        hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
        optimizer, lr_scheduler = None, None
        model_parameters = None
    else:
        trainer.optimizer = None  # 重要:在重新初始化时将 optimizer 设为 None
        # 获取所有需要梯度更新的模型参数
        model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
        # 使用 DeepSpeed 提供的优化器和学习率调度器函数初始化 optimizer 和 lr_scheduler
        optimizer, lr_scheduler = deepspeed_optim_sched(
            trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
        )

    # 保留以便快速调试:
    # from pprint import pprint; pprint(config)

    return optimizer, lr_scheduler


# 加载 DeepSpeed 引擎的检查点
# 检查指定的 checkpoint_path 是否包含 DeepSpeed 的检查点文件
# Args:
#   deepspeed_engine: DeepSpeed 引擎对象
#   checkpoint_path: 检查点路径
#   load_module_strict: 是否严格加载模块(默认为 True)
def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
    # 用户可能试图从 model_path 恢复,该路径不一定包含 DeepSpeed 检查点,
    # 例如,示例只检查目录是否存在,并假定是恢复检查点而不是本地预训练权重。
    # 因此,这里我们检查路径是否包含类似 DeepSpeed 检查点的内容。
    import glob  # 导入 glob 模块,用于文件路径的通配符匹配
    
    deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))
    # 使用 glob 模块匹配符合模式 `{checkpoint_path}/global_step*` 的文件路径,并按字母顺序排序后存储在列表中
    
    if len(deepspeed_checkpoint_dirs) > 0:
        logger.info(f"Attempting to resume from {checkpoint_path}")
        # 如果找到符合条件的检查点目录,则记录信息尝试从指定路径 {checkpoint_path} 恢复训练
    
        load_path, _ = deepspeed_engine.load_checkpoint(
            checkpoint_path,
            load_module_strict=load_module_strict,
            load_optimizer_states=True,
            load_lr_scheduler_states=True,
        )
        # 调用 deepspeed_engine 的 load_checkpoint 方法加载检查点文件,更新优化器和学习率调度器的状态
    
        if load_path is None:
            raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
            # 如果加载路径为 None,则抛出值错误,指示未能从指定的检查点路径 {checkpoint_path} 恢复训练
    else:
        raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")
        # 如果未找到符合条件的检查点目录,则抛出值错误,指示在指定路径 {checkpoint_path} 下找不到有效的检查点

.\integrations\integration_utils.py

"""
Integrations with other Python libraries.
"""
# functools 模块提供了创建和使用偏函数(partial function)的工具
import functools
# importlib.metadata 提供了从安装的包中读取元数据的功能
import importlib.metadata
# importlib.util 提供了高级导入支持
import importlib.util
# json 是 Python 的 JSON 编解码器
import json
# numbers 包含 Python 中的数字抽象基类
import numbers
# os 提供了与操作系统交互的功能
import os
# pickle 实现了基于 Python 对象的序列化和反序列化
import pickle
# shutil 提供了高级文件操作
import shutil
# sys 提供了访问与解释器交互的变量和函数
import sys
# tempfile 提供了生成临时文件和目录的功能
import tempfile
# dataclasses 提供了用于定义数据类的工具
from dataclasses import asdict, fields
# pathlib 提供了面向对象的文件系统路径操作
from pathlib import Path
# typing 提供了类型提示支持
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union

# numpy 是科学计算的核心库
import numpy as np
# packaging.version 提供了版本管理功能
import packaging.version

# 导入本地模块中的 __version__,作为当前模块的版本号
from .. import __version__ as version
# 导入本地模块中的一些实用函数
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 如果 Torch 可用,则导入 Torch 模块
if is_torch_available():
    import torch

# 检查是否安装了 comet_ml,并且未禁用 COMET_MODE
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
    try:
        import comet_ml  # noqa: F401

        # 检查是否设置了 COMET_API_KEY
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            # 如果未设置 COMET_API_KEY,则发出警告
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False

# 检查是否安装了 neptune 或 neptune-client
_has_neptune = (
    importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
)

# 如果是类型检查模式且安装了 Neptune,记录 Neptune 的版本信息
if TYPE_CHECKING and _has_neptune:
    try:
        _neptune_version = importlib.metadata.version("neptune")
        logger.info(f"Neptune version {_neptune_version} available.")
    except importlib.metadata.PackageNotFoundError:
        try:
            _neptune_version = importlib.metadata.version("neptune-client")
            logger.info(f"Neptune-client version {_neptune_version} available.")
        except importlib.metadata.PackageNotFoundError:
            _has_neptune = False

# 导入本地模块中的一些回调函数和实用类
from ..trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
from ..training_args import ParallelMode  # noqa: E402
from ..utils import ENV_VARS_TRUE_VALUES, is_torch_xla_available  # noqa: E402

# Integration functions:

def is_wandb_available():
    # 检查是否安装了 WandB,任何非空值的 WANDB_DISABLED 变量都将禁用 WandB
    # 检查环境变量中是否定义了"WANDB_DISABLED"且其值在指定的真值列表中
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
        # 如果条件成立,输出警告日志,提醒使用者停止使用"WANDB_DISABLED"环境变量,因为它将在v5版本中被移除,并建议使用"--report_to"标志来控制日志结果的集成方式(例如使用"--report_to none"来禁用集成)。
        logger.warning(
            "Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
        # 返回False,表示WANDB(Weights and Biases)被禁用
        return False
    # 检查是否能够找到"wandb"模块的规范(spec)
    return importlib.util.find_spec("wandb") is not None
# 检查是否安装了 ClearML(formerly Trains),返回 True 或 False
def is_clearml_available():
    return importlib.util.find_spec("clearml") is not None


# 检查是否安装了 Comet,返回 _has_comet 的值
def is_comet_available():
    return _has_comet


# 检查是否安装了 TensorBoard 或 TensorBoardX,返回 True 或 False
def is_tensorboard_available():
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None


# 检查是否安装了 Optuna,返回 True 或 False
def is_optuna_available():
    return importlib.util.find_spec("optuna") is not None


# 检查是否安装了 Ray,返回 True 或 False
def is_ray_available():
    return importlib.util.find_spec("ray") is not None


# 检查是否安装了 Ray Tune,返回 True 或 False
def is_ray_tune_available():
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None


# 检查是否安装了 SigOpt,返回 True 或 False
def is_sigopt_available():
    return importlib.util.find_spec("sigopt") is not None


# 检查是否安装了 Azure ML,返回 True 或 False
def is_azureml_available():
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None


# 检查是否启用了 MLflow 并安装了相关依赖,返回 True 或 False
def is_mlflow_available():
    if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
        return False
    return importlib.util.find_spec("mlflow") is not None


# 检查是否同时安装了 Dagshub 和 MLflow,返回 True 或 False
def is_dagshub_available():
    return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]


# 检查是否安装了 Neptune,返回 _has_neptune 的值
def is_neptune_available():
    return _has_neptune


# 检查是否安装了 CodeCarbon,返回 True 或 False
def is_codecarbon_available():
    return importlib.util.find_spec("codecarbon") is not None


# 检查是否安装了 Flytekit,返回 True 或 False
def is_flytekit_available():
    return importlib.util.find_spec("flytekit") is not None


# 检查是否安装了 Flyte Deck Standard,返回 True 或 False
def is_flyte_deck_standard_available():
    if not is_flytekit_available():
        return False
    return importlib.util.find_spec("flytekitplugins.deck") is not None


# 检查是否安装了 DVC Live,返回 True 或 False
def is_dvclive_available():
    return importlib.util.find_spec("dvclive") is not None


# 根据提供的试验对象(trial)返回超参数字典,可能从 Optuna、Ray Tune、SigOpt 或 W&B 中获取
def hp_params(trial):
    if is_optuna_available():
        import optuna
        if isinstance(trial, optuna.Trial):
            return trial.params
    if is_ray_tune_available():
        if isinstance(trial, dict):
            return trial
    if is_sigopt_available():
        if isinstance(trial, dict):
            return trial
    if is_wandb_available():  # Assuming is_wandb_available function is defined elsewhere
        if isinstance(trial, dict):
            return trial
    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


# 使用 Optuna 进行超参数搜索,返回 BestRun 对象
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    import optuna
    # 检查当前进程索引是否为0,只有主进程执行以下逻辑
    if trainer.args.process_index == 0:

        # 定义内部函数 _objective,用于定义优化目标函数
        def _objective(trial, checkpoint_dir=None):
            # 初始化 checkpoint 为 None
            checkpoint = None
            # 如果提供了 checkpoint_dir,则查找其中以 PREFIX_CHECKPOINT_DIR 开头的子目录
            if checkpoint_dir:
                for subdir in os.listdir(checkpoint_dir):
                    if subdir.startswith(PREFIX_CHECKPOINT_DIR):
                        checkpoint = os.path.join(checkpoint_dir, subdir)
            
            # 将 trainer 的 objective 属性设为 None
            trainer.objective = None
            
            # 如果运行环境的 world_size 大于1,则进入分布式训练模式
            if trainer.args.world_size > 1:
                # 检查并确保当前并行模式为 ParallelMode.DISTRIBUTED
                if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
                    raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
                
                # 初始化分布式超参搜索
                trainer._hp_search_setup(trial)
                
                # 使用 torch.distributed 广播序列化后的 trainer.args 到所有进程
                torch.distributed.broadcast_object_list(pickle.dumps(trainer.args), src=0)
                
                # 开始训练,从 checkpoint 恢复
                trainer.train(resume_from_checkpoint=checkpoint)
            else:
                # 单机模式下,开始训练,从 checkpoint 恢复,传入 trial 对象
                trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
            
            # 如果在训练过程中没有进行评估,则执行评估过程
            if getattr(trainer, "objective", None) is None:
                metrics = trainer.evaluate()
                trainer.objective = trainer.compute_objective(metrics)
            
            # 返回训练的目标值
            return trainer.objective

        # 从 kwargs 中弹出 timeout 和 n_jobs 参数,分别设置默认值为 None 和 1
        timeout = kwargs.pop("timeout", None)
        n_jobs = kwargs.pop("n_jobs", 1)
        
        # 如果 direction 是 list 类型,则设置 directions 为 direction,否则为 None
        directions = direction if isinstance(direction, list) else None
        direction = None if directions is not None else direction
        
        # 创建一个新的 Optuna Study 对象 study,根据参数设置方向和其他参数
        study = optuna.create_study(direction=direction, directions=directions, **kwargs)
        
        # 使用 study 对象优化 _objective 函数,执行 n_trials 次优化,支持并行数为 n_jobs
        study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
        
        # 如果 study 不是多目标优化,则返回最佳运行结果
        if not study._is_multi_objective():
            best_trial = study.best_trial
            return BestRun(str(best_trial.number), best_trial.value, best_trial.params)
        else:
            # 如果是多目标优化,则返回多个最佳运行结果列表
            best_trials = study.best_trials
            return [BestRun(str(best.number), best.values, best.params) for best in best_trials]
    
    else:
        # 对于非主进程,执行以下逻辑,循环 n_trials 次
        for i in range(n_trials):
            # 将 trainer 的 objective 属性设为 None
            trainer.objective = None
            
            # 序列化 trainer.args 到 args_main_rank 列表
            args_main_rank = list(pickle.dumps(trainer.args))
            
            # 检查并确保当前并行模式为 ParallelMode.DISTRIBUTED
            if trainer.args.parallel_mode != ParallelMode.DISTRIBUTED:
                raise RuntimeError("only support DDP optuna HPO for ParallelMode.DISTRIBUTED currently.")
            
            # 使用 torch.distributed 广播 args_main_rank 到所有进程
            torch.distributed.broadcast_object_list(args_main_rank, src=0)
            
            # 将 args_main_rank 反序列化为 args 对象
            args = pickle.loads(bytes(args_main_rank))
            
            # 遍历 args 的属性,将除了 "local_rank" 外的键值对设置为 trainer.args 的属性
            for key, value in asdict(args).items():
                if key != "local_rank":
                    setattr(trainer.args, key, value)
            
            # 开始训练,从 checkpoint=None 恢复
            trainer.train(resume_from_checkpoint=None)
            
            # 如果在训练过程中没有进行评估,则执行评估过程
            if getattr(trainer, "objective", None) is None:
                metrics = trainer.evaluate()
                trainer.objective = trainer.compute_objective(metrics)
        
        # 非主进程返回 None
        return None
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    # 导入必要的库
    import ray
    import ray.train

    def _objective(trial: dict, local_trainer):
        try:
            # 尝试导入笔记本进度回调类
            from transformers.utils.notebook import NotebookProgressCallback

            # 如果存在 NotebookProgressCallback,则从 local_trainer 中移除并添加 ProgressCallback
            if local_trainer.pop_callback(NotebookProgressCallback):
                local_trainer.add_callback(ProgressCallback)
        except ModuleNotFoundError:
            # 如果模块未找到,则忽略
            pass

        # 将 local_trainer 的 objective 属性设置为 None
        local_trainer.objective = None

        # 获取 ray.train 的检查点对象
        checkpoint = ray.train.get_checkpoint()
        if checkpoint:
            # 如果有检查点,说明是恢复训练状态
            # 重置 local_trainer 的 objective 属性为 "objective",以解决训练完成后额外触发检查点的问题
            local_trainer.objective = "objective"

            # 获取检查点目录下的第一个检查点路径
            with checkpoint.as_directory() as checkpoint_dir:
                checkpoint_path = next(Path(checkpoint_dir).glob(f"{PREFIX_CHECKPOINT_DIR}*")).as_posix()
                # 从检查点路径恢复训练
                local_trainer.train(resume_from_checkpoint=checkpoint_path, trial=trial)
        else:
            # 如果没有检查点,则直接开始训练
            local_trainer.train(trial=trial)

        # 如果训练过程中未进行评估
        if getattr(local_trainer, "objective", None) is None:
            # 进行评估,并计算目标指标
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)

            # 更新 metrics,并标记为完成
            metrics.update({"objective": local_trainer.objective, "done": True})

            # 使用临时目录保存检查点,并创建 ray.train.Checkpoint 对象
            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                local_trainer._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
                checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
                # 报告评估结果和检查点
                ray.train.report(metrics, checkpoint=checkpoint)

    # 如果 trainer 的内存追踪未跳过内存指标的记录,则警告并设置为跳过
    if not trainer._memory_tracker.skip_memory_metrics:
        from ..trainer_utils import TrainerMemoryTracker

        logger.warning(
            "Memory tracking for your Trainer is currently "
            "enabled. Automatically disabling the memory tracker "
            "since the memory tracker is not serializable."
        )
        trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)

    # 在进行 ray 超参数搜索期间,模型和 TensorBoard writer 无法序列化,因此需要移除它们
    _tb_writer = trainer.pop_callback(TensorBoardCallback)
    trainer.model = None

    # 设置默认的 `resources_per_trial`。
    # 检查是否在 `kwargs` 参数中存在 `resources_per_trial` 键
    if "resources_per_trial" not in kwargs:
        # 如果不存在,则设置默认值为每个试验分配 1 个 CPU 和(如果可用)1 个 GPU
        kwargs["resources_per_trial"] = {"cpu": 1}
        # 如果训练器有 GPU,则将 GPU 数量设置为 1
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        # 生成资源信息字符串,用于日志记录
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        # 记录日志,说明未传递 `resources_per_trial` 参数,使用默认值
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )

    # 确保每个训练器实例只使用根据试验分配的 GPU
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial

    # 设置默认的进度报告器 `progress_reporter`
    if "progress_reporter" not in kwargs:
        # 导入所需的 CLIReporter 类
        from ray.tune import CLIReporter
        # 如果未指定 `progress_reporter`,则设置为 CLIReporter,并指定度量列为 "objective"
        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])

    # 如果 `kwargs` 中包含 `scheduler` 参数
    if "scheduler" in kwargs:
        # 导入可能需要中间报告的调度器类
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining

        # 检查调度器是否需要中间报告,并且检查是否开启了评估
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
            # 抛出运行时错误,要求开启评估以便调度器能够使用中间结果
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )

    # 使用 `ray.tune.with_parameters` 将 `_objective` 和本地训练器 `trainer` 组合成可调用的函数
    trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

    # 使用 `functools.wraps` 装饰 `trainable` 函数,以保留其元数据
    @functools.wraps(trainable)
    def dynamic_modules_import_trainable(*args, **kwargs):
        """
        Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.

        Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.

        Assumes that `_objective`, defined above, is a function.
        """
        # 检查是否存在 datasets 模块
        if is_datasets_available():
            # 导入 datasets.load 模块
            import datasets.load

            # 初始化动态模块的路径
            dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
            # 从路径加载动态模块
            spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
            datasets_modules = importlib.util.module_from_spec(spec)
            # 将加载的模块添加到系统模块列表中
            sys.modules[spec.name] = datasets_modules
            # 执行加载的模块
            spec.loader.exec_module(datasets_modules)
        
        # 返回通过 tune.with_parameters 调用的 trainable 函数的结果
        return trainable(*args, **kwargs)

    # 检查 trainable 函数是否具有特殊属性 __mixins__
    if hasattr(trainable, "__mixins__"):
        # 如果有,将 dynamic_modules_import_trainable 函数的 __mixins__ 属性设置为 trainable 函数的 __mixins__
        dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

    # 运行 ray.tune 的分布式调参任务
    analysis = ray.tune.run(
        dynamic_modules_import_trainable,
        config=trainer.hp_space(None),  # 使用 trainer.hp_space(None) 获取超参数空间配置
        num_samples=n_trials,  # 设置试验的样本数为 n_trials
        **kwargs,  # 其他传递的关键字参数
    )
    
    # 获取最佳试验的信息,基于指定的度量标准和方向
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
    
    # 构造最佳运行的对象,并传递相关信息
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config, analysis)
    
    # 如果存在 _tb_writer 对象,则将其作为回调添加到 trainer 中
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
    
    # 返回表示最佳运行的对象
    return best_run
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    # 导入检查Wandb是否可用的函数
    from ..integrations import is_wandb_available

    # 如果Wandb不可用,则抛出 ImportError 异常
    if not is_wandb_available():
        raise ImportError("This function needs wandb installed: `pip install wandb`")
    # 导入Wandb
    import wandb

    # 检查是否已经添加了 WandbCallback 到 trainer 的回调列表中
    reporting_to_wandb = False
    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, WandbCallback):
            reporting_to_wandb = True
            break
    # 如果没有添加,则将 WandbCallback 添加到 trainer 的回调列表中
    if not reporting_to_wandb:
        trainer.add_callback(WandbCallback())

    # 设置 trainer 的 report_to 属性为 ["wandb"],表明报告结果到 Wandb
    trainer.args.report_to = ["wandb"]

    # 初始化最佳试验的信息
    best_trial = {"run_id": None, "objective": None, "hyperparameters": None}

    # 从 kwargs 中获取 sweep_id、project、name、entity 等参数
    sweep_id = kwargs.pop("sweep_id", None)
    project = kwargs.pop("project", None)
    name = kwargs.pop("name", None)
    entity = kwargs.pop("entity", None)
    # 从 kwargs 中获取 metric 参数,默认为 "eval/loss"
    metric = kwargs.pop("metric", "eval/loss")

    # 从 trainer 获取超参数空间配置
    sweep_config = trainer.hp_space(None)

    # 设置超参数空间配置的优化目标和指标名称
    sweep_config["metric"]["goal"] = direction
    sweep_config["metric"]["name"] = metric

    # 如果提供了 name 参数,则设置超参数空间配置的名称
    if name:
        sweep_config["name"] = name
    # 定义一个名为 _objective 的函数
    def _objective():
        # 如果 wandb.run 存在,则使用当前运行的 wandb.run,否则初始化一个新的 wandb.run
        run = wandb.run if wandb.run else wandb.init()
        # 将训练器的试验名称设置为当前运行的名称
        trainer.state.trial_name = run.name
        # 更新配置,包括 "assignments": {} 和指定的度量标准 metric
        run.config.update({"assignments": {}, "metric": metric})
        # 获取当前的配置
        config = wandb.config

        # 将训练器的 objective 属性设置为 None
        trainer.objective = None

        # 开始训练过程,resume_from_checkpoint=None,并传递配置项作为试验的参数
        trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
        
        # 如果在训练循环中没有进行任何评估
        if getattr(trainer, "objective", None) is None:
            # 执行评估并获取指标 metrics
            metrics = trainer.evaluate()
            # 计算训练器的 objective 属性
            trainer.objective = trainer.compute_objective(metrics)
            # 重新编写日志格式的指标
            format_metrics = rewrite_logs(metrics)
            # 如果指定的度量标准不在重新编写后的指标中,则发出警告
            if metric not in format_metrics:
                logger.warning(
                    f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
                    f" metrics are {format_metrics.keys()}"
                )
        
        # 初始化 best_score 为 False
        best_score = False
        # 如果 best_trial["run_id"] 不为 None
        if best_trial["run_id"] is not None:
            # 根据 direction 的设置,比较当前训练器的 objective 属性和最佳试验的 objective 属性
            if direction == "minimize":
                best_score = trainer.objective < best_trial["objective"]
            elif direction == "maximize":
                best_score = trainer.objective > best_trial["objective"]

        # 如果 best_score 为 True 或者 best_trial["run_id"] 为 None
        if best_score or best_trial["run_id"] is None:
            # 更新最佳试验的 run_id、objective 和 hyperparameters
            best_trial["run_id"] = run.id
            best_trial["objective"] = trainer.objective
            best_trial["hyperparameters"] = dict(config)

        # 返回训练器的 objective 属性作为函数 _objective 的结果
        return trainer.objective

    # 如果 sweep_id 不存在,则使用给定的 sweep_config 创建一个新的 wandb sweep,并指定项目和实体
    sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
    # 输出当前的 wandb sweep id
    logger.info(f"wandb sweep id - {sweep_id}")
    # 使用 wandb agent 在指定的 sweep_id 上运行函数 _objective,并设置运行的次数为 n_trials
    wandb.agent(sweep_id, function=_objective, count=n_trials)

    # 返回包含最佳运行的 run_id、objective 和 hyperparameters 的 BestRun 对象
    return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])
# 定义函数,返回所有可用的报告集成列表
def get_available_reporting_integrations():
    # 初始化空列表,用于存储可用的集成
    integrations = []
    # 检查 Azure ML 是否可用,并且 MLflow 不可用时,添加 "azure_ml" 到集成列表
    if is_azureml_available() and not is_mlflow_available():
        integrations.append("azure_ml")
    # 如果 Comet ML 可用,添加 "comet_ml" 到集成列表
    if is_comet_available():
        integrations.append("comet_ml")
    # 如果 DagsHub 可用,添加 "dagshub" 到集成列表
    if is_dagshub_available():
        integrations.append("dagshub")
    # 如果 DVC Live 可用,添加 "dvclive" 到集成列表
    if is_dvclive_available():
        integrations.append("dvclive")
    # 如果 MLflow 可用,添加 "mlflow" 到集成列表
    if is_mlflow_available():
        integrations.append("mlflow")
    # 如果 Neptune 可用,添加 "neptune" 到集成列表
    if is_neptune_available():
        integrations.append("neptune")
    # 如果 TensorBoard 可用,添加 "tensorboard" 到集成列表
    if is_tensorboard_available():
        integrations.append("tensorboard")
    # 如果 Weights & Biases 可用,添加 "wandb" 到集成列表
    if is_wandb_available():
        integrations.append("wandb")
    # 如果 CodeCarbon 可用,添加 "codecarbon" 到集成列表
    if is_codecarbon_available():
        integrations.append("codecarbon")
    # 如果 ClearML 可用,添加 "clearml" 到集成列表
    if is_clearml_available():
        integrations.append("clearml")
    # 返回所有已添加的集成列表
    return integrations


# 定义函数,重写输入字典的键名规则并返回新字典
def rewrite_logs(d):
    # 初始化空字典,用于存储重写后的键值对
    new_d = {}
    # 设置评估前缀和测试前缀
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
    # 遍历输入字典的键值对
    for k, v in d.items():
        # 如果键以评估前缀开头,将键重写为 "eval/去除前缀后的键",并将原值赋给新字典对应键
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        # 如果键以测试前缀开头,将键重写为 "test/去除前缀后的键",并将原值赋给新字典对应键
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
        # 否则,将键重写为 "train/原键",并将原值赋给新字典对应键
        else:
            new_d["train/" + k] = v
    # 返回重写后的新字典
    return new_d


class TensorBoardCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).

    Args:
        tb_writer (`SummaryWriter`, *optional*):
            The writer to use. Will instantiate one if not set.
    """

    def __init__(self, tb_writer=None):
        # 检查 TensorBoard 是否可用
        has_tensorboard = is_tensorboard_available()
        # 如果 TensorBoard 不可用,抛出运行时错误
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
                " install tensorboardX."
            )
        # 如果 TensorBoard 可用
        if has_tensorboard:
            try:
                # 尝试导入 PyTorch 的 SummaryWriter
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401
                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    # 如果导入失败,尝试导入 tensorboardX 的 SummaryWriter
                    from tensorboardX import SummaryWriter
                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    # 如果都导入失败,设为 None
                    self._SummaryWriter = None
        else:
            # 如果 TensorBoard 不可用,设为 None
            self._SummaryWriter = None
        # 设置回调对象的写入器
        self.tb_writer = tb_writer

    # 初始化 TensorBoard 的 SummaryWriter
    def _init_summary_writer(self, args, log_dir=None):
        # 如果未提供日志目录,使用参数 args 的 logging_dir
        log_dir = log_dir or args.logging_dir
        # 如果 SummaryWriter 存在
        if self._SummaryWriter is not None:
            # 初始化回调对象的写入器
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
    # 如果不是全局进程的第一个进程,则直接返回,不执行后续操作
    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        # 初始化日志目录为 None
        log_dir = None

        # 如果是超参数搜索状态,根据试验名称组合日志目录路径
        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        # 如果 tb_writer 为空,则初始化摘要写入器
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)

        # 如果 tb_writer 不为空,则添加参数 args 的 JSON 字符串到 TensorBoard
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
            # 如果 kwargs 中包含 "model",则获取模型配置信息并添加到 TensorBoard
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)

    # 如果不是全局进程的第一个进程,则直接返回,不执行后续操作
    def on_log(self, args, state, control, logs=None, **kwargs):
        if not state.is_world_process_zero:
            return

        # 如果 tb_writer 为空,则初始化摘要写入器
        if self.tb_writer is None:
            self._init_summary_writer(args)

        # 如果 tb_writer 不为空,则重写日志并逐个处理
        if self.tb_writer is not None:
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                # 如果值为整数或浮点数,则将其作为标量添加到 TensorBoard
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    # 否则记录警告信息,指出不正确的调用方式并丢弃该属性
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute."
                    )
            # 刷新 TensorBoard 写入器
            self.tb_writer.flush()

    # 如果 tb_writer 存在,则关闭它,并将其置为 None
    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
            self.tb_writer = None
    """
    A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
    """

    # 初始化函数,检查是否安装了 wandb,如果未安装则抛出异常
    def __init__(self):
        # 检查是否安装了 wandb
        has_wandb = is_wandb_available()
        if not has_wandb:
            raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
        # 如果 wandb 可用,则引入 wandb 模块
        if has_wandb:
            import wandb
            self._wandb = wandb
        # 标记初始化状态为 False
        self._initialized = False
        
        # 根据环境变量设置是否记录模型,同时给出警告信息
        if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
            DeprecationWarning(
                f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
                "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
            )
            logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
            # 将记录模型的设置从环境变量中读取并转换为小写字符串
            self._log_model = "end"
        else:
            # 如果环境变量中未设置,则默认为 false
            self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()

    # 当训练开始时调用的函数,根据状态进行相应操作
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        # 如果未安装 wandb,则直接返回
        if self._wandb is None:
            return
        
        # 检查是否为超参数搜索,如果是,则结束 wandb 进程并重置初始化状态
        hp_search = state.is_hyper_param_search
        if hp_search:
            self._wandb.finish()
            self._initialized = False
            args.run_name = None
        
        # 如果未初始化,则进行设置
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
    # 在训练结束时触发的回调函数,用于上传模型和日志
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
        # 如果未初始化或者未配置 WandB,直接返回
        if self._wandb is None:
            return
        # 如果需要在结束或者检查点时记录模型,并且已经初始化,并且当前进程是主进程
        if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
            # 导入 Trainer 类用于模型保存
            from ..trainer import Trainer

            # 创建一个假的 Trainer 对象用于保存模型
            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            # 使用临时目录保存模型
            with tempfile.TemporaryDirectory() as temp_dir:
                # 将模型保存到临时目录
                fake_trainer.save_model(temp_dir)
                # 准备上传的元数据
                metadata = (
                    {
                        k: v
                        for k, v in dict(self._wandb.summary).items()
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
                # 记录上传日志
                logger.info("Logging model artifacts. ...")
                # 确定模型的名称
                model_name = (
                    f"model-{self._wandb.run.id}"
                    if (args.run_name is None or args.run_name == args.output_dir)
                    else f"model-{self._wandb.run.name}"
                )
                # 创建一个 WandB Artifact 对象,用于上传模型
                artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
                # 遍历临时目录下的所有文件
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        # 将每个文件添加到 Artifact 中
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
                # 上传 Artifact 到 WandB
                self._wandb.run.log_artifact(artifact)

    # 在日志记录时触发的回调函数,用于记录单值标量和非标量日志
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        # 定义单值标量日志的关键字列表
        single_value_scalars = [
            "train_runtime",
            "train_samples_per_second",
            "train_steps_per_second",
            "train_loss",
            "total_flos",
        ]

        # 如果未配置 WandB,直接返回
        if self._wandb is None:
            return
        # 如果未初始化,则进行初始化
        if not self._initialized:
            self.setup(args, state, model)
        # 如果当前进程是主进程
        if state.is_world_process_zero:
            # 遍历 logs 中的每个键值对
            for k, v in logs.items():
                # 如果键在单值标量列表中,则更新 WandB 的 summary
                if k in single_value_scalars:
                    self._wandb.run.summary[k] = v
            # 从 logs 中提取非标量日志,并进行重写处理
            non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
            non_scalar_logs = rewrite_logs(non_scalar_logs)
            # 记录非标量日志和全局步数到 WandB
            self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
    # 当保存操作触发时调用此方法,接收参数 args, state, control 和其他关键字参数 kwargs
    def on_save(self, args, state, control, **kwargs):
        # 检查日志模式是否为 "checkpoint",且对象已初始化,并且当前进程是主进程
        if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
            # 创建一个包含非私有数值的摘要元数据字典
            checkpoint_metadata = {
                k: v
                for k, v in dict(self._wandb.summary).items()
                if isinstance(v, numbers.Number) and not k.startswith("_")
            }

            # 根据全局步数创建检查点目录名
            ckpt_dir = f"checkpoint-{state.global_step}"
            # 构造完整的存储路径,放置检查点
            artifact_path = os.path.join(args.output_dir, ckpt_dir)
            # 记录日志,指示正在保存检查点工件
            logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
            # 根据运行名和ID创建检查点名称
            checkpoint_name = (
                f"checkpoint-{self._wandb.run.id}"
                if (args.run_name is None or args.run_name == args.output_dir)
                else f"checkpoint-{self._wandb.run.name}"
            )
            # 创建一个 W&B Artifact 对象,类型为 "model",并附带元数据
            artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
            # 将检查点目录及其内容添加到工件中
            artifact.add_dir(artifact_path)
            # 使用全局步数作为别名,将工件记录到 W&B
            self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
# 定义一个名为 `CometCallback` 的类,继承自 `TrainerCallback`
class CometCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
    """

    # 初始化方法,检查是否安装了 comet-ml 库,若未安装则抛出运行时错误
    def __init__(self):
        if not _has_comet:
            raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
        self._initialized = False  # 标记是否已初始化
        self._log_assets = False  # 标记是否记录训练资源

    # 设置方法,用于配置 Comet.ml 集成
    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
        - **COMET_MODE** (`str`, *optional*, defaults to `ONLINE`):
            Whether to create an online, offline experiment or disable Comet logging. Can be `OFFLINE`, `ONLINE`, or
            `DISABLED`.
        - **COMET_PROJECT_NAME** (`str`, *optional*):
            Comet project name for experiments.
        - **COMET_OFFLINE_DIRECTORY** (`str`, *optional*):
            Folder to use for saving offline experiments when `COMET_MODE` is `OFFLINE`.
        - **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
            Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or
            `FALSE`.

        For a number of configurable items in the environment, see
        [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
        """
        self._initialized = True  # 标记已初始化
        # 检查是否需要记录训练资源,根据环境变量 COMET_LOG_ASSETS 的设置决定
        log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
        if log_assets in {"TRUE", "1"}:
            self._log_assets = True
        # 如果是主进程(world_process_zero),根据环境变量 COMET_MODE 的设置创建相应的 Comet 实验
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            experiment = None
            experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            if comet_mode == "ONLINE":
                # 创建在线 Comet 实验,并记录相关信息
                experiment = comet_ml.Experiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                # 创建离线 Comet 实验,并记录相关信息
                experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            # 如果成功创建实验对象,则记录模型图和参数信息到 Comet
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    # 训练开始时的回调方法,如果未初始化则执行设置方法
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
    # 当日志事件触发时调用的方法,用于处理日志记录
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        # 如果对象尚未初始化,则执行初始化设置
        if not self._initialized:
            self.setup(args, state, model)
        # 如果当前进程是全局的第零个进程
        if state.is_world_process_zero:
            # 获取全局 Comet 实验对象
            experiment = comet_ml.config.get_global_experiment()
            # 如果实验对象不为空,则记录指标(metrics)
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")

    # 当训练结束时调用的方法
    def on_train_end(self, args, state, control, **kwargs):
        # 如果对象已经初始化,并且当前进程是全局的第零个进程
        if self._initialized and state.is_world_process_zero:
            # 获取全局 Comet 实验对象
            experiment = comet_ml.config.get_global_experiment()
            # 如果实验对象不为空
            if experiment is not None:
                # 如果设置了记录资产(_log_assets),则记录输出目录中的文件
                if self._log_assets is True:
                    logger.info("Logging checkpoints. This may take time.")
                    experiment.log_asset_folder(
                        args.output_dir, recursive=True, log_file_name=True, step=state.global_step
                    )
                # 结束 Comet 实验
                experiment.end()
class AzureMLCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
    """

    def __init__(self, azureml_run=None):
        # 检查是否安装了 AzureML SDK,如果没有则抛出运行时错误
        if not is_azureml_available():
            raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
        # 导入 AzureML 的 Run 类
        from azureml.core.run import Run

        # 如果未提供 azureml_run 并且是主进程,则获取当前运行的上下文
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
        # 如果有提供 azureml_run 并且是主进程
        if self.azureml_run and state.is_world_process_zero:
            # 遍历 logs 字典,将其键值对作为日志项传递给 AzureML 的 run 对象
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


class MLflowCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
    environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
    """

    def __init__(self):
        # 检查是否安装了 MLflow,如果没有则抛出运行时错误
        if not is_mlflow_available():
            raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
        import mlflow

        # 设置 MLflow 相关的最大参数值长度和每批次参数标签的最大数
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

        self._initialized = False
        self._auto_end_run = False
        self._log_artifacts = False
        self._ml_flow = mlflow

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        # 如果尚未初始化,则进行设置
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        # 如果尚未初始化,则进行设置
        if not self._initialized:
            self.setup(args, state, model)
        # 如果是主进程
        if state.is_world_process_zero:
            metrics = {}
            # 遍历 logs 字典,将数值类型的值作为指标(metrics)记录
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    metrics[k] = v
                else:
                    # 如果值不是数值类型,则记录警告日志并忽略该值
                    logger.warning(
                        f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
                        "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
                    )

            # 如果是异步日志,则异步记录指标(metrics)
            if self._async_log:
                self._ml_flow.log_metrics(metrics=metrics, step=state.global_step, synchronous=False)
            else:
                self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)

    def on_train_end(self, args, state, control, **kwargs):
        # 如果已初始化并且是主进程,则根据设置自动结束 MLflow 的运行
        if self._initialized and state.is_world_process_zero:
            if self._auto_end_run and self._ml_flow.active_run():
                self._ml_flow.end_run()
    # 当保存操作触发时执行的方法,接收参数 args, state, control 和 kwargs
    def on_save(self, args, state, control, **kwargs):
        # 检查对象是否已初始化,并且当前进程是世界中的主进程,同时日志记录已启用
        if self._initialized and state.is_world_process_zero and self._log_artifacts:
            # 构建检查点目录名称,使用全局步数来唯一标识
            ckpt_dir = f"checkpoint-{state.global_step}"
            # 构建检查点的完整路径,基于指定的输出目录
            artifact_path = os.path.join(args.output_dir, ckpt_dir)
            # 记录信息日志,指示正在将检查点数据记录为 artifacts
            logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. This may take time.")
            # 使用 MLflow 的 pyfunc 接口记录模型
            self._ml_flow.pyfunc.log_model(
                ckpt_dir,  # 记录模型时使用的名称
                artifacts={"model_path": artifact_path},  # 附加的 artifacts,指定模型路径
                python_model=self._ml_flow.pyfunc.PythonModel(),  # 使用的 Python 模型
            )

    # 析构函数,在对象被销毁时调用
    def __del__(self):
        # 如果设置了自动结束运行,并且 MLflow 的活跃运行状态是可调用的且不为空
        if (
            self._auto_end_run
            and callable(getattr(self._ml_flow, "active_run", None))
            and self._ml_flow.active_run() is not None
        ):
            # 结束当前的 MLflow 运行
            self._ml_flow.end_run()
class DagsHubCallback(MLflowCallback):
    """
    A [`TrainerCallback`] that logs to [DagsHub](https://dagshub.com/). Extends [`MLflowCallback`]
    """

    def __init__(self):
        super().__init__()
        # 检查是否安装了 DagsHub 相关库
        if not is_dagshub_available():
            raise ImportError("DagsHubCallback requires dagshub to be installed. Run `pip install dagshub`.")

        # 导入 DagsHub 的 Repo 类
        from dagshub.upload import Repo
        self.Repo = Repo

    def setup(self, *args, **kwargs):
        """
        Setup the DagsHub's Logging integration.

        Environment:
        - **HF_DAGSHUB_LOG_ARTIFACTS** (`str`, *optional*):
                Whether to save the data and model artifacts for the experiment. Default to `False`.
        """

        # 检查是否要记录数据和模型的 artifacts
        self.log_artifacts = os.getenv("HF_DAGSHUB_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
        # 获取模型名称,如果未指定则默认为 "main"
        self.name = os.getenv("HF_DAGSHUB_MODEL_NAME") or "main"
        # 获取 MLflow 的远程跟踪 URI
        self.remote = os.getenv("MLFLOW_TRACKING_URI")
        # 根据远程跟踪 URI 创建 DagsHub Repo 对象
        self.repo = self.Repo(
            owner=self.remote.split(os.sep)[-2],
            name=self.remote.split(os.sep)[-1].split(".")[0],
            branch=os.getenv("BRANCH") or "main",
        )
        # 设置路径为 "artifacts"
        self.path = Path("artifacts")

        # 如果未设置远程跟踪 URI,则抛出运行时错误
        if self.remote is None:
            raise RuntimeError(
                "DagsHubCallback requires the `MLFLOW_TRACKING_URI` environment variable to be set. Did you run"
                " `dagshub.init()`?"
            )

        # 调用父类的 setup 方法
        super().setup(*args, **kwargs)

    def on_train_end(self, args, state, control, **kwargs):
        # 如果要记录 artifacts
        if self.log_artifacts:
            # 如果存在 train_dataloader 属性,则保存数据集到 "dataset.pt" 文件
            if getattr(self, "train_dataloader", None):
                torch.save(self.train_dataloader.dataset, os.path.join(args.output_dir, "dataset.pt"))

            # 将输出目录下的内容添加到 DagsHub Repo 的指定目录下
            self.repo.directory(str(self.path)).add_dir(args.output_dir)


class NeptuneMissingConfiguration(Exception):
    def __init__(self):
        super().__init__(
            """
        ------ Unsupported ---- We were not able to create new runs. You provided a custom Neptune run to
        `NeptuneCallback` with the `run` argument. For the integration to work fully, provide your `api_token` and
        `project` by saving them as environment variables or passing them to the callback.
        """
        )


class NeptuneCallback(TrainerCallback):
    """TrainerCallback that sends the logs to [Neptune](https://app.neptune.ai).
    # NeptuneLogger 类定义,用于集成 Transformers 框架与 Neptune 平台
    class NeptuneLogger:
    
        # 集成版本键,用于 Neptune 运行日志中的源代码路径
        integration_version_key = "source_code/integrations/transformers"
        # 模型参数键,用于 Neptune 运行日志中的模型参数
        model_parameters_key = "model_parameters"
        # 实验名称键,用于 Neptune 运行日志中的实验名称
        trial_name_key = "trial"
        # 实验参数键,用于 Neptune 运行日志中的实验参数
        trial_params_key = "trial_params"
        # 训练器参数键,用于 Neptune 运行日志中的训练器参数
        trainer_parameters_key = "trainer_parameters"
        # 扁平化指标,用于 Neptune 运行日志中的扁平化指标记录
        flat_metrics = {"train/epoch"}
    
        # NeptuneLogger 类的初始化方法
        def __init__(
            self,
            *,
            api_token: Optional[str] = None,  # Neptune API token,可选参数,用于身份验证
            project: Optional[str] = None,  # Neptune 项目名称,可选参数,指定要记录的项目
            name: Optional[str] = None,  # 自定义运行名称,可选参数,指定 Neptune 运行的名称
            base_namespace: str = "finetuning",  # 基础命名空间,默认为 "finetuning",用于 Neptune 日志的根命名空间
            run=None,  # Neptune 运行对象,可选参数,如果要继续记录到现有运行中
            log_parameters: bool = True,  # 是否记录训练器参数和模型参数的标志,可选参数,默认为 True
            log_checkpoints: Optional[str] = None,  # 检查点记录选项,可选参数,指定何时上传检查点文件
            **neptune_run_kwargs,  # 其他 Neptune 初始化函数的关键字参数,用于创建新的 Neptune 运行时
        ):
    ):
        # 检查 Neptune 是否可用,如果不可用则抛出 ValueError 异常
        if not is_neptune_available():
            raise ValueError(
                "NeptuneCallback requires the Neptune client library to be installed. "
                "To install the library, run `pip install neptune`."
            )

        try:
            # 尝试导入 Neptune 相关模块
            from neptune import Run
            from neptune.internal.utils import verify_type
        except ImportError:
            # 如果导入失败,则尝试从新路径导入 Neptune 相关模块
            from neptune.new.internal.utils import verify_type
            from neptune.new.metadata_containers.run import Run

        # 验证参数类型
        verify_type("api_token", api_token, (str, type(None)))
        verify_type("project", project, (str, type(None)))
        verify_type("name", name, (str, type(None)))
        verify_type("base_namespace", base_namespace, str)
        verify_type("run", run, (Run, type(None)))
        verify_type("log_parameters", log_parameters, bool)
        verify_type("log_checkpoints", log_checkpoints, (str, type(None)))

        # 设置内部变量
        self._base_namespace_path = base_namespace
        self._log_parameters = log_parameters
        self._log_checkpoints = log_checkpoints
        self._initial_run: Optional[Run] = run

        # 初始化变量
        self._run = None
        self._is_monitoring_run = False
        self._run_id = None
        self._force_reset_monitoring_run = False
        self._init_run_kwargs = {"api_token": api_token, "project": project, "name": name, **neptune_run_kwargs}

        self._volatile_checkpoints_dir = None
        self._should_upload_checkpoint = self._log_checkpoints is not None
        self._recent_checkpoint_path = None

        # 根据 log_checkpoints 的值设置目标检查点命名空间和是否清理最近上传的检查点
        if self._log_checkpoints in {"last", "best"}:
            self._target_checkpoints_namespace = f"checkpoints/{self._log_checkpoints}"
            self._should_clean_recently_uploaded_checkpoint = True
        else:
            self._target_checkpoints_namespace = "checkpoints"
            self._should_clean_recently_uploaded_checkpoint = False

    # 如果运行实例存在,则停止该运行实例
    def _stop_run_if_exists(self):
        if self._run:
            self._run.stop()
            del self._run
            self._run = None

    # 初始化 Neptune 运行实例
    def _initialize_run(self, **additional_neptune_kwargs):
        try:
            # 尝试从 neptune 包中导入 init_run 和异常处理类
            from neptune import init_run
            from neptune.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException
        except ImportError:
            # 如果导入失败,则尝试从新路径导入 init_run 和异常处理类
            from neptune.new import init_run
            from neptune.new.exceptions import NeptuneMissingApiTokenException, NeptuneMissingProjectNameException

        # 停止已存在的运行实例
        self._stop_run_if_exists()

        try:
            # 创建运行实例的参数集合
            run_params = additional_neptune_kwargs.copy()
            run_params.update(self._init_run_kwargs)
            # 使用参数初始化运行实例,并获取运行实例的 ID
            self._run = init_run(**run_params)
            self._run_id = self._run["sys/id"].fetch()
        except (NeptuneMissingProjectNameException, NeptuneMissingApiTokenException) as e:
            # 如果缺少项目名或 API token,则抛出 NeptuneMissingConfiguration 异常
            raise NeptuneMissingConfiguration() from e
    # 将初始运行设置为当前运行,并开启监控模式
    def _use_initial_run(self):
        self._run = self._initial_run
        self._is_monitoring_run = True
        self._run_id = self._run["sys/id"].fetch()
        self._initial_run = None

    # 确保存在带监控的运行环境
    def _ensure_run_with_monitoring(self):
        if self._initial_run is not None:
            # 如果存在初始运行,则使用初始运行
            self._use_initial_run()
        else:
            if not self._force_reset_monitoring_run and self._is_monitoring_run:
                return

            if self._run and not self._is_monitoring_run and not self._force_reset_monitoring_run:
                # 如果存在运行环境但未开启监控,则重新初始化运行并开启监控
                self._initialize_run(with_id=self._run_id)
                self._is_monitoring_run = True
            else:
                # 否则,初始化一个新的运行环境
                self._initialize_run()
                self._force_reset_monitoring_run = False

    # 确保至少存在一个不带监控的运行环境
    def _ensure_at_least_run_without_monitoring(self):
        if self._initial_run is not None:
            # 如果存在初始运行,则使用初始运行
            self._use_initial_run()
        else:
            if not self._run:
                # 如果没有运行环境,则初始化一个新的运行环境,不捕获 stdout、stderr、硬件指标和 traceback
                self._initialize_run(
                    with_id=self._run_id,
                    capture_stdout=False,
                    capture_stderr=False,
                    capture_hardware_metrics=False,
                    capture_traceback=False,
                )
                self._is_monitoring_run = False

    # 返回当前运行环境
    @property
    def run(self):
        if self._run is None:
            self._ensure_at_least_run_without_monitoring()
        return self._run

    # 返回运行环境的元数据命名空间
    @property
    def _metadata_namespace(self):
        return self.run[self._base_namespace_path]

    # 记录集成版本号到运行环境中
    def _log_integration_version(self):
        self.run[NeptuneCallback.integration_version_key] = version

    # 记录训练器参数到运行环境的元数据命名空间中
    def _log_trainer_parameters(self, args):
        self._metadata_namespace[NeptuneCallback.trainer_parameters_key] = args.to_sanitized_dict()

    # 记录模型参数到运行环境的元数据命名空间中
    def _log_model_parameters(self, model):
        from neptune.utils import stringify_unsupported

        if model and hasattr(model, "config") and model.config is not None:
            self._metadata_namespace[NeptuneCallback.model_parameters_key] = stringify_unsupported(
                model.config.to_dict()
            )

    # 记录超参数搜索参数到运行环境的元数据命名空间中
    def _log_hyper_param_search_parameters(self, state):
        if state and hasattr(state, "trial_name"):
            self._metadata_namespace[NeptuneCallback.trial_name_key] = state.trial_name

        if state and hasattr(state, "trial_params") and state.trial_params is not None:
            self._metadata_namespace[NeptuneCallback.trial_params_key] = state.trial_params
    # 将源目录和检查点路径合并成目标路径
    target_path = relative_path = os.path.join(source_directory, checkpoint)

    # 如果存在易失性检查点目录
    if self._volatile_checkpoints_dir is not None:
        # 构建一致性检查点路径
        consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
        try:
            # 从相对路径中移除开头的 ../,并去掉开头的路径分隔符
            cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
            copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
            # 复制整个目录树到一致性检查点路径
            shutil.copytree(relative_path, copy_path)
            # 更新目标路径为一致性检查点路径
            target_path = consistent_checkpoint_path
        except IOError as e:
            # 如果复制过程中出现 I/O 异常,则记录警告信息
            logger.warning(
                "NeptuneCallback was unable to made a copy of checkpoint due to I/O exception: '{}'. "
                "Could fail trying to upload.".format(e)
            )

    # 将目标路径中的文件上传到 Neptune 中
    self._metadata_namespace[self._target_checkpoints_namespace].upload_files(target_path)

    # 如果需要清理最近上传的检查点,并且最近的检查点路径不为 None,则删除它
    if self._should_clean_recently_uploaded_checkpoint and self._recent_checkpoint_path is not None:
        self._metadata_namespace[self._target_checkpoints_namespace].delete_files(self._recent_checkpoint_path)

    # 更新最近的检查点路径为相对路径
    self._recent_checkpoint_path = relative_path

def on_init_end(self, args, state, control, **kwargs):
    # 初始化易失性检查点目录为 None
    self._volatile_checkpoints_dir = None
    # 如果需要记录检查点,并且要求覆盖输出目录或设置了保存总数限制,则创建一个临时目录用于存储检查点
    if self._log_checkpoints and (args.overwrite_output_dir or args.save_total_limit is not None):
        self._volatile_checkpoints_dir = tempfile.TemporaryDirectory().name

    # 如果要求记录最佳检查点但未设置在训练结束时加载最佳模型,则引发 ValueError 异常
    if self._log_checkpoints == "best" and not args.load_best_model_at_end:
        raise ValueError("To save the best model checkpoint, the load_best_model_at_end argument must be enabled.")

def on_train_begin(self, args, state, control, model=None, **kwargs):
    # 如果不是全局进程的主进程,则直接返回
    if not state.is_world_process_zero:
        return

    # 确保在监控下运行
    self._ensure_run_with_monitoring()
    # 强制重置监控运行状态
    self._force_reset_monitoring_run = True

    # 记录集成版本信息
    self._log_integration_version()
    # 如果需要记录参数,则记录训练器参数和模型参数
    if self._log_parameters:
        self._log_trainer_parameters(args)
        self._log_model_parameters(model)

    # 如果是超参数搜索状态,则记录超参数搜索参数
    if state.is_hyper_param_search:
        self._log_hyper_param_search_parameters(state)

def on_train_end(self, args, state, control, **kwargs):
    # 如果存在运行,则停止该运行
    self._stop_run_if_exists()

def __del__(self):
    # 如果存在易失性检查点目录,则删除该目录及其内容,忽略所有错误
    if self._volatile_checkpoints_dir is not None:
        shutil.rmtree(self._volatile_checkpoints_dir, ignore_errors=True)

    # 停止 Neptune 运行,如果存在的话
    self._stop_run_if_exists()

def on_save(self, args, state, control, **kwargs):
    # 如果需要上传检查点,则记录模型检查点
    if self._should_upload_checkpoint:
        self._log_model_checkpoint(args.output_dir, f"checkpoint-{state.global_step}")
    # 定义一个方法,处理评估时的回调函数
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # 如果设置了日志保存最佳模型
        if self._log_checkpoints == "best":
            # 获取用于最佳模型判定的指标名称
            best_metric_name = args.metric_for_best_model
            # 如果指标名称不以"eval_"开头,则添加前缀"eval_"
            if not best_metric_name.startswith("eval_"):
                best_metric_name = f"eval_{best_metric_name}"

            # 获取指定名称的指标值
            metric_value = metrics.get(best_metric_name)

            # 根据参数指定的条件判断函数选择比较操作符
            operator = np.greater if args.greater_is_better else np.less

            # 判断是否应上传检查点,判断标准是当前指标值是否优于之前保存的最佳指标值
            self._should_upload_checkpoint = state.best_metric is None or operator(metric_value, state.best_metric)

    # 类方法:获取训练器关联的 NeptuneCallback 实例的运行配置
    @classmethod
    def get_run(cls, trainer):
        # 遍历训练器回调处理程序中的回调函数
        for callback in trainer.callback_handler.callbacks:
            # 如果回调函数是 NeptuneCallback 的实例,则返回其运行配置
            if isinstance(callback, cls):
                return callback.run

        # 如果没有 NeptuneCallback 配置,抛出异常
        raise Exception("The trainer doesn't have a NeptuneCallback configured.")

    # 定义一个方法,处理记录日志时的回调函数
    def on_log(self, args, state, control, logs: Optional[Dict[str, float]] = None, **kwargs):
        # 如果不是全局进程的主进程,直接返回
        if not state.is_world_process_zero:
            return

        # 如果有日志内容
        if logs is not None:
            # 对每个重写后的日志项进行处理
            for name, value in rewrite_logs(logs).items():
                # 如果值是整数或浮点数
                if isinstance(value, (int, float)):
                    # 如果日志名称在 NeptuneCallback 的平坦指标中
                    if name in NeptuneCallback.flat_metrics:
                        # 将值记录到元数据命名空间中
                        self._metadata_namespace[name] = value
                    else:
                        # 否则,将值记录到元数据命名空间中并指定步骤为全局步骤数
                        self._metadata_namespace[name].log(value, step=state.global_step)
# 定义一个名为 CodeCarbonCallback 的类,继承自 TrainerCallback 类
class CodeCarbonCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that tracks the CO2 emission of training.
    """

    # 初始化方法
    def __init__(self):
        # 检查是否安装了 codecarbon 库,若未安装则引发运行时错误
        if not is_codecarbon_available():
            raise RuntimeError(
                "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
            )
        # 导入 codecarbon 库
        import codecarbon

        # 将 codecarbon 模块赋值给 self._codecarbon
        self._codecarbon = codecarbon
        # 初始化追踪器为 None
        self.tracker = None

    # 当初始化结束时触发的回调方法
    def on_init_end(self, args, state, control, **kwargs):
        # 如果追踪器为 None 并且是本地进程的第零号进程
        if self.tracker is None and state.is_local_process_zero:
            # 使用指定的输出目录创建 CO2 排放追踪器对象
            self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)

    # 当训练开始时触发的回调方法
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        # 如果追踪器存在并且是本地进程的第零号进程
        if self.tracker and state.is_local_process_zero:
            # 启动 CO2 排放追踪器
            self.tracker.start()

    # 当训练结束时触发的回调方法
    def on_train_end(self, args, state, control, **kwargs):
        # 如果追踪器存在并且是本地进程的第零号进程
        if self.tracker and state.is_local_process_zero:
            # 停止 CO2 排放追踪器
            self.tracker.stop()


# 定义一个名为 ClearMLCallback 的类,继承自 TrainerCallback 类
class ClearMLCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [ClearML](https://clear.ml/).

    Environment:
    - **CLEARML_PROJECT** (`str`, *optional*, defaults to `HuggingFace Transformers`):
        ClearML project name.
    - **CLEARML_TASK** (`str`, *optional*, defaults to `Trainer`):
        ClearML task name.
    - **CLEARML_LOG_MODEL** (`bool`, *optional*, defaults to `False`):
        Whether to log models as artifacts during training.
    """

    # 类级别的属性
    log_suffix = ""

    _hparams_section = "Transformers"
    _model_config_section = "Model Configuration"
    _ignore_hparams_overrides = "_ignore_hparams_ui_overrides_"
    _ignoge_model_config_overrides = "_ignore_model_config_ui_overrides_"
    _model_config_description = "The configuration of model number {}."
    _model_config_description_note = (
        "Note that, when cloning this task and running it remotely,"
        " the configuration might be applied to another model instead of this one."
        " To avoid this, initialize the task externally by calling `Task.init`"
        " before the `ClearMLCallback` is instantiated."
    )
    _train_run_counter = 0
    _model_connect_counter = 0
    _task_created_in_callback = False
    _should_close_on_train_end = None

    # 初始化方法
    def __init__(self):
        # 检查是否安装了 clearml 库,若未安装则引发运行时错误
        if is_clearml_available():
            import clearml

            # 导入 clearml 库
            self._clearml = clearml
        else:
            raise RuntimeError("ClearMLCallback requires 'clearml' to be installed. Run `pip install clearml`.")

        # 初始化标志为 False
        self._initialized = False
        # 初始化 ClearML 任务为 None
        self._clearml_task = None

        # 初始化日志模型为 False
        self._log_model = False
        # 初始化检查点保存列表为空列表
        self._checkpoints_saved = []
    # 当训练开始时调用的方法,初始化训练过程中需要用到的参数和模型
    def on_train_begin(self, args, state, control, model=None, tokenizer=None, **kwargs):
        # 如果未初始化 ClearML,直接返回
        if self._clearml is None:
            return
        # 初始化一个空列表来存储保存的检查点文件名
        self._checkpoints_saved = []
        # 如果当前训练是超参数搜索,则标记为未初始化状态
        if state.is_hyper_param_search:
            self._initialized = False
        # 如果未初始化,调用 setup 方法来设置参数、模型和分词器等
        if not self._initialized:
            self.setup(args, state, model, tokenizer, **kwargs)
    
    # 当训练结束时调用的方法,用于清理和关闭 ClearML 相关的任务和计数器
    def on_train_end(self, args, state, control, **kwargs):
        # 如果应该在训练结束时关闭 ClearML 任务,则关闭当前任务
        if ClearMLCallback._should_close_on_train_end:
            self._clearml_task.close()
            # 重置训练运行计数器为零
            ClearMLCallback._train_run_counter = 0
    # 定义一个方法,用于处理日志信息,将其发送到 ClearML 平台
    def on_log(self, args, state, control, model=None, tokenizer=None, logs=None, **kwargs):
        # 如果 ClearML 客户端未初始化,则直接返回
        if self._clearml is None:
            return
        # 如果未初始化,则进行初始化设置
        if not self._initialized:
            self.setup(args, state, model, tokenizer, **kwargs)
        # 如果是全局进程的第一个进程(通常是主进程)
        if state.is_world_process_zero:
            # 定义评估数据的前缀和长度
            eval_prefix = "eval_"
            eval_prefix_len = len(eval_prefix)
            # 定义测试数据的前缀和长度
            test_prefix = "test_"
            test_prefix_len = len(test_prefix)
            # 定义单值标量的列表,这些值通常用于表示单个标量的日志信息
            single_value_scalars = [
                "train_runtime",
                "train_samples_per_second",
                "train_steps_per_second",
                "train_loss",
                "total_flos",
                "epoch",
            ]
            # 遍历日志中的每个键值对
            for k, v in logs.items():
                # 如果值 v 是整数或浮点数
                if isinstance(v, (int, float)):
                    # 如果键 k 在单值标量列表中,则将其作为单值报告到 ClearML
                    if k in single_value_scalars:
                        self._clearml_task.get_logger().report_single_value(
                            name=k + ClearMLCallback.log_suffix, value=v
                        )
                    # 如果键 k 以评估数据前缀开头,则将其作为评估数据报告到 ClearML
                    elif k.startswith(eval_prefix):
                        self._clearml_task.get_logger().report_scalar(
                            title="eval" + ClearMLCallback.log_suffix,
                            series=k[eval_prefix_len:],
                            value=v,
                            iteration=state.global_step,
                        )
                    # 如果键 k 以测试数据前缀开头,则将其作为测试数据报告到 ClearML
                    elif k.startswith(test_prefix):
                        self._clearml_task.get_logger().report_scalar(
                            title="test" + ClearMLCallback.log_suffix,
                            series=k[test_prefix_len:],
                            value=v,
                            iteration=state.global_step,
                        )
                    # 否则,将其作为训练数据报告到 ClearML
                    else:
                        self._clearml_task.get_logger().report_scalar(
                            title="train" + ClearMLCallback.log_suffix,
                            series=k,
                            value=v,
                            iteration=state.global_step,
                        )
                else:
                    # 如果值 v 的类型不是整数或浮点数,则记录警告信息
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
                        "This invocation of ClearML logger's  report_scalar() "
                        "is incorrect so we dropped this attribute."
                    )
    # 定义一个保存模型回调函数,根据特定条件执行保存操作
    def on_save(self, args, state, control, **kwargs):
        # 如果启用模型日志、存在 ClearML 任务并且当前进程为主进程
        if self._log_model and self._clearml_task and state.is_world_process_zero:
            # 根据全局步数创建检查点目录名
            ckpt_dir = f"checkpoint-{state.global_step}"
            # 构建检查点在文件系统中的路径
            artifact_path = os.path.join(args.output_dir, ckpt_dir)
            # 定义保存的模型名,包括 ClearML 日志后缀
            name = ckpt_dir + ClearMLCallback.log_suffix
            # 输出日志,指示正在记录检查点信息
            logger.info(f"Logging checkpoint artifact `{name}`. This may take some time.")
            # 创建 ClearML 的 OutputModel 对象,关联到当前任务并设置名称
            output_model = self._clearml.OutputModel(task=self._clearml_task, name=name)
            output_model.connect(task=self._clearml_task, name=name)
            # 更新模型权重包,将指定路径的权重文件打包,指定迭代次数,并禁止自动删除文件
            output_model.update_weights_package(
                weights_path=artifact_path,
                target_filename=ckpt_dir,
                iteration=state.global_step,
                auto_delete_file=False,
            )
            # 将保存的模型对象添加到检查点列表中
            self._checkpoints_saved.append(output_model)
            # 当设置了保存总数限制并且当前保存的检查点数量超过限制时执行
            while args.save_total_limit and args.save_total_limit < len(self._checkpoints_saved):
                try:
                    # 尝试移除最早的检查点模型及其关联的权重文件
                    self._clearml.model.Model.remove(
                        self._checkpoints_saved[0],
                        delete_weights_file=True,
                        force=True,
                        raise_on_errors=True,
                    )
                except Exception as e:
                    # 记录警告,指示在超过保存限制后无法移除检查点的错误信息
                    logger.warning(
                        "Could not remove checkpoint `{}` after going over the `save_total_limit`. Error is: {}".format(
                            self._checkpoints_saved[0].name, e
                        )
                    )
                    # 中断循环,保持检查点列表不变
                    break
                # 移除成功后,更新检查点列表,去除最早的一个检查点对象
                self._checkpoints_saved = self._checkpoints_saved[1:]

    # 将训练参数复制为超参数,并将其传递给 ClearML 任务
    def _copy_training_args_as_hparams(self, training_args, prefix):
        # 将训练参数对象中的每个字段转换为字典,排除以 "_token" 结尾的字段
        as_dict = {
            field.name: getattr(training_args, field.name)
            for field in fields(training_args)
            if field.init and not field.name.endswith("_token")
        }
        # 扁平化字典,将所有键转换为字符串
        flat_dict = {str(k): v for k, v in self._clearml.utilities.proxy_object.flatten_dictionary(as_dict).items()}
        # 将扁平化后的字典作为超参数设置到 ClearML 任务中
        self._clearml_task._arguments.copy_from_dict(flat_dict, prefix=prefix)
# 定义一个名为 FlyteCallback 的类,继承自 TrainerCallback 类
class FlyteCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
    NOTE: This callback only works within a Flyte task.

    Args:
        save_log_history (`bool`, *optional*, defaults to `True`):
            When set to True, the training logs are saved as a Flyte Deck.

        sync_checkpoints (`bool`, *optional*, defaults to `True`):
            When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
            interruption.

    Example:

    ```
    # Note: This example skips over some setup steps for brevity.
    from flytekit import current_context, task


    @task
    def train_hf_transformer():
        cp = current_context().checkpoint
        trainer = Trainer(..., callbacks=[FlyteCallback()])
        output = trainer.train(resume_from_checkpoint=cp.restore())
    ```
    """

    # 初始化方法,接受两个可选参数:save_log_history 和 sync_checkpoints
    def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
        # 调用父类的初始化方法
        super().__init__()
        
        # 检查 flytekit 是否可用,如果不可用则抛出 ImportError
        if not is_flytekit_available():
            raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
        
        # 检查是否安装了 flytekitplugins-deck-standard 和 pandas,如果未安装,则记录警告并将 save_log_history 设置为 False
        if not is_flyte_deck_standard_available() or not is_pandas_available():
            logger.warning(
                "Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
                "Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
            )
            save_log_history = False
        
        # 导入当前上下文的 checkpoint 对象
        from flytekit import current_context
        self.cp = current_context().checkpoint
        
        # 初始化实例变量
        self.save_log_history = save_log_history
        self.sync_checkpoints = sync_checkpoints

    # 在保存方法回调时执行的操作
    def on_save(self, args, state, control, **kwargs):
        # 如果 sync_checkpoints 为 True,并且当前状态的全局进程是零(即主进程)
        if self.sync_checkpoints and state.is_world_process_zero:
            # 构建检查点目录和存储路径
            ckpt_dir = f"checkpoint-{state.global_step}"
            artifact_path = os.path.join(args.output_dir, ckpt_dir)

            # 记录信息,将检查点同步到 Flyte。这可能需要一些时间。
            logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
            self.cp.save(artifact_path)

    # 在训练结束时执行的操作
    def on_train_end(self, args, state, control, **kwargs):
        # 如果 save_log_history 为 True
        if self.save_log_history:
            # 导入 pandas、Deck 类以及 TableRenderer
            import pandas as pd
            from flytekit import Deck
            from flytekitplugins.deck.renderer import TableRenderer

            # 创建日志历史的 DataFrame
            log_history_df = pd.DataFrame(state.log_history)
            
            # 创建一个名为 "Log History" 的 Flyte Deck,使用 TableRenderer 将 DataFrame 转换为 HTML 格式
            Deck("Log History", TableRenderer().to_html(log_history_df))


# DVCLiveCallback 类暂时省略注释
class DVCLiveCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).

    Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
    those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
    """
    Args:
        live (`dvclive.Live`, *optional*, defaults to `None`):
            Optional Live instance. If None, a new instance will be created using **kwargs.
        log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
            Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
            the final checkpoint is logged at the end of training. If set to `"all"`, the entire
            [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
    """
    
    # DVCLiveCallback 类的初始化方法,用于设置 DVCLive 相关的参数和实例
    def __init__(
        self,
        live: Optional[Any] = None,
        log_model: Optional[Union[Literal["all"], bool]] = None,
        **kwargs,
    ):
        # 检查 dvclive 是否可用,如果不可用则抛出运行时错误
        if not is_dvclive_available():
            raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
        # 导入 dvclive.Live
        from dvclive import Live

        # 初始化实例变量
        self._initialized = False
        self.live = None
        
        # 根据 live 参数的类型来设置 self.live 实例
        if isinstance(live, Live):
            self.live = live
        elif live is not None:
            raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")

        # 设置日志模型的方式
        self._log_model = log_model
        if self._log_model is None:
            # 从环境变量 HF_DVCLIVE_LOG_MODEL 获取日志模型设置
            log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE")
            if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
                self._log_model = True
            elif log_model_env.lower() == "all":
                self._log_model = "all"

    # 设置 DVCLiveCallback 的初始化状态,并在主进程中初始化 dvclive.Live 实例并记录参数
    def setup(self, args, state, model):
        """
        Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
        [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).

        Environment:
        - **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
            Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
            *1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
            [`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
        """
        # 导入 dvclive.Live
        from dvclive import Live

        # 设置初始化状态为 True
        self._initialized = True
        
        # 如果是主进程中的第一个进程,则初始化 dvclive.Live 实例并记录参数
        if state.is_world_process_zero:
            if not self.live:
                self.live = Live()
            self.live.log_params(args.to_dict())

    # 在训练开始时检查是否初始化,如果未初始化则调用 setup 方法初始化
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
    # 当日志事件发生时调用,处理日志相关操作
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        # 如果对象尚未初始化,则进行初始化设置
        if not self._initialized:
            self.setup(args, state, model)
        # 如果是全局进程中的主进程
        if state.is_world_process_zero:
            # 导入必要的库:Metric 类和标准化指标名称的工具函数
            from dvclive.plots import Metric
            from dvclive.utils import standardize_metric_name

            # 遍历日志中的键值对
            for key, value in logs.items():
                # 检查当前值是否可记录为 Metric
                if Metric.could_log(value):
                    # 使用标准化的名称记录指标到 DVCLive
                    self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
                else:
                    # 如果记录的值不符合 Metric 要求,发出警告并丢弃该属性
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        f'"{value}" of type {type(value)} for key "{key}" as a scalar. '
                        "This invocation of DVCLive's Live.log_metric() "
                        "is incorrect so we dropped this attribute."
                    )
            # 在 DVCLive 中记录下一个步骤
            self.live.next_step()

    # 当保存事件发生时调用,处理保存模型相关操作
    def on_save(self, args, state, control, **kwargs):
        # 如果设置为保存所有模型,并且对象已初始化且是全局进程中的主进程
        if self._log_model == "all" and self._initialized and state.is_world_process_zero:
            # 将输出目录作为 artifact 记录到 DVCLive 中
            self.live.log_artifact(args.output_dir)

    # 当训练结束事件发生时调用,处理训练结束相关操作
    def on_train_end(self, args, state, control, **kwargs):
        # 如果对象已初始化且是全局进程中的主进程
        if self._initialized and state.is_world_process_zero:
            # 导入 Transformers 库中的 Trainer 类
            from transformers.trainer import Trainer

            # 如果设置为保存模型
            if self._log_model is True:
                # 创建一个虚拟 Trainer 对象用于保存模型
                fake_trainer = Trainer(args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer"))
                # 根据设置选择保存最佳模型还是最后模型
                name = "best" if args.load_best_model_at_end else "last"
                output_dir = os.path.join(args.output_dir, name)
                # 保存模型到指定目录
                fake_trainer.save_model(output_dir)
                # 将保存的模型目录作为 artifact 记录到 DVCLive 中
                self.live.log_artifact(output_dir, name=name, type="model", copy=True)
            # 在 DVCLive 中结束记录
            self.live.end()
# 定义一个映射,将集成名称映射到相应的回调类
INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
    "neptune": NeptuneCallback,
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
    "codecarbon": CodeCarbonCallback,
    "clearml": ClearMLCallback,
    "dagshub": DagsHubCallback,
    "flyte": FlyteCallback,
    "dvclive": DVCLiveCallback,
}

# 根据给定的报告集成列表,返回对应的回调类列表
def get_reporting_integration_callbacks(report_to):
    # 遍历报告集成列表中的每个集成
    for integration in report_to:
        # 如果集成不在预定义的映射中,则引发 ValueError 异常
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )

    # 返回一个包含各个集成对应的回调类的列表
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]

.\integrations\peft.py

# 导入所需的模块和函数
import inspect  # 导入 inspect 模块,用于检查和分析 Python 对象的属性和结构
import warnings  # 导入 warnings 模块,用于管理警告信息
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union  # 从 typing 模块导入类型提示相关的类型

from ..utils import (
    check_peft_version,  # 导入 check_peft_version 函数,用于检查 PEFT 版本
    find_adapter_config_file,  # 导入 find_adapter_config_file 函数,用于查找适配器配置文件
    is_accelerate_available,  # 导入 is_accelerate_available 函数,用于检查是否可用 accelerate 库
    is_peft_available,  # 导入 is_peft_available 函数,用于检查是否可用 PEFT 库
    is_torch_available,  # 导入 is_torch_available 函数,用于检查是否可用 torch 库
    logging,  # 导入 logging 模块,用于记录日志
)


if is_accelerate_available():  # 如果 accelerate 库可用,则执行以下导入
    from accelerate import dispatch_model  # 导入 dispatch_model 函数,用于调度模型
    from accelerate.utils import get_balanced_memory, infer_auto_device_map  # 导入 get_balanced_memory 和 infer_auto_device_map 函数,用于内存管理和设备映射推断

# PEFT 集成所需的最低版本
MIN_PEFT_VERSION = "0.5.0"

if TYPE_CHECKING:  # 如果是类型检查阶段,则执行以下导入
    if is_torch_available():  # 如果 torch 库可用,则导入 torch 库
        import torch  # 导入 torch 库

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


class PeftAdapterMixin:
    """
    包含加载和使用 PEFT 库支持的适配器权重的所有函数的类。有关适配器及如何在基于 Transformer 的模型中注入它们的详细信息,
    请参阅 PEFT 库的文档: https://huggingface.co/docs/peft/index

    当前支持的 PEFT 方法是所有非前缀调整方法。以下是可以使用此混合类加载、训练和运行的支持的 PEFT 方法列表:
    - Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
    - IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
    - AdaLora: https://arxiv.org/abs/2303.10512

    其他 PEFT 模型,如提示调整、提示学习等因其适配器无法“注入”到 torch 模块而不在讨论范围内。要使用这些方法,请参阅 PEFT 库的使用指南。

    使用此混合类,如果安装了正确的 PEFT 版本,可以:
    - 加载存储在本地路径或远程 Hub 存储库中的适配器,并将其注入模型中
    - 在模型中附加新的适配器,并使用 Trainer 或自己的方法进行训练
    - 附加多个适配器并迭代地激活/停用它们
    - 激活/停用模型中的所有适配器
    - 获取激活适配器的 `state_dict`
    """

    _hf_peft_config_loaded = False  # 初始配置标志,指示 PEFT 配置是否已加载
   python
    def load_adapter(
        self,
        peft_model_id: Optional[str] = None,
        adapter_name: Optional[str] = None,
        revision: Optional[str] = None,
        token: Optional[str] = None,
        device_map: Optional[str] = "auto",
        max_memory: Optional[str] = None,
        offload_folder: Optional[str] = None,
        offload_index: Optional[int] = None,
        peft_config: Dict[str, Any] = None,
        adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
        adapter_kwargs: Optional[Dict[str, Any]] = None,
    ):
        r"""
        If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
        official documentation: https://huggingface.co/docs/peft

        Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default
        name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
        default adapter name).

        Args:
            adapter_config (`~peft.PeftConfig`):
                The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
                methods
            adapter_name (`str`, *optional*, defaults to `"default"`):
                The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
        """
        # 检查 PEFT 版本是否符合最低要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 导入 PEFT 配置和适配器注入函数
        from peft import PeftConfig, inject_adapter_in_model

        # 如果未提供适配器名称,则使用默认名称 "default"
        adapter_name = adapter_name or "default"

        # 如果 PEFT 配置未加载,则标记为已加载
        if not self._hf_peft_config_loaded:
            self._hf_peft_config_loaded = True
        # 如果同名适配器已存在,则抛出 ValueError 异常
        elif adapter_name in self.peft_config:
            raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

        # 如果 adapter_config 不是 PeftConfig 的实例,则抛出 ValueError 异常
        if not isinstance(adapter_config, PeftConfig):
            raise ValueError(
                f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
            )

        # 获取模型的名称或路径,以保持与 PEFT 中的一致性
        adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None)

        # 将适配器注入到模型中
        inject_adapter_in_model(adapter_config, self, adapter_name)

        # 设置当前模型的适配器
        self.set_adapter(adapter_name)
    def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
        """
        If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
        official documentation: https://huggingface.co/docs/peft

        Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.

        Args:
            adapter_name (`Union[List[str], str]`):
                The name of the adapter to set. Can be also a list of strings to set multiple adapters.
        """
        # 检查 PEFT 的最小版本要求
        check_peft_version(min_version=MIN_PEFT_VERSION)
        
        # 如果尚未加载 PEFT 配置,则引发 ValueError
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")
        
        # 如果 adapter_name 是一个列表,检查列表中的适配器是否存在于当前配置中
        elif isinstance(adapter_name, list):
            missing = set(adapter_name) - set(self.peft_config)
            if len(missing) > 0:
                raise ValueError(
                    f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
                    f" current loaded adapters are: {list(self.peft_config.keys())}"
                )
        
        # 如果 adapter_name 不在当前配置中,引发 ValueError
        elif adapter_name not in self.peft_config:
            raise ValueError(
                f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
            )

        # 导入 PEFT 中必要的模块
        from peft.tuners.tuners_utils import BaseTunerLayer
        from peft.utils import ModulesToSaveWrapper

        # 标记是否成功设置了适配器
        _adapters_has_been_set = False

        # 遍历模型的所有模块
        for _, module in self.named_modules():
            # 如果模块是 BaseTunerLayer 或 ModulesToSaveWrapper 的实例
            if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
                # 对于兼容旧版 PEFT 的情况,检查是否有 set_adapter 方法
                if hasattr(module, "set_adapter"):
                    module.set_adapter(adapter_name)
                else:
                    # 否则直接设置 active_adapter 属性
                    module.active_adapter = adapter_name
                _adapters_has_been_set = True

        # 如果没有成功设置适配器,引发 ValueError
        if not _adapters_has_been_set:
            raise ValueError(
                "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
            )
    def disable_adapters(self) -> None:
        r"""
        If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
        official documentation: https://huggingface.co/docs/peft

        Disable all adapters that are attached to the model. This leads to inferring with the base model only.
        """
        # 检查 PEFT 版本是否符合要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 如果 PEFT 配置未加载,则抛出数值错误异常
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")

        # 导入必要的 PEFT 模块
        from peft.tuners.tuners_utils import BaseTunerLayer
        from peft.utils import ModulesToSaveWrapper

        # 遍历模型的所有模块
        for _, module in self.named_modules():
            # 检查模块是否属于 BaseTunerLayer 或 ModulesToSaveWrapper 类型
            if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
                # 如果模块具有 enable_adapters 方法,则调用以禁用适配器
                if hasattr(module, "enable_adapters"):
                    module.enable_adapters(enabled=False)
                else:
                    # 否则,将模块的 disable_adapters 属性设置为 True
                    module.disable_adapters = True

    def enable_adapters(self) -> None:
        """
        If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
        official documentation: https://huggingface.co/docs/peft

        Enable adapters that are attached to the model. The model will use `self.active_adapter()`
        """
        # 检查 PEFT 版本是否符合要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 如果 PEFT 配置未加载,则抛出数值错误异常
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")

        # 导入必要的 PEFT 模块
        from peft.tuners.tuners_utils import BaseTunerLayer

        # 遍历模型的所有模块
        for _, module in self.named_modules():
            # 检查模块是否属于 BaseTunerLayer 类型
            if isinstance(module, BaseTunerLayer):
                # 如果模块具有 enable_adapters 方法,则调用以启用适配器
                if hasattr(module, "enable_adapters"):
                    module.enable_adapters(enabled=True)
                else:
                    # 否则,将模块的 disable_adapters 属性设置为 False
                    module.disable_adapters = False
    def active_adapters(self) -> List[str]:
        """
        获取当前模型的活跃适配器列表。如果进行多适配器推理(结合多个适配器进行推理),返回所有活跃适配器的列表,以便用户可以相应处理。

        对于之前版本的 PEFT(不支持多适配器推理),`module.active_adapter` 将返回一个单独的字符串。
        """
        # 检查 PEFT 版本是否符合最低要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 检查 PEFT 是否可用,如果不可用则抛出 ImportError
        if not is_peft_available():
            raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")

        # 如果没有加载 PEFT 配置,抛出 ValueError
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")

        # 导入 PEFT 的 BaseTunerLayer
        from peft.tuners.tuners_utils import BaseTunerLayer

        # 遍历模型的所有子模块,查找 BaseTunerLayer 类型的模块,获取其活跃适配器
        for _, module in self.named_modules():
            if isinstance(module, BaseTunerLayer):
                active_adapters = module.active_adapter
                break

        # 对于之前的 PEFT 版本,确保 active_adapters 是列表类型
        if isinstance(active_adapters, str):
            active_adapters = [active_adapters]

        # 返回活跃适配器列表
        return active_adapters

    def active_adapter(self) -> str:
        """
        警告:`active_adapter` 方法已弃用,并将在未来版本中移除。
        """
        # 发出警告:方法已弃用
        warnings.warn(
            "The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
        )

        # 返回当前活跃适配器列表中的第一个适配器
        return self.active_adapters()[0]

    def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
        """
        获取适配器的状态字典,该字典应仅包含指定适配器名称的权重张量。如果未传适配器名称,则使用活跃适配器。

        Args:
            adapter_name (`str`, *optional*):
                要获取状态字典的适配器名称。如果未传适配器名称,则使用活跃适配器。
        """
        # 检查 PEFT 版本是否符合最低要求
        check_peft_version(min_version=MIN_PEFT_VERSION)

        # 如果没有加载 PEFT 配置,抛出 ValueError
        if not self._hf_peft_config_loaded:
            raise ValueError("No adapter loaded. Please load an adapter first.")

        # 导入 PEFT 的 get_peft_model_state_dict 函数
        from peft import get_peft_model_state_dict

        # 如果未传适配器名称,使用当前的活跃适配器
        if adapter_name is None:
            adapter_name = self.active_adapter()

        # 获取指定适配器名称的状态字典
        adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
        return adapter_state_dict

    def _dispatch_accelerate_model(
        self,
        device_map: str,
        max_memory: Optional[int] = None,
        offload_folder: Optional[str] = None,
        offload_index: Optional[int] = None,
    ) -> None:
        """
        Optional re-dispatch the model and attach new hooks to the model in case the model has been loaded with
        accelerate (i.e. with `device_map=xxx`)

        Args:
            device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
                A map that specifies where each submodule should go. It doesn't need to be refined to each
                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
                same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
                like `1`) on which the model will be allocated, the device map will map the entire model to this
                device. Passing `device_map = 0` means put the whole model on GPU 0.

                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
                more information about each option see [designing a device
                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
            max_memory (`Dict`, *optional*):
                A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
                GPU and the available CPU RAM if unset.
            offload_folder (`str` or `os.PathLike`, *optional*):
                If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
            offload_index (`int`, *optional*):
                The offload_index argument to be passed to `accelerate.dispatch_model` method.
        """
        # Prepare arguments for dispatching the model
        dispatch_model_kwargs = {}

        # Safety checker for previous `accelerate` versions
        # Check if `offload_index` is supported by the `dispatch_model` function
        if "offload_index" in inspect.signature(dispatch_model).parameters:
            dispatch_model_kwargs["offload_index"] = offload_index

        # Get the list of module classes that should not be split during dispatch
        no_split_module_classes = self._no_split_modules

        # Calculate balanced memory allocation if device_map is not "sequential"
        if device_map != "sequential":
            max_memory = get_balanced_memory(
                self,
                max_memory=max_memory,
                no_split_module_classes=no_split_module_classes,
                low_zero=(device_map == "balanced_low_0"),
            )

        # Infer an automatic device_map if device_map is a string
        if isinstance(device_map, str):
            device_map = infer_auto_device_map(
                self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
            )

        # Dispatch the model with the specified parameters
        dispatch_model(
            self,
            device_map=device_map,
            offload_dir=offload_folder,
            **dispatch_model_kwargs,
        )

.\integrations\quanto.py

# 导入 is_torch_available 函数,检查是否可以导入 torch
from ..utils import is_torch_available

# 如果 torch 可用,则导入 torch 库
if is_torch_available():
    import torch

# 定义函数 replace_with_quanto_layers,用于递归替换给定模型的线性层为 Quanto 量化层,并返回转换后的模型及是否成功的布尔值
def replace_with_quanto_layers(
    model,  # 输入参数:待转换的模型,必须是 torch.nn.Module 的实例
    quantization_config=None,  # 输入参数:量化配置对象,包含量化参数,默认为 None
    modules_to_not_convert=None,  # 输入参数:不转换的模块列表,默认为 None
    current_key_name=None,  # 输入参数:当前键名列表,用于递归,用户不应传递此参数,默认为 None
    has_been_replaced=False,  # 输入参数:指示转换是否成功的布尔值,用于递归,用户不应传递此参数,默认为 False
):
    """
    Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
    Returns the converted model and a boolean that indicates if the conversion has been successfull or not.

    Args:
        model (`torch.nn.Module`):
            The model to convert, can be any `torch.nn.Module` instance.
        quantization_config (`AqlmConfig`, defaults to `None`):
            The quantization config object that contains the quantization parameters.
        modules_to_not_convert (`list`, *optional*, defaults to `None`):
            A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
            converted.
        current_key_name (`list`, *optional*, defaults to `None`):
            A list that contains the current key name. This is used for recursion and should not be passed by the user.
        has_been_replaced (`bool`, *optional*, defaults to `None`):
            A boolean that indicates if the conversion has been successful or not. This is used for recursion and
            should not be passed by the user.
    """

    # 从 accelerate 库中导入 init_empty_weights 函数
    from accelerate import init_empty_weights
    # 从 quanto 库中导入 QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8 等类和函数
    from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8

    # 定义权重映射表和激活映射表
    w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
    a_mapping = {None: None, "float8": qfloat8, "int8": qint8}

    # 如果 modules_to_not_convert 为 None,则设为空列表
    if modules_to_not_convert is None:
        modules_to_not_convert = []
    # 遍历模型的每个子模块,获取子模块的名称和对象
    for name, module in model.named_children():
        # 如果当前键名为 None,则初始化为空列表
        if current_key_name is None:
            current_key_name = []
        # 将当前子模块的名称添加到当前键名列表中
        current_key_name.append(name)

        # 检查当前模块的全限定名是否包含在不转换的模块列表中的任何项
        if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
            # 使用空权重初始化上下文管理器
            with init_empty_weights():
                # 如果模块是线性层
                if isinstance(module, torch.nn.Linear):
                    # 替换为量化后的线性层 QLinear
                    model._modules[name] = QLinear(
                        in_features=module.in_features,
                        out_features=module.out_features,
                        bias=module.bias is not None,
                        dtype=module.weight.dtype,
                        weights=w_mapping[quantization_config.weights],
                        activations=a_mapping[quantization_config.activations],
                    )
                    # 设置新模块不需要梯度计算
                    model._modules[name].requires_grad_(False)
                    # 标记替换已完成
                    has_been_replaced = True
                # 如果模块是 LayerNorm 层
                elif isinstance(module, torch.nn.LayerNorm):
                    # 如果存在激活量化配置,则替换为量化后的 LayerNorm 层 QLayerNorm
                    if quantization_config.activations is not None:
                        model._modules[name] = QLayerNorm(
                            module.normalized_shape,
                            module.eps,
                            module.elementwise_affine,
                            module.bias is not None,
                            activations=a_mapping[quantization_config.activations],
                        )
                        # 标记替换已完成
                        has_been_replaced = True

        # 如果当前模块有子模块,则递归替换子模块中的量化层
        if len(list(module.children())) > 0:
            _, has_been_replaced = replace_with_quanto_layers(
                module,
                quantization_config=quantization_config,
                modules_to_not_convert=modules_to_not_convert,
                current_key_name=current_key_name,
                has_been_replaced=has_been_replaced,
            )
        
        # 递归完成后,移除当前键名列表中的最后一个键,准备处理下一个子模块
        current_key_name.pop(-1)

    # 返回替换后的模型和替换操作是否发生过的标志
    return model, has_been_replaced

.\integrations\tpu.py

# 导入 DataLoader 类从 torch.utils.data 模块
# 导入 is_torch_xla_available 函数从上级模块 ..utils 中
from torch.utils.data import DataLoader
from ..utils import is_torch_xla_available

# 定义函数 tpu_spmd_dataloader,接收一个 DataLoader 对象作为参数
def tpu_spmd_dataloader(dataloader: DataLoader):
    # 如果 Torch XLA 可用
    if is_torch_xla_available():
        # 导入 torch_xla.distributed.parallel_loader 模块,并重命名为 pl
        import torch_xla.distributed.parallel_loader as pl
        
        # 断言确保 dataloader 是一个 torch_xla.distributed.parallel_loader.MpDeviceLoader 对象
        assert isinstance(
            dataloader, pl.MpDeviceLoader
        ), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."
        
        # 注释:这段代码支持通过 SPMD 实现 PyTorch/XLA FSDP。
        # 在这里,我们将输入数据的第 0 维在 fsdp 轴上进行分片。
        
        # 导入 torch_xla.distributed.spmd 模块,并重命名为 xs
        import torch_xla.distributed.spmd as xs
        
        # 获取全局网格并创建一个 ShardingSpec 对象,指定 fsdp 轴上的分片规范
        sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
        
        # 将 input_sharding 参数设置为上述分片规范
        dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
        
        # 返回修改后的 dataloader 对象
        return dataloader
    else:
        # 如果 Torch XLA 不可用,则直接返回原始的 dataloader 对象
        return dataloader

.\integrations\__init__.py

# 版权声明及许可信息
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型检查模块
from typing import TYPE_CHECKING

# 从当前包的utils模块中导入_LazyModule类
from ..utils import _LazyModule

# 定义模块的导入结构
_import_structure = {
    "aqlm": ["replace_with_aqlm_linear"],  # 导入aqlm模块的replace_with_aqlm_linear函数
    "awq": [
        "fuse_awq_modules",  # 导入awq模块的fuse_awq_modules函数
        "post_init_awq_exllama_modules",  # 导入awq模块的post_init_awq_exllama_modules函数
        "replace_with_awq_linear",  # 导入awq模块的replace_with_awq_linear函数
    ],
    "bitsandbytes": [
        "get_keys_to_not_convert",  # 导入bitsandbytes模块的get_keys_to_not_convert函数
        "replace_8bit_linear",  # 导入bitsandbytes模块的replace_8bit_linear函数
        "replace_with_bnb_linear",  # 导入bitsandbytes模块的replace_with_bnb_linear函数
        "set_module_8bit_tensor_to_device",  # 导入bitsandbytes模块的set_module_8bit_tensor_to_device函数
        "set_module_quantized_tensor_to_device",  # 导入bitsandbytes模块的set_module_quantized_tensor_to_device函数
    ],
    "deepspeed": [
        "HfDeepSpeedConfig",  # 导入deepspeed模块的HfDeepSpeedConfig类
        "HfTrainerDeepSpeedConfig",  # 导入deepspeed模块的HfTrainerDeepSpeedConfig类
        "deepspeed_config",  # 导入deepspeed模块的deepspeed_config函数
        "deepspeed_init",  # 导入deepspeed模块的deepspeed_init函数
        "deepspeed_load_checkpoint",  # 导入deepspeed模块的deepspeed_load_checkpoint函数
        "deepspeed_optim_sched",  # 导入deepspeed模块的deepspeed_optim_sched函数
        "is_deepspeed_available",  # 导入deepspeed模块的is_deepspeed_available函数
        "is_deepspeed_zero3_enabled",  # 导入deepspeed模块的is_deepspeed_zero3_enabled函数
        "set_hf_deepspeed_config",  # 导入deepspeed模块的set_hf_deepspeed_config函数
        "unset_hf_deepspeed_config",  # 导入deepspeed模块的unset_hf_deepspeed_config函数
    ],
    "integration_utils": [
        "INTEGRATION_TO_CALLBACK",  # 导入integration_utils模块的INTEGRATION_TO_CALLBACK常量
        "AzureMLCallback",  # 导入integration_utils模块的AzureMLCallback类
        "ClearMLCallback",  # 导入integration_utils模块的ClearMLCallback类
        "CodeCarbonCallback",  # 导入integration_utils模块的CodeCarbonCallback类
        "CometCallback",  # 导入integration_utils模块的CometCallback类
        "DagsHubCallback",  # 导入integration_utils模块的DagsHubCallback类
        "DVCLiveCallback",  # 导入integration_utils模块的DVCLiveCallback类
        "FlyteCallback",  # 导入integration_utils模块的FlyteCallback类
        "MLflowCallback",  # 导入integration_utils模块的MLflowCallback类
        "NeptuneCallback",  # 导入integration_utils模块的NeptuneCallback类
        "NeptuneMissingConfiguration",  # 导入integration_utils模块的NeptuneMissingConfiguration异常类
        "TensorBoardCallback",  # 导入integration_utils模块的TensorBoardCallback类
        "WandbCallback",  # 导入integration_utils模块的WandbCallback类
        "get_available_reporting_integrations",  # 导入integration_utils模块的get_available_reporting_integrations函数
        "get_reporting_integration_callbacks",  # 导入integration_utils模块的get_reporting_integration_callbacks函数
        "hp_params",  # 导入integration_utils模块的hp_params函数
        "is_azureml_available",  # 导入integration_utils模块的is_azureml_available函数
        "is_clearml_available",  # 导入integration_utils模块的is_clearml_available函数
        "is_codecarbon_available",  # 导入integration_utils模块的is_codecarbon_available函数
        "is_comet_available",  # 导入integration_utils模块的is_comet_available函数
        "is_dagshub_available",  # 导入integration_utils模块的is_dagshub_available函数
        "is_dvclive_available",  # 导入integration_utils模块的is_dvclive_available函数
        "is_flyte_deck_standard_available",  # 导入integration_utils模块的is_flyte_deck_standard_available函数
        "is_flytekit_available",  # 导入integration_utils模块的is_flytekit_available函数
        "is_mlflow_available",  # 导入integration_utils模块的is_mlflow_available函数
        "is_neptune_available",  # 导入integration_utils模块的is_neptune_available函数
        "is_optuna_available",  # 导入integration_utils模块的is_optuna_available函数
        "is_ray_available",  # 导入integration_utils模块的is_ray_available函数
        "is_ray_tune_available",  # 导入integration_utils模块的is_ray_tune_available函数
        "is_sigopt_available",  # 导入integration_utils模块的is_sigopt_available函数
        "is_tensorboard_available",  # 导入integration_utils模块的is_tensorboard_available函数
        "is_wandb_available",  # 导入integration_utils模块的is_wandb_available函数
        "rewrite_logs",  # 导入integration_utils模块的rewrite_logs函数
        "run_hp_search_optuna",  # 导入integration_utils模块的run_hp_search_optuna函数
        "run_hp_search_ray",  # 导入integration_utils模块的run_hp_search_ray函数
        "run_hp_search_sigopt",  # 导入integration_utils模块的run_hp_search_sigopt函数
        "run_hp_search_wandb",  # 导入integration_utils模块的run_hp_search_wandb函数
    ],
    "peft": ["PeftAdapterMixin"],  # 导入peft模块的PeftAdapterMixin类
    "quanto": ["replace_with_quanto_layers"],  # 导入quanto模块的replace_with_quanto_layers函数
}

# 如果支持类型检查,则导入以下类型
if TYPE_CHECKING:
    from .aqlm import replace_with_aqlm_linear  # 导入aqlm模块的replace_with_aqlm_linear函数
    from .awq import (
        fuse_awq_modules,  # 导入awq模块的fuse_awq_modules函数
        post_init_awq_exllama_modules,  # 导入awq模块的post_init_awq_exllama_modules函数
        replace_with_awq_linear,  # 导入awq模块的replace_with_awq_linear函数
    )
    # 导入从bitsandbytes模块中的函数和类
    from .bitsandbytes import (
        get_keys_to_not_convert,               # 获取不转换的键
        replace_8bit_linear,                  # 替换为8位线性
        replace_with_bnb_linear,              # 替换为BNB线性
        set_module_8bit_tensor_to_device,     # 将模块的8位张量设置到设备
        set_module_quantized_tensor_to_device # 将模块的量化张量设置到设备
    )
    
    # 导入从deepspeed模块中的函数和类
    from .deepspeed import (
        HfDeepSpeedConfig,                    # Hugging Face DeepSpeed 配置
        HfTrainerDeepSpeedConfig,             # Hugging Face Trainer DeepSpeed 配置
        deepspeed_config,                     # DeepSpeed 配置
        deepspeed_init,                       # DeepSpeed 初始化
        deepspeed_load_checkpoint,            # 加载 DeepSpeed 检查点
        deepspeed_optim_sched,                # DeepSpeed 优化和调度
        is_deepspeed_available,               # 判断 DeepSpeed 是否可用
        is_deepspeed_zero3_enabled,           # 判断 DeepSpeed Zero3 是否启用
        set_hf_deepspeed_config,              # 设置 Hugging Face DeepSpeed 配置
        unset_hf_deepspeed_config             # 取消设置 Hugging Face DeepSpeed 配置
    )
    
    # 导入从integration_utils模块中的函数和类
    from .integration_utils import (
        INTEGRATION_TO_CALLBACK,              # 集成到回调函数的映射
        AzureMLCallback,                      # AzureML 回调函数
        ClearMLCallback,                      # ClearML 回调函数
        CodeCarbonCallback,                   # CodeCarbon 回调函数
        CometCallback,                        # Comet 回调函数
        DagsHubCallback,                      # DagsHub 回调函数
        DVCLiveCallback,                      # DVCLive 回调函数
        FlyteCallback,                        # Flyte 回调函数
        MLflowCallback,                       # MLflow 回调函数
        NeptuneCallback,                      # Neptune 回调函数
        NeptuneMissingConfiguration,          # Neptune 缺少配置
        TensorBoardCallback,                  # TensorBoard 回调函数
        WandbCallback,                        # Wandb 回调函数
        get_available_reporting_integrations, # 获取可用的报告集成
        get_reporting_integration_callbacks,  # 获取报告集成回调函数
        hp_params,                            # 超参数配置
        is_azureml_available,                 # 判断 AzureML 是否可用
        is_clearml_available,                 # 判断 ClearML 是否可用
        is_codecarbon_available,              # 判断 CodeCarbon 是否可用
        is_comet_available,                   # 判断 Comet 是否可用
        is_dagshub_available,                 # 判断 DagsHub 是否可用
        is_dvclive_available,                 # 判断 DVCLive 是否可用
        is_flyte_deck_standard_available,     # 判断 Flyte Deck Standard 是否可用
        is_flytekit_available,                # 判断 Flytekit 是否可用
        is_mlflow_available,                  # 判断 MLflow 是否可用
        is_neptune_available,                 # 判断 Neptune 是否可用
        is_optuna_available,                  # 判断 Optuna 是否可用
        is_ray_available,                     # 判断 Ray 是否可用
        is_ray_tune_available,                # 判断 Ray Tune 是否可用
        is_sigopt_available,                  # 判断 SigOpt 是否可用
        is_tensorboard_available,             # 判断 TensorBoard 是否可用
        is_wandb_available,                   # 判断 Wandb 是否可用
        rewrite_logs,                         # 重写日志
        run_hp_search_optuna,                 # 运行 Optuna 的超参数搜索
        run_hp_search_ray,                    # 运行 Ray 的超参数搜索
        run_hp_search_sigopt,                 # 运行 SigOpt 的超参数搜索
        run_hp_search_wandb                   # 运行 Wandb 的超参数搜索
    )
    
    # 导入从peft模块中的类
    from .peft import PeftAdapterMixin         # Peft 适配器混合类
    
    # 导入从quanto模块中的函数
    from .quanto import replace_with_quanto_layers  # 替换为 Quanto 层
else:
    # 导入 sys 模块,用于动态管理模块
    import sys

    # 将当前模块注册为一个懒加载模块
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\keras_callbacks.py

import logging  # 导入日志模块
import os  # 导入操作系统模块
from pathlib import Path  # 导入路径操作模块
from time import sleep  # 导入睡眠函数
from typing import Callable, List, Optional, Union  # 导入类型提示相关模块

import numpy as np  # 导入NumPy库
import tensorflow as tf  # 导入TensorFlow库
from huggingface_hub import Repository, create_repo  # 导入Hugging Face Hub相关函数
from packaging.version import parse  # 导入版本解析模块

from . import IntervalStrategy, PreTrainedTokenizerBase  # 从当前包导入特定模块
from .modelcard import TrainingSummary  # 从当前包导入模型卡片中的训练摘要
from .modeling_tf_utils import keras  # 从当前包导入TensorFlow工具中的Keras模块

logger = logging.getLogger(__name__)  # 获取当前模块的日志记录器

class KerasMetricCallback(keras.callbacks.Callback):
    """
    Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
    compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
    operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
    `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
    metrics and return a dict mapping metric names to metric values.

    We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
    this example skips some post-processing for readability and simplicity, and should probably not be used as-is!

    ```
    from datasets import load_metric

    rouge_metric = load_metric("rouge")


    def rouge_fn(predictions, labels):
        decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
        return {key: value.mid.fmeasure * 100 for key, value in result.items()}
    ```

    The above function will return a dict containing values which will be logged like any other Keras metric:

    ```
    {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
    ```
    """
    pass  # KerasMetricCallback 类暂时不包含具体实现,只是一个占位符
    # 初始化方法,用于创建一个新的评估器对象
    def __init__(
        self,
        metric_fn: Callable,  # 参数:评估指标函数,接受预测值和标签作为输入,返回指标名称到数值的字典
        eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],  # 参数:用于评估的数据集或数据字典/元组/数组
        output_cols: Optional[List[str]] = None,  # 可选参数:模型输出中要保留的列名列表,默认为全部列
        label_cols: Optional[List[str]] = None,  # 可选参数:从输入数据集中要保留的标签列名列表,如果未提供则自动检测
        batch_size: Optional[int] = None,  # 可选参数:批处理大小,仅在数据不是预先批处理的 tf.data.Dataset 时使用
        predict_with_generate: bool = False,  # 可选参数:是否使用 model.generate() 获取模型输出
        use_xla_generation: bool = False,  # 可选参数:如果生成结果,是否使用 XLA 编译模型生成,可以显著提高生成速度
        generate_kwargs: Optional[dict] = None,  # 可选参数:传递给 model.generate() 的关键字参数,仅在 predict_with_generate 为 True 时有效
        ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置度量函数和批次大小
        self.metric_fn = metric_fn
        self.batch_size = batch_size
        # 如果评估数据集不是 tf.data.Dataset 类型,则根据情况处理
        if not isinstance(eval_dataset, tf.data.Dataset):
            if batch_size is None:
                # 如果没有设置批次大小且传入的数据不是预先批处理的 tf.data.Dataset,则抛出异常
                raise ValueError(
                    "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
                    "the batch_size argument must be set."
                )
            # 将传入的数据转换为 tf.data.Dataset,并按指定的批次大小进行分批
            eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
        # 存储评估数据集
        self.eval_dataset = eval_dataset
        self.predict_with_generate = predict_with_generate
        self.output_cols = output_cols

        # 下面的代码块尝试解析数据集的元素规范,确定应该将哪些元素附加到传递给 metric_fn 的标签列表中
        if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
            # 如果数据集的元素规范是一个元组且长度为 2,则假设第一个元素是输入,第二个元素是标签
            input_spec, label_spec = eval_dataset.element_spec
        else:
            # 否则,将整个元素规范视为输入规范,标签规范设为 None
            input_spec = eval_dataset.element_spec
            label_spec = None
        # 如果指定了 label_cols
        if label_cols is not None:
            # 检查每个指定的标签是否在输入规范中,如果不在则抛出异常
            for label in label_cols:
                if label not in input_spec:
                    raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
            self.label_cols = label_cols
            self.use_keras_label = False
        elif label_spec is not None:
            # 如果数据集的元素规范是一个元组,且没有指定 label_cols,则假设第二个元素是标签
            self.label_cols = None
            self.use_keras_label = True
        elif "labels" in input_spec:
            # 如果输入规范中有 "labels",则将其作为标签列
            self.label_cols = ["labels"]
            self.use_keras_label = False
            logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
        elif "start_positions" in input_spec and "end_positions" in input_spec:
            # 如果输入规范中有 "start_positions" 和 "end_positions",则将它们作为标签列
            self.label_cols = ["start_positions", "end_positions"]
            self.use_keras_label = False
            logging.warning(
                "No label_cols specified for KerasMetricCallback, assuming you want the "
                "start_positions and end_positions keys."
            )
        else:
            # 如果无法自动检测到标签列,则抛出异常
            raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
        # 如果 TensorFlow 版本小于 2.7,给出警告
        if parse(tf.__version__) < parse("2.7"):
            logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")

        # 设置是否使用 XLA 生成
        self.use_xla_generation = use_xla_generation
        # 生成文本的额外参数
        self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs

        # 生成函数初始化为 None
        self.generation_function = None

    @staticmethod
    def _concatenate_batches(batches, padding_index=-100):
        # 如果所有批次都是一维的或者长度相同,直接进行简单的拼接
        if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
            return np.concatenate(batches, axis=0)

        # 如果批次长度不同,进行填充操作
        max_len = max([batch.shape[1] for batch in batches])  # 计算最大长度
        num_samples = sum([batch.shape[0] for batch in batches])  # 计算总样本数
        output = np.full_like(
            batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
        )
        # i 用于跟踪下一个要写入批次数据的位置
        i = 0
        for batch in batches:
            output[i : i + len(batch), : batch.shape[1]] = batch  # 将每个批次的数据写入到输出中
            i += len(batch)
        return output

    def _postprocess_predictions_or_labels(self, inputs):
        if isinstance(inputs[0], dict):
            outputs = {}
            for key in inputs[0].keys():
                outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
            # 如果输出是一个只有一个键的字典,直接返回数组
            if len(outputs) == 1:
                outputs = list(outputs.values())[0]
        elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
            outputs = []
            for input_list in zip(*inputs):
                outputs.append(self._concatenate_batches(input_list))
            if len(outputs) == 1:
                outputs = outputs[0]  # 如果输出是一个只有一个元素的列表,直接返回数组
        elif isinstance(inputs[0], np.ndarray):
            outputs = self._concatenate_batches(inputs)
        elif isinstance(inputs[0], tf.Tensor):
            outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
        else:
            raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")  # 处理无法处理的批次类型异常
        return outputs
# 定义一个自定义的回调类,用于定期保存模型并推送到 Hub 上。默认情况下,每个 epoch 结束后进行推送,但可以通过 `save_strategy` 参数进行更改。
class PushToHubCallback(keras.callbacks.Callback):

    """
    Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
    be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
    as with the `from_pretrained` method.

    ```
    from transformers.keras_callbacks import PushToHubCallback

    push_to_hub_callback = PushToHubCallback(
        output_dir="./model_save",
        tokenizer=tokenizer,
        hub_model_id="gpt5-7xlarge",
    )

    model.fit(train_dataset, callbacks=[push_to_hub_callback])
    ```

    Args:
        output_dir (`str`):
            The output directory where the model predictions and checkpoints will be written and synced with the
            repository on the Hub.
        save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
            The checkpoint save strategy to adopt during training. Possible values are:

                - `"no"`: Save is done at the end of training.
                - `"epoch"`: Save is done at the end of each epoch.
                - `"steps"`: Save is done every `save_steps`
        save_steps (`int`, *optional*):
            The number of steps between saves when using the "steps" `save_strategy`.
        tokenizer (`PreTrainedTokenizerBase`, *optional*):
            The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
        hub_model_id (`str`, *optional*):
            The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
            which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
            for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
            `"organization_name/model"`.

            Will default to the name of `output_dir`.
        hub_token (`str`, *optional*):
            The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
            `huggingface-cli login`.
        checkpoint (`bool`, *optional*, defaults to `False`):
            Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
            resumed. Only usable when `save_strategy` is `"epoch"`.
    """

    def __init__(
        self,
        output_dir: Union[str, Path],
        save_strategy: Union[str, IntervalStrategy] = "epoch",
        save_steps: Optional[int] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        hub_model_id: Optional[str] = None,
        hub_token: Optional[str] = None,
        checkpoint: bool = False,
        **model_card_args,
    ):
        # 初始化回调函数,接受输出目录、保存策略、保存步数、分词器、Hub 模型 ID、Hub token、是否保存检查点等参数
        super().__init__()
        # 设置输出目录,用于保存模型预测和检查点,并与 Hub 上的仓库同步
        self.output_dir = output_dir
        # 设置保存策略,控制模型保存的频率,默认为每个 epoch 结束时保存
        self.save_strategy = save_strategy
        # 设置保存步数,当保存策略为 "steps" 时,指定每隔多少步保存一次
        self.save_steps = save_steps
        # 设置分词器,如果提供,将与模型权重一起上传到 Hub
        self.tokenizer = tokenizer
        # 设置 Hub 模型 ID,指定要同步的本地输出目录对应的仓库名称
        self.hub_model_id = hub_model_id
        # 设置 Hub token,用于推送模型到 Hub,如果未提供,则使用缓存文件夹中的 token
        self.hub_token = hub_token
        # 设置是否保存完整的训练检查点,包括 epoch 和优化器状态,允许在训练中断后恢复
        self.checkpoint = checkpoint
        # 其他模型卡片参数,以字典形式传递给模型卡片
        self.model_card_args = model_card_args
        ):
            super().__init__()
            # 调用父类的构造方法
            if checkpoint and save_strategy != "epoch":
                raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
            # 检查是否能够保存检查点,若保存策略不是 'epoch',则抛出值错误异常
            if isinstance(save_strategy, str):
                save_strategy = IntervalStrategy(save_strategy.lower())
            # 如果保存策略是字符串,则转换为小写后创建 IntervalStrategy 对象
            self.save_strategy = save_strategy
            # 设置保存策略
            if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
                raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
            # 如果保存策略为步数,并且保存步数不是正整数或者小于等于零,则抛出值错误异常
            self.save_steps = save_steps
            # 设置保存步数
            output_dir = Path(output_dir)

            # Create repo and retrieve repo_id
            # 创建仓库并获取仓库 ID
            if hub_model_id is None:
                hub_model_id = output_dir.absolute().name
            # 如果未指定 hub_model_id,则将其设为输出目录的绝对路径名
            self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
            # 创建仓库,获取仓库 ID,并存储到实例变量中

            self.output_dir = output_dir
            # 设置输出目录
            self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
            # 创建仓库对象,克隆自指定的 hub_model_id,并设置令牌

            self.tokenizer = tokenizer
            # 设置分词器
            self.last_job = None
            # 初始化最后一个作业为 None
            self.checkpoint = checkpoint
            # 设置检查点标志
            self.training_history = None
            # 初始化训练历史为 None
            self.model_card_args = model_card_args
            # 设置模型卡参数

    def on_train_begin(self, logs=None):
        # Although we can access model.history, we have no guarantees that the History callback will fire before this
        # one, so we keep track of it here too
        # 虽然我们可以访问 model.history,但不能保证 History 回调会在当前回调之前触发,因此我们也在这里进行跟踪
        self.training_history = []
        # 初始化训练历史为空列表

    def on_train_batch_end(self, batch, logs=None):
        if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
            # 如果保存策略是基于步数,并且当前批次是保存步数的倍数
            if self.last_job is not None and not self.last_job.is_done:
                return  # The last upload is still running, don't start another
                # 如果上一个上传仍在运行中,则不启动另一个上传
            self.model.save_pretrained(self.output_dir)
            # 保存模型到输出目录
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(self.output_dir)
                # 如果存在分词器,则保存分词器到输出目录
            _, self.last_job = self.repo.push_to_hub(
                commit_message=f"Training in progress steps {batch}", blocking=False
            )
            # 推送模型和分词器到 Hub 仓库,使用批次号作为提交消息,非阻塞模式
    # 在每个 epoch 结束时调用的方法,用于处理日志和保存模型训练历史
    def on_epoch_end(self, epoch, logs=None):
        logs = logs.copy()  # 复制日志以避免意外影响后续 Keras 的读取操作
        if "epoch" not in logs:
            logs["epoch"] = epoch  # 如果日志中没有 epoch,则添加当前 epoch
        self.training_history.append(logs)  # 将当前 epoch 的日志记录到训练历史中
        if self.save_strategy == IntervalStrategy.EPOCH:
            if self.last_job is not None and not self.last_job.is_done:
                return  # 如果上一个上传任务仍在运行,则不启动新的任务
            self.model.save_pretrained(self.output_dir)  # 保存模型到指定输出目录
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(self.output_dir)  # 如果存在 tokenizer,则保存到同一输出目录
            if self.checkpoint:
                checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
                self.model._save_checkpoint(checkpoint_dir, epoch)  # 保存检查点信息
            # 从 Keras 历史和模型信息中生成训练摘要
            train_summary = TrainingSummary.from_keras(
                model=self.model,
                model_name=self.hub_model_id,
                keras_history=self.training_history,
                **self.model_card_args,
            )
            model_card = train_summary.to_model_card()  # 转换训练摘要为模型卡片信息
            with (self.output_dir / "README.md").open("w") as f:
                f.write(model_card)  # 将模型卡片信息写入 README.md 文件中
            # 推送到版本控制平台(Hub),并获取推送任务状态
            _, self.last_job = self.repo.push_to_hub(
                commit_message=f"Training in progress epoch {epoch}", blocking=False
            )

    # 在训练结束时调用的方法,确保最新版本的模型已上传到 Hub
    def on_train_end(self, logs=None):
        if self.last_job is not None and not self.last_job.is_done:
            logging.info("Pushing the last epoch to the Hub, this may take a while...")
            while not self.last_job.is_done:
                sleep(1)  # 等待上一个推送任务完成
        else:
            self.model.save_pretrained(self.output_dir)  # 保存最终版本的模型到输出目录
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(self.output_dir)  # 如果存在 tokenizer,则保存到同一输出目录
            # 从 Keras 历史和模型信息中生成训练摘要
            train_summary = TrainingSummary.from_keras(
                model=self.model,
                model_name=self.hub_model_id,
                keras_history=self.training_history,
                **self.model_card_args,
            )
            model_card = train_summary.to_model_card()  # 转换训练摘要为模型卡片信息
            with (self.output_dir / "README.md").open("w") as f:
                f.write(model_card)  # 将模型卡片信息写入 README.md 文件中
            self.repo.push_to_hub(commit_message="End of training", blocking=True)  # 推送最终训练结果到版本控制平台(Hub)
posted @ 2024-06-29 16:58  绝不原创的飞龙  阅读(62)  评论(0编辑  收藏  举报