Transformers-源码解析-一百三十五-

Transformers 源码解析(一百三十五)

.\tools\agent_types.py

# 导入必要的库和模块
import os  # 导入操作系统模块
import pathlib  # 导入路径操作模块
import tempfile  # 导入临时文件模块
import uuid  # 导入 UUID 模块

import numpy as np  # 导入 NumPy 库

from ..utils import (  # 导入自定义工具模块中的函数和类
    is_soundfile_availble,  # 检查是否可用的音频文件模块
    is_torch_available,  # 检查是否可用的 PyTorch 模块
    is_vision_available,  # 检查是否可用的视觉模块
    logging  # 导入日志记录模块
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

# 如果视觉模块可用,则导入相关 PIL 库和类
if is_vision_available():
    import PIL.Image  # 导入 PIL 库的图像模块
    from PIL import Image  # 导入 PIL 库的图像模块
    from PIL.Image import Image as ImageType  # 导入 PIL 图像类型别名 ImageType
else:
    ImageType = object  # 否则将 ImageType 设置为通用对象类型

# 如果 PyTorch 模块可用,则导入 torch 库
if is_torch_available():
    import torch  # 导入 PyTorch 模块

# 如果音频文件模块可用,则导入 soundfile 库
if is_soundfile_availble():
    import soundfile as sf  # 导入 soundfile 库


class AgentType:
    """
    抽象类,用于定义代理返回的对象类型。
    
    这些对象具有以下三个目的:
    - 它们表现为它们所代表的类型,例如文本的字符串,图像的 PIL.Image
    - 它们可以转化为字符串形式:str(object) 返回对象定义的字符串
    - 它们应该在 ipython 笔记本/colab/jupyter 中正确显示
    """

    def __init__(self, value):
        self._value = value  # 初始化对象的值

    def __str__(self):
        return self.to_string()  # 返回对象的字符串形式

    def to_raw(self):
        logger.error(
            "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
        )
        return self._value  # 返回对象的原始值

    def to_string(self) -> str:
        logger.error(
            "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
        )
        return str(self._value)  # 返回对象的字符串形式


class AgentText(AgentType, str):
    """
    代理返回的文本类型,表现为字符串。
    """

    def to_raw(self):
        return self._value  # 返回文本对象的原始值

    def to_string(self):
        return self._value  # 返回文本对象的字符串形式


class AgentImage(AgentType, ImageType):
    """
    代理返回的图像类型,表现为 PIL.Image。
    """

    def __init__(self, value):
        super().__init__(value)  # 调用父类的初始化方法

        if not is_vision_available():
            raise ImportError("PIL must be installed in order to handle images.")  # 如果 PIL 不可用,则引发 ImportError

        self._path = None  # 初始化图像路径为 None
        self._raw = None  # 初始化原始图像为 None
        self._tensor = None  # 初始化图像张量为 None

        # 根据值的类型进行初始化
        if isinstance(value, ImageType):
            self._raw = value  # 如果值是 PIL 图像类型,则设置为原始图像
        elif isinstance(value, (str, pathlib.Path)):
            self._path = value  # 如果值是字符串或路径对象,则设置为图像路径
        elif isinstance(value, torch.Tensor):
            self._tensor = value  # 如果值是 PyTorch 张量,则设置为图像张量
        else:
            raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")  # 如果值的类型不支持,则引发 ValueError
    # 在 IPython 环境中显示对象,支持在 IPython Notebook 中显示
    def _ipython_display_(self, include=None, exclude=None):
        """
        Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
        """
        # 导入 IPython 的显示模块和 Image 类
        from IPython.display import Image, display

        # 显示当前对象的图像表示
        display(Image(self.to_string()))

    # 返回该对象的原始版本,在 AgentImage 类中是一个 PIL.Image 对象
    def to_raw(self):
        """
        Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
        """
        # 如果已经存在原始图像对象,则直接返回
        if self._raw is not None:
            return self._raw

        # 如果存在图像文件路径,则打开并返回对应的 PIL.Image 对象
        if self._path is not None:
            self._raw = Image.open(self._path)
            return self._raw

    # 返回该对象的字符串表示,在 AgentImage 类中是图像的序列化版本的路径
    def to_string(self):
        """
        Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
        version of the image.
        """
        # 如果图像文件路径已经存在,则直接返回路径字符串
        if self._path is not None:
            return self._path

        # 如果原始图像对象存在,则将其保存为 PNG 格式的临时文件,并返回文件路径
        if self._raw is not None:
            # 创建一个临时目录
            directory = tempfile.mkdtemp()
            # 使用 UUID 生成唯一文件名,并保存为 PNG 格式
            self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
            self._raw.save(self._path)

            return self._path

        # 如果存在张量表示,并且需要转换为图像保存
        if self._tensor is not None:
            # 将张量转换为 numpy 数组,并缩放到 0-255 的范围,然后转换为 PIL.Image 对象
            array = self._tensor.cpu().detach().numpy()
            img = Image.fromarray((array * 255).astype(np.uint8))

            # 创建一个临时目录
            directory = tempfile.mkdtemp()
            # 使用 UUID 生成唯一文件名,并保存为 PNG 格式
            self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")

            # 将图像保存为 PNG 文件
            img.save(self._path)

            return self._path
class AgentAudio(AgentType):
    """
    Audio type returned by the agent.
    """

    def __init__(self, value, samplerate=16_000):
        # 调用父类的初始化方法
        super().__init__(value)

        # 检查是否安装了 soundfile 库,否则抛出 ImportError 异常
        if not is_soundfile_availble():
            raise ImportError("soundfile must be installed in order to handle audio.")

        # 初始化对象的路径和张量属性
        self._path = None
        self._tensor = None

        # 设置采样率
        self.samplerate = samplerate

        # 根据 value 的类型初始化对象的路径或张量属性
        if isinstance(value, (str, pathlib.Path)):
            self._path = value
        elif isinstance(value, torch.Tensor):
            self._tensor = value
        else:
            raise ValueError(f"Unsupported audio type: {type(value)}")

    def _ipython_display_(self, include=None, exclude=None):
        """
        Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
        """
        # 导入必要的库函数
        from IPython.display import Audio, display

        # 在 IPython 环境中显示音频对象
        display(Audio(self.to_string(), rate=self.samplerate))

    def to_raw(self):
        """
        Returns the "raw" version of that object. It is a `torch.Tensor` object.
        """
        # 如果对象是张量,则直接返回张量
        if self._tensor is not None:
            return self._tensor

        # 如果对象是文件路径,则读取音频数据并转换为张量
        if self._path is not None:
            tensor, self.samplerate = sf.read(self._path)
            self._tensor = torch.tensor(tensor)
            return self._tensor

    def to_string(self):
        """
        Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
        version of the audio.
        """
        # 如果对象是文件路径,则直接返回路径
        if self._path is not None:
            return self._path

        # 如果对象是张量,则将其保存为临时 WAV 文件并返回该文件路径
        if self._tensor is not None:
            directory = tempfile.mkdtemp()
            self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
            sf.write(self._path, self._tensor, samplerate=self.samplerate)
            return self._path


AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText}

# 如果视觉处理库可用,则将 PIL.Image 类型添加到 INSTANCE_TYPE_MAPPING 中
if is_vision_available():
    INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage


def handle_agent_inputs(*args, **kwargs):
    """
    Handles input arguments by converting AgentType objects to their raw form (if applicable).
    """
    # 将参数列表中的 AgentType 对象转换为原始形式(如果是 AgentType 对象)
    args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
    kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
    return args, kwargs


def handle_agent_outputs(outputs, output_types=None):
    """
    Placeholder function to handle agent outputs.
    """
    # 这个函数的具体实现需要进一步补充
    # 检查变量 outputs 是否为字典类型
    if isinstance(outputs, dict):
        # 如果是字典类型,则初始化一个空字典 decoded_outputs
        decoded_outputs = {}
        # 遍历字典 outputs 的键值对
        for i, (k, v) in enumerate(outputs.items()):
            # 如果提供了 output_types 参数
            if output_types is not None:
                # 如果 output_types[i] 在 AGENT_TYPE_MAPPING 中有定义,则使用对应的映射函数转换 v
                if output_types[i] in AGENT_TYPE_MAPPING:
                    decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v)
                else:
                    # 否则使用默认的 AgentType 类型转换 v
                    decoded_outputs[k] = AgentType(v)
            else:
                # 如果未提供 output_types 参数,则根据类型进行映射转换
                for _k, _v in INSTANCE_TYPE_MAPPING.items():
                    if isinstance(v, _k):
                        decoded_outputs[k] = _v(v)
                # 如果找不到合适的映射,则使用默认的 AgentType 类型转换 v
                if k not in decoded_outputs:
                    decoded_outputs[k] = AgentType[v]

    # 如果 outputs 是列表或元组类型
    elif isinstance(outputs, (list, tuple)):
        # 初始化一个与 outputs 类型相同的空对象 decoded_outputs
        decoded_outputs = type(outputs)()
        # 遍历列表或元组 outputs
        for i, v in enumerate(outputs):
            # 如果提供了 output_types 参数
            if output_types is not None:
                # 如果 output_types[i] 在 AGENT_TYPE_MAPPING 中有定义,则使用对应的映射函数转换 v
                if output_types[i] in AGENT_TYPE_MAPPING:
                    decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v))
                else:
                    # 否则使用默认的 AgentType 类型转换 v
                    decoded_outputs.append(AgentType(v))
            else:
                # 如果未提供 output_types 参数,则根据类型进行映射转换
                found = False
                for _k, _v in INSTANCE_TYPE_MAPPING.items():
                    if isinstance(v, _k):
                        decoded_outputs.append(_v(v))
                        found = True
                # 如果找不到合适的映射,则使用默认的 AgentType 类型转换 v
                if not found:
                    decoded_outputs.append(AgentType(v))

    else:
        # 如果 outputs 是其他类型,则处理单个输出的情况
        if output_types[0] in AGENT_TYPE_MAPPING:
            # 如果 output_types[0] 在 AGENT_TYPE_MAPPING 中有定义,则使用对应的映射函数转换 outputs
            decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs)
        else:
            # 否则根据类型进行映射转换
            for _k, _v in INSTANCE_TYPE_MAPPING.items():
                if isinstance(outputs, _k):
                    return _v(outputs)
            # 如果找不到合适的映射,则返回默认的 AgentType 类型转换 outputs
            return AgentType(outputs)

    # 返回转换后的输出结果 decoded_outputs
    return decoded_outputs

.\tools\base.py

# 指定脚本的解释器环境为 Python,并设置编码为 UTF-8

# 导入所需的标准库和第三方库
import base64  # 导入 base64 编解码模块
import importlib  # 导入动态导入模块的模块
import inspect  # 导入检查对象信息的模块
import io  # 导入处理文件流的模块
import json  # 导入处理 JSON 数据的模块
import os  # 导入操作系统相关功能的模块
import tempfile  # 导入临时文件和目录创建功能的模块
from typing import Any, Dict, List, Optional, Union  # 导入类型提示相关模块

# 导入 Hugging Face Hub 相关功能模块
from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session

# 导入自定义模块
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
from ..models.auto import AutoProcessor
from ..utils import (
    CONFIG_NAME,
    cached_file,
    is_accelerate_available,
    is_torch_available,
    is_vision_available,
    logging,
)

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 如果 torch 可用,则导入 torch 库
if is_torch_available():
    import torch

# 如果 accelerate 可用,则导入相关功能
if is_accelerate_available():
    from accelerate import PartialState
    from accelerate.utils import send_to_device

# 定义工具配置文件名
TOOL_CONFIG_FILE = "tool_config.json"


# 定义函数:根据 repo_id 获取仓库类型
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
    # 如果已提供 repo_type,则直接返回
    if repo_type is not None:
        return repo_type
    
    # 尝试下载 repo_id 的配置文件,类型为 "space"
    try:
        hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
        return "space"
    # 如果找不到指定仓库
    except RepositoryNotFoundError:
        # 尝试下载 repo_id 的配置文件,类型为 "model"
        try:
            hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
            return "model"
        # 如果仍然找不到指定仓库,则抛出环境错误
        except RepositoryNotFoundError:
            raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
        # 如果下载过程中出现异常,则默认返回 "model" 类型
        except Exception:
            return "model"
    # 如果下载过程中出现异常,则默认返回 "space" 类型
    except Exception:
        return "space"


# 定义多行字符串模板,用于生成应用文件内容
# docstyle-ignore
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from {module_name} import {class_name}

launch_gradio_demo({class_name})
"""


# 定义工具类:代表代理函数使用的基类
class Tool:
    """
    代理函数使用的基类,实现 `__call__` 方法以及以下类属性:

    - **description** (`str`) -- 工具功能的简要描述,包括预期的输入和输出。例如,'这是一个从 `url` 下载文件的工具。它接受 `url` 作为输入,并返回文件中的文本内容'。
    """
    # 定义一个工具的类,表示一个用户定义的工具
    class Tool:
        # 描述工具的说明
        description: str = "This is a tool that ..."
        # 工具的名称
        name: str = ""
        # 工具接受的输入数据的模态列表
        inputs: List[str]
        # 工具返回的输出数据的模态列表
        outputs: List[str]
    
        # 初始化方法,接受任意数量的位置参数和关键字参数
        def __init__(self, *args, **kwargs):
            # 初始化时标记工具未被初始化
            self.is_initialized = False
    
        # 调用方法,接受任意数量的位置参数和关键字参数
        def __call__(self, *args, **kwargs):
            # 如果未在子类中实现__call__方法,则返回未实现错误
            return NotImplemented("Write this method in your subclass of `Tool`.")
    
        # 执行初始化的方法,用于在使用工具之前执行一些昂贵的操作,比如加载大型模型
        def setup(self):
            # 标记工具被初始化
            self.is_initialized = True
    def save(self, output_dir):
        """
        Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
        tool in `output_dir` as well as autogenerate:

        - a config file named `tool_config.json`
        - an `app.py` file so that your tool can be converted to a space
        - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
          code)

        You should only use this method to save tools that are defined in a separate module (not `__main__`).

        Args:
            output_dir (`str`): The folder in which you want to save your tool.
        """
        # Create the output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

        # Check if the tool's class is defined in the __main__ module
        if self.__module__ == "__main__":
            raise ValueError(
                f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
                "have to put this code in a separate module so we can include it in the saved folder."
            )

        # Save the module files using a custom function
        module_files = custom_object_save(self, output_dir)

        # Get the name of the module containing the class
        module_name = self.__class__.__module__
        last_module = module_name.split(".")[-1]
        full_name = f"{last_module}.{self.__class__.__name__}"

        # Save or update the tool's configuration file
        config_file = os.path.join(output_dir, "tool_config.json")
        if os.path.isfile(config_file):
            # Load existing configuration if file already exists
            with open(config_file, "r", encoding="utf-8") as f:
                tool_config = json.load(f)
        else:
            tool_config = {}

        # Update tool configuration with class information
        tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
        with open(config_file, "w", encoding="utf-8") as f:
            # Write updated configuration to file in a human-readable format
            f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")

        # Save the app.py file using a template specific to the tool
        app_file = os.path.join(output_dir, "app.py")
        with open(app_file, "w", encoding="utf-8") as f:
            # Write the app file content based on a predefined template
            f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))

        # Save the requirements.txt file listing all dependencies
        requirements_file = os.path.join(output_dir, "requirements.txt")
        imports = []
        for module in module_files:
            # Gather all imports used by the modules of the tool
            imports.extend(get_imports(module))
        imports = list(set(imports))  # Ensure uniqueness of imports
        with open(requirements_file, "w", encoding="utf-8") as f:
            # Write each import as a separate line in the requirements file
            f.write("\n".join(imports) + "\n")
    ) -> str:
        """
        Upload the tool to the Hub.

        Parameters:
            repo_id (`str`):
                The name of the repository you want to push your tool to. It should contain your organization name when
                pushing to a given organization.
            commit_message (`str`, *optional*, defaults to `"Upload tool"`):
                Message to commit while pushing.
            private (`bool`, *optional`):
                Whether or not the repository created should be private.
            token (`bool` or `str`, *optional*):
                The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
                when running `huggingface-cli login` (stored in `~/.huggingface`).
            create_pr (`bool`, *optional*, defaults to `False`):
                Whether or not to create a PR with the uploaded files or directly commit.
        """
        # 创建仓库并获取仓库 URL
        repo_url = create_repo(
            repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
        )
        # 更新仓库的元数据,添加标签 "tool"
        repo_id = repo_url.repo_id
        metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")

        # 使用临时目录来保存文件
        with tempfile.TemporaryDirectory() as work_dir:
            # 保存所有文件到临时目录
            self.save(work_dir)
            # 记录日志,显示将要上传的文件列表
            logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
            # 调用上传函数,将临时目录中的文件夹上传到指定仓库
            return upload_folder(
                repo_id=repo_id,
                commit_message=commit_message,
                folder_path=work_dir,
                token=token,
                create_pr=create_pr,
                repo_type="space",
            )

    @staticmethod
    def from_gradio(gradio_tool):
        """
        Creates a [`Tool`] from a gradio tool.
        """
        # 定义一个内部类 GradioToolWrapper,继承自 Tool
        class GradioToolWrapper(Tool):
            def __init__(self, _gradio_tool):
                super().__init__()
                # 初始化名称和描述
                self.name = _gradio_tool.name
                self.description = _gradio_tool.description

        # 将 GradioToolWrapper 的 __call__ 方法设置为 gradio_tool 的 run 方法
        GradioToolWrapper.__call__ = gradio_tool.run
        # 返回创建的 GradioToolWrapper 实例,该实例包装了 gradio_tool
        return GradioToolWrapper(gradio_tool)
# 定义一个名为 RemoteTool 的类,继承自 Tool 类
class RemoteTool(Tool):
    """
    A [`Tool`] that will make requests to an inference endpoint.

    Args:
        endpoint_url (`str`, *optional*):
            The url of the endpoint to use.
        token (`str`, *optional*):
            The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
            running `huggingface-cli login` (stored in `~/.huggingface`).
        tool_class (`type`, *optional`):
            The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
            the output should be converted to another type (like images).
    """

    # 初始化方法,接收三个可选参数:endpoint_url, token, tool_class
    def __init__(self, endpoint_url=None, token=None, tool_class=None):
        # 设置实例变量 endpoint_url,用于存储端点 URL
        self.endpoint_url = endpoint_url
        # 创建 EndpointClient 对象并存储在实例变量 client 中,用于处理与端点的通信
        self.client = EndpointClient(endpoint_url, token=token)
        # 设置实例变量 tool_class,用于存储工具类别信息
        self.tool_class = tool_class

    # 准备输入数据的方法,接收任意数量的位置参数和关键字参数
    def prepare_inputs(self, *args, **kwargs):
        """
        Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
        matched with the signature of the `tool_class` if it was provided at instantiation. Images will be encoded into
        bytes.

        You can override this method in your custom class of [`RemoteTool`].
        """
        # 复制关键字参数到 inputs 字典中
        inputs = kwargs.copy()

        # 如果有位置参数传入
        if len(args) > 0:
            # 如果指定了 tool_class
            if self.tool_class is not None:
                # 匹配位置参数与 tool_class 方法签名
                if issubclass(self.tool_class, PipelineTool):
                    call_method = self.tool_class.encode
                else:
                    call_method = self.tool_class.__call__
                signature = inspect.signature(call_method).parameters
                # 获取方法的参数名,排除可变位置参数和可变关键字参数
                parameters = [
                    k
                    for k, p in signature.items()
                    if p.kind not in [inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD]
                ]
                # 如果方法的第一个参数是 self,则去掉
                if parameters[0] == "self":
                    parameters = parameters[1:]
                # 如果传入的参数多于方法要求的参数数量,抛出 ValueError 异常
                if len(args) > len(parameters):
                    raise ValueError(
                        f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
                    )
                # 将位置参数与参数名对应存入 inputs 字典中
                for arg, name in zip(args, parameters):
                    inputs[name] = arg
            # 如果未指定 tool_class,但传入了多个位置参数,抛出 ValueError 异常
            elif len(args) > 1:
                raise ValueError("A `RemoteTool` can only accept one positional input.")
            # 如果只有一个位置参数,并且是 PIL 图像,则编码为字节流放入 "inputs" 键中
            elif len(args) == 1:
                if is_pil_image(args[0]):
                    return {"inputs": self.client.encode_image(args[0])}
                return {"inputs": args[0]}

        # 对 inputs 中的每个值进行检查,如果是 PIL 图像,则编码为字节流
        for key, value in inputs.items():
            if is_pil_image(value):
                inputs[key] = self.client.encode_image(value)

        # 返回包含编码后数据的字典,键为 "inputs"
        return {"inputs": inputs}
    # 定义一个方法 `extract_outputs`,用于处理端点输出的自定义后处理逻辑
    def extract_outputs(self, outputs):
        """
        You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
        outputs of the endpoint.
        """
        # 默认情况下,直接返回输出结果
        return outputs

    # 定义 `__call__` 方法,使对象可以像函数一样被调用
    def __call__(self, *args, **kwargs):
        # 处理传入的参数,确保它们符合要求
        args, kwargs = handle_agent_inputs(*args, **kwargs)

        # 检查是否需要输出图片,并准备输入数据
        output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
        inputs = self.prepare_inputs(*args, **kwargs)

        # 根据输入的类型调用客户端方法,并传递需要输出图片的信息
        if isinstance(inputs, dict):
            outputs = self.client(**inputs, output_image=output_image)
        else:
            outputs = self.client(inputs, output_image=output_image)

        # 如果输出是一个嵌套列表,并且只有一个元素,将其解包
        if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
            outputs = outputs[0]

        # 处理从客户端获取的输出,应用工具类定义的输出规范(如果有的话)
        outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)

        # 调用 `extract_outputs` 方法处理最终的输出结果,并返回处理后的结果
        return self.extract_outputs(outputs)
# 定义一个 PipelineTool 类,继承自 Tool 类,用于处理 Transformer 模型相关的工具功能
class PipelineTool(Tool):
    """
    A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
    need to specify:

    - **model_class** (`type`) -- The class to use to load the model in this tool.
    - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
    - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
      pre-processor
    - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
      post-processor (when different from the pre-processor).

    Args:
        model (`str` or [`PreTrainedModel`], *optional*):
            The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
            value of the class attribute `default_checkpoint`.
        pre_processor (`str` or `Any`, *optional*):
            The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
            tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
            unset.
        post_processor (`str` or `Any`, *optional*):
            The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
            tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
            unset.
        device (`int`, `str` or `torch.device`, *optional*):
            The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
            CPU otherwise.
        device_map (`str` or `dict`, *optional*):
            If passed along, will be used to instantiate the model.
        model_kwargs (`dict`, *optional*):
            Any keyword argument to send to the model instantiation.
        token (`str`, *optional*):
            The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
            running `huggingface-cli login` (stored in `~/.huggingface`).
        hub_kwargs (additional keyword arguments, *optional*):
            Any additional keyword argument to send to the methods that will load the data from the Hub.
    """

    # 默认的预处理器类为 AutoProcessor
    pre_processor_class = AutoProcessor
    # 模型类,需要根据具体情况指定
    model_class = None
    # 默认的后处理器类也为 AutoProcessor
    post_processor_class = AutoProcessor
    # 默认的检查点名称,当用户未指定时应使用该值
    default_checkpoint = None

    # 初始化方法,接收多个可选参数以配置工具的行为
    def __init__(
        self,
        model=None,
        pre_processor=None,
        post_processor=None,
        device=None,
        device_map=None,
        model_kwargs=None,
        token=None,
        **hub_kwargs,
    ):
        # 检查是否安装了 Torch 库,如果未安装则抛出 ImportError 异常
        if not is_torch_available():
            raise ImportError("Please install torch in order to use this tool.")

        # 检查是否安装了 Accelerate 库,如果未安装则抛出 ImportError 异常
        if not is_accelerate_available():
            raise ImportError("Please install accelerate in order to use this tool.")

        # 如果未提供模型,则尝试使用默认的检查点,如果默认检查点未设置,则抛出 ValueError 异常
        if model is None:
            if self.default_checkpoint is None:
                raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
            model = self.default_checkpoint

        # 如果未提供预处理器,则使用模型作为预处理器
        if pre_processor is None:
            pre_processor = model

        # 设置对象的模型、预处理器、后处理器、设备、设备映射和模型参数
        self.model = model
        self.pre_processor = pre_processor
        self.post_processor = post_processor
        self.device = device
        self.device_map = device_map
        self.model_kwargs = {} if model_kwargs is None else model_kwargs

        # 如果设备映射不为空,则将其添加到模型参数中
        if device_map is not None:
            self.model_kwargs["device_map"] = device_map

        # 将 hub_kwargs 参数添加到对象的属性中
        self.hub_kwargs = hub_kwargs
        self.hub_kwargs["token"] = token

        # 调用父类的构造函数
        super().__init__()

    def setup(self):
        """
        Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
        """
        # 如果预处理器是字符串,则根据预训练模型名称实例化预处理器对象
        if isinstance(self.pre_processor, str):
            self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)

        # 如果模型是字符串,则根据预训练模型名称实例化模型对象
        if isinstance(self.model, str):
            self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)

        # 如果未指定后处理器,则使用预处理器作为后处理器
        if self.post_processor is None:
            self.post_processor = self.pre_processor
        # 如果后处理器是字符串,则根据预训练模型名称实例化后处理器对象
        elif isinstance(self.post_processor, str):
            self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)

        # 如果设备未指定,则根据模型的设备映射设置设备
        if self.device is None:
            if self.device_map is not None:
                self.device = list(self.model.hf_device_map.values())[0]
            else:
                self.device = PartialState().default_device

        # 如果设备映射为空,则将模型移动到设备
        if self.device_map is None:
            self.model.to(self.device)

        # 调用父类的 setup 方法
        super().setup()

    def encode(self, raw_inputs):
        """
        Uses the `pre_processor` to prepare the inputs for the `model`.
        """
        # 使用预处理器对原始输入进行编码
        return self.pre_processor(raw_inputs)

    def forward(self, inputs):
        """
        Sends the inputs through the `model`.
        """
        # 使用模型处理输入数据,并返回输出结果
        with torch.no_grad():
            return self.model(**inputs)

    def decode(self, outputs):
        """
        Uses the `post_processor` to decode the model output.
        """
        # 使用后处理器对模型输出进行解码
        return self.post_processor(outputs)
    # 定义对象的调用方法,接受任意位置参数和关键字参数
    def __call__(self, *args, **kwargs):
        # 使用辅助函数处理输入参数,返回处理后的 args 和 kwargs
        args, kwargs = handle_agent_inputs(*args, **kwargs)

        # 如果对象尚未初始化,则调用 setup 方法进行初始化
        if not self.is_initialized:
            self.setup()

        # 对输入参数进行编码处理,返回编码后的结果
        encoded_inputs = self.encode(*args, **kwargs)
        # 将编码后的输入数据发送到指定设备上
        encoded_inputs = send_to_device(encoded_inputs, self.device)
        # 调用对象的 forward 方法进行前向传播,得到输出结果
        outputs = self.forward(encoded_inputs)
        # 将输出结果发送回 CPU
        outputs = send_to_device(outputs, "cpu")
        # 对输出结果进行解码处理,得到最终的解码输出
        decoded_outputs = self.decode(outputs)

        # 使用辅助函数处理解码后的输出,并返回处理后的结果
        return handle_agent_outputs(decoded_outputs, self.outputs)
# 启动一个 gradio 演示界面,展示特定工具的功能。该工具类需要正确实现类属性 `inputs` 和 `outputs`。
def launch_gradio_demo(tool_class: Tool):
    try:
        import gradio as gr  # 尝试导入 gradio 库
    except ImportError:
        raise ImportError("Gradio 应该安装才能启动 gradio 演示。")

    tool = tool_class()  # 实例化给定的工具类对象

    # 定义一个函数 fn,用来调用工具类实例的 __call__ 方法
    def fn(*args, **kwargs):
        return tool(*args, **kwargs)

    # 创建一个 gr.Interface 对象,配置输入输出和界面的标题和文章描述
    gr.Interface(
        fn=fn,
        inputs=tool_class.inputs,
        outputs=tool_class.outputs,
        title=tool_class.__name__,
        article=tool.description,
    ).launch()  # 启动 gradio 演示界面


# 支持的任务映射关系,将任务 ID 映射到工具类的字符串名称
TASK_MAPPING = {
    "document-question-answering": "DocumentQuestionAnsweringTool",
    "image-captioning": "ImageCaptioningTool",
    "image-question-answering": "ImageQuestionAnsweringTool",
    "image-segmentation": "ImageSegmentationTool",
    "speech-to-text": "SpeechToTextTool",
    "summarization": "TextSummarizationTool",
    "text-classification": "TextClassificationTool",
    "text-question-answering": "TextQuestionAnsweringTool",
    "text-to-speech": "TextToSpeechTool",
    "translation": "TranslationTool",
}


def get_default_endpoints():
    # 获取默认的端点配置文件,并读取其中的端点信息
    endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
    with open(endpoints_file, "r", encoding="utf-8") as f:
        endpoints = json.load(f)  # 解析 JSON 文件中的端点配置信息
    return endpoints


def supports_remote(task_or_repo_id):
    endpoints = get_default_endpoints()  # 获取默认的端点信息
    return task_or_repo_id in endpoints  # 判断给定的任务或库 ID 是否存在于端点信息中


def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
    """
    主函数,快速加载一个工具,无论是在 Hub 上还是在 Transformers 库中。

    <Tip warning={true}>

    加载工具意味着你会下载并在本地执行该工具。
    在加载到运行时之前,始终检查你要下载的工具,就像使用 pip/npm/apt 安装软件包时一样。

    </Tip>
    """
    # 如果给定的任务或模型ID在任务映射中已定义
    if task_or_repo_id in TASK_MAPPING:
        # 获取任务对应的工具类名
        tool_class_name = TASK_MAPPING[task_or_repo_id]
        # 动态导入transformers主模块
        main_module = importlib.import_module("transformers")
        # 获取tools子模块
        tools_module = main_module.tools
        # 根据工具类名获取具体的工具类对象
        tool_class = getattr(tools_module, tool_class_name)

        # 如果选择远程加载模型
        if remote:
            # 如果未提供model_repo_id,则获取默认的端点
            if model_repo_id is None:
                endpoints = get_default_endpoints()
                # 如果任务或模型ID不在默认端点列表中,则抛出值错误
                if task_or_repo_id not in endpoints:
                    raise ValueError(
                        f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
                        "`model_repo_id` argument."
                    )
                # 使用获取到的默认端点作为模型仓库ID
                model_repo_id = endpoints[task_or_repo_id]
            # 返回一个远程工具对象,包括模型仓库ID和token
            return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
        else:
            # 直接实例化本地工具对象,传入模型仓库ID和额外的关键字参数kwargs
            return tool_class(model_repo_id, token=token, **kwargs)
    else:
        # 如果任务或模型ID未定义在任务映射中,则发出警告
        logger.warning_once(
            f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
            f"trust as the code within that tool will be executed on your machine. Always verify the code of "
            f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
            f"code that you have checked."
        )
        # 从Hub加载工具对象,传入任务或模型ID、模型仓库ID、token、远程标志和其他关键字参数kwargs
        return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
# 为函数添加描述信息的装饰器
def add_description(description):
    """
    A decorator that adds a description to a function.
    """

    def inner(func):
        # 将描述信息添加到函数对象的属性中
        func.description = description
        # 记录函数的名称
        func.name = func.__name__
        return func

    return inner


## Will move to the Hub
# EndpointClient 类,用于管理与端点通信的客户端
class EndpointClient:
    def __init__(self, endpoint_url: str, token: Optional[str] = None):
        # 构建 HTTP 请求头部信息,包括访问令牌
        self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
        # 记录端点的 URL 地址
        self.endpoint_url = endpoint_url

    @staticmethod
    def encode_image(image):
        # 将图像编码为 PNG 格式的 Base64 字符串
        _bytes = io.BytesIO()
        image.save(_bytes, format="PNG")
        b64 = base64.b64encode(_bytes.getvalue())
        return b64.decode("utf-8")

    @staticmethod
    def decode_image(raw_image):
        # 解码 Base64 字符串为图像对象
        if not is_vision_available():
            # 如果 Pillow 库不可用,抛出 ImportError 异常
            raise ImportError(
                "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
            )

        from PIL import Image

        b64 = base64.b64decode(raw_image)
        _bytes = io.BytesIO(b64)
        return Image.open(_bytes)

    def __call__(
        self,
        inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
        params: Optional[Dict] = None,
        data: Optional[bytes] = None,
        output_image: bool = False,
    ) -> Any:
        # 构建请求的有效负载
        payload = {}
        if inputs:
            payload["inputs"] = inputs
        if params:
            payload["parameters"] = params

        # 发起 API 调用
        response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)

        # 根据需要输出图像或解析 JSON 响应
        if output_image:
            # 如果需要输出图像,则解码 API 响应的图像数据
            return self.decode_image(response.content)
        else:
            # 否则解析 API 响应的 JSON 数据
            return response.json()

.\tools\document_question_answering.py

#!/usr/bin/env python
# coding=utf-8

# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re  # 导入正则表达式模块

from ..models.auto import AutoProcessor  # 导入自动处理器模块
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel  # 导入视觉编码解码模型
from ..utils import is_vision_available  # 导入视觉功能可用性检查函数
from .base import PipelineTool  # 导入流水线工具基类


if is_vision_available():  # 如果视觉功能可用
    from PIL import Image  # 导入图像处理库PIL中的Image模块


class DocumentQuestionAnsweringTool(PipelineTool):
    default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"  # 默认检查点路径
    description = (
        "This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
        "should be the document containing the information, as well as a `question` that is the question about the "
        "document. It returns a text that contains the answer to the question."
    )  # 工具描述
    name = "document_qa"  # 工具名称
    pre_processor_class = AutoProcessor  # 预处理器类
    model_class = VisionEncoderDecoderModel  # 模型类

    inputs = ["image", "text"]  # 输入类型:图像、文本
    outputs = ["text"]  # 输出类型:文本

    def __init__(self, *args, **kwargs):
        if not is_vision_available():  # 如果视觉功能不可用
            raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")  # 抛出数值错误

        super().__init__(*args, **kwargs)  # 调用父类初始化方法

    def encode(self, document: "Image", question: str):
        task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"  # 任务提示字符串模板
        prompt = task_prompt.replace("{user_input}", question)  # 根据问题替换用户输入
        decoder_input_ids = self.pre_processor.tokenizer(  # 使用预处理器的分词器对提示进行编码
            prompt, add_special_tokens=False, return_tensors="pt"
        ).input_ids
        pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values  # 使用预处理器处理文档图像,获取像素值

        return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}  # 返回编码的输入数据字典

    def forward(self, inputs):
        return self.model.generate(  # 使用模型生成答案
            inputs["pixel_values"].to(self.device),  # 图像像素值移到指定设备
            decoder_input_ids=inputs["decoder_input_ids"].to(self.device),  # 解码器输入IDs移到指定设备
            max_length=self.model.decoder.config.max_position_embeddings,  # 最大生成长度
            early_stopping=True,  # 开启早停机制
            pad_token_id=self.pre_processor.tokenizer.pad_token_id,  # 填充标记ID
            eos_token_id=self.pre_processor.tokenizer.eos_token_id,  # 结束标记ID
            use_cache=True,  # 使用缓存
            num_beams=1,  # 波束搜索数量
            bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],  # 不良词汇ID
            return_dict_in_generate=True,  # 生成时返回字典格式
        ).sequences  # 返回生成的序列
    # 定义一个方法 `decode`,接收 `self` 和 `outputs` 作为参数
    def decode(self, outputs):
        # 使用预处理器的方法对输出进行批量解码,并取第一个序列
        sequence = self.pre_processor.batch_decode(outputs)[0]
        # 替换序列中的结束标记为""
        sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
        # 替换序列中的填充标记为""
        sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
        # 使用正则表达式移除序列中第一个出现的任务开始标记
        sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
        # 使用预处理器的方法将序列转换为 JSON 格式
        sequence = self.pre_processor.token2json(sequence)

        # 返回 JSON 格式结果中的答案部分
        return sequence["answer"]

.\tools\evaluate_agent.py

### Fake tools for test
# 定义一系列用于测试目的的假工具函数

def classifier(text, labels):
    # 返回一个文本分类的结果字符串
    return f"This is the classification of {text} along {labels}."

def translator(text, src_lang, tgt_lang):
    # 返回一个文本翻译的结果字符串
    return f"This is the translation of {text} from {src_lang} to {tgt_lang}."

def speaker(text):
    # 返回一个将文本朗读成声音的结果字符串
    return f"This is actually a sound reading {text}."

def transcriber(audio):
    # 如果输入不是声音,则抛出值错误异常
    if "sound" not in audio:
        raise ValueError(f"`audio` ({audio}) is not a sound.")
    # 返回从声音转录出的文本结果字符串
    return f"This is the transcribed text from {audio}."

def image_generator(prompt):
    # 返回一个表示给定提示的图像的结果字符串
    return f"This is actually an image representing {prompt}."

def image_captioner(image):
    # 如果输入不是图像,则抛出值错误异常
    if "image" not in image:
        raise ValueError(f"`image` ({image}) is not an image.")
    # 返回给定图像的描述结果字符串
    return f"This is a description of {image}."

def image_transformer(image, prompt):
    # 如果输入不是图像,则抛出值错误异常
    if "image" not in image:
        raise ValueError(f"`image` ({image}) is not an image.")
    # 返回根据给定提示对图像进行转换的结果字符串
    return f"This is a transformation of {image} according to {prompt}."

def question_answerer(text, question):
    # 返回对给定问题从文本中得到的答案结果字符串
    return f"This is the answer to {question} from {text}."

def image_qa(image, question):
    # 如果输入不是图像,则抛出值错误异常
    if "image" not in image:
        raise ValueError(f"`image` ({image}) is not an image.")
    # 返回对给定问题从图像中得到的答案结果字符串
    return f"This is the answer to {question} from {image}."

def text_downloader(url):
    # 返回从给定 URL 下载的内容结果字符串
    return f"This is the content of {url}."

def summarizer(text):
    # 返回给定文本的摘要结果字符串
    return f"This is a summary of {text}."

def video_generator(prompt, seconds=2):
    # 返回一个包含给定提示的视频结果字符串
    return f"A video of {prompt}"

def document_qa(image, question):
    # 返回对给定问题从文档图像中得到的答案结果字符串
    return f"This is the answer to {question} from the document {image}."

def image_segmenter(image, prompt):
    # 返回在给定图像中对给定提示的分割结果字符串
    return f"This is the mask of {prompt} in {image}"

TEST_TOOLS = {
    "text_classifier": classifier,
    "translator": translator,
    "text_reader": speaker,
    "summarizer": summarizer,
    "transcriber": transcriber,
    "image_generator": image_generator,
    "image_captioner": image_captioner,
    "image_transformer": image_transformer,
    "text_qa": question_answerer,
    "text_downloader": text_downloader,
    "image_qa": image_qa,
    "video_generator": video_generator,
    "document_qa": document_qa,
    "image_segmenter": image_segmenter,
}

class Problem:
    """
    占位符类,暂时没有定义任何内容
    """
    # 一个类,用于组织解决问题所需的所有信息,以便评估代理程序。
    
    Args:
        task (`str` 或 `list[str]`):
            要执行任务的一个或多个描述。如果是列表,则应包含相同任务的不同表达方式。
        inputs (`list[str]` 或 `dict[str, str]`):
            将提供给工具的输入。在这个测试环境中,只接受字符串作为值。当你想要指定每个输入的值时,请传递一个字典;或者直接传递期望的输入列表(在这种情况下,使用 `<<input_name>>` 作为值)。
        answer (`str` 或 `list[str]`):
            问题的理论答案(或可能的有效答案列表),作为代码。
    """
    
    # 初始化方法,用于设置实例的属性
    def __init__(self, task, inputs, answer):
        self.task = task      # 将传入的任务描述存储在实例的属性中
        self.inputs = inputs  # 将传入的输入数据存储在实例的属性中
        self.answer = answer  # 将传入的答案数据存储在实例的属性中
### 定义一个评估任务列表,包含多个问题实例

EVALUATION_TASKS = [
    # 定义一个问题实例,任务是判断给定的 `text`(西班牙语)是积极还是消极的
    Problem(
        task=[
            "Is the following `text` (in Spanish) positive or negative?",
            "Is the text in the variable `text` (in Spanish) positive or negative?",
            "Translate the following `text` from Spanish to English then tell me if its positive or negative.",
        ],
        inputs=["text"],
        # 答案是一个字符串表达式,调用了多个函数来进行文本处理和分类
        answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
    ),

    # 定义一个问题实例,任务是描述给定的 `image` 包含的内容
    Problem(
        task=[
            "Tell me out loud what the `image` contains.",
            "Describe the following `image` out loud.",
            "Find what is in the picture stored in `image` then read it out loud.",
        ],
        inputs=["image"],
        # 答案是一个列表,包含了两种描述图片内容的方法
        answer=[
            "text_reader(image_captioner(image))",
            "text_reader(image_qa(image, question='What is in the image?'))",
        ],
    ),

    # 定义一个问题实例,任务是根据 `text_input` 生成图片,然后根据 `prompt` 进行变换
    Problem(
        task=[
            "Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
            "Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
        ],
        inputs=["text_input", "prompt"],
        # 答案是一个字符串,调用了多个函数来生成并变换图片
        answer="image_transformer(image_generator(text_input), prompt)",
    ),

    # 定义一个问题实例,任务是根据 `url` 下载内容,进行摘要并生成一张图片
    Problem(
        task=[
            "Download the content of `url`, summarize it then generate an image from its content.",
            "Use a summary of the web page at `url` to generate an image.",
            "Summarize the content of the web page at `url`, and use the result to generate an image.",
        ],
        inputs=["url"],
        # 答案是一个字符串,调用了多个函数来下载、摘要并生成图片
        answer="image_generator(summarizer(text_downloader(url)))",
    ),

    # 定义一个问题实例,任务是根据 `text` 和 `image` 进行图片的文本提示变换
    Problem(
        task=[
            "Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
            "Use the text prompt in `text` (in Spanish) to transform the following `image`.",
            "Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
        ],
        inputs=["text", "image"],
        # 答案是一个字符串,调用了多个函数来进行图片的文本提示变换
        answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
    ),

    # 定义一个问题实例,任务是根据 `url` 下载内容,进行摘要并朗读摘要
    Problem(
        task=[
            "Download the content of `url`, summarize it then read it out loud to me.",
            "Read me a summary of the web page at `url`.",
        ],
        inputs=["url"],
        # 答案是一个字符串,调用了多个函数来下载、摘要并朗读摘要
        answer="text_reader(summarizer(text_downloader(url)))",
    ),

    # 定义一个问题实例,任务是根据 `text_input` 生成一张图片
    Problem(
        task=[
            "Generate an image from the text given in `text_input`.",
        ],
        inputs=["text_input"],
        # 答案是一个字符串,调用了一个函数来生成图片
        answer="image_generator(text_input)",
    ),
]
    Problem(
        task=[
            "Replace the beaver in the `image` by the `prompt`.",
            "Transform the `image` so that it contains the `prompt`.",
            "Use `prompt` to transform this `image`.",
        ],
        inputs=["image", "prompt"],
        answer="image_transformer(image, prompt)",
    ),
    Problem(
        task=[
            "Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
            "Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
            "Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
        ],
        inputs=["text"],
        answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
    ),
    Problem(
        task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
        inputs={"prompt": "A lobster swimming"},
        answer="video_generator('A lobster swimming')",
    ),
    Problem(
        task=[
            "Download the following file `url`, summarize it in a few words and generate a video from it."
            "Fetch the file at this `url`, summarize it, and create an animation out of it."
        ],
        inputs=["url"],
        answer="video_generator(summarizer(text_downloader(url)))",
    ),



    Problem(
        task=[
            "Replace the beaver in the `image` by the `prompt`.",
            "Transform the `image` so that it contains the `prompt`.",
            "Use `prompt` to transform this `image`.",
        ],
        inputs=["image", "prompt"],
        answer="image_transformer(image, prompt)",
    ),
    # 创建一个 Problem 对象,包含了替换图片中的某物体和转换图片的任务,输入为图片和提示,答案为调用 image_transformer 函数
    Problem(
        task=[
            "Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
            "Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
            "Read me a summary of the `text` out loud. Transcribe this and translate it in French.",
        ],
        inputs=["text"],
        answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
    ),
    # 创建一个 Problem 对象,包含了对文本进行摘要、朗读、转录和翻译任务,输入为文本,答案为复合函数调用
    Problem(
        task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
        inputs={"prompt": "A lobster swimming"},
        answer="video_generator('A lobster swimming')",
    ),
    # 创建一个 Problem 对象,包含了根据提示生成视频的任务,输入为提示,答案为调用 video_generator 函数
    Problem(
        task=[
            "Download the following file `url`, summarize it in a few words and generate a video from it."
            "Fetch the file at this `url`, summarize it, and create an animation out of it."
        ],
        inputs=["url"],
        answer="video_generator(summarizer(text_downloader(url)))",
    ),
    # 创建一个 Problem 对象,包含了从 URL 下载文件、摘要并生成视频的任务,输入为 URL,答案为调用 video_generator 函数
EVALUATION_CHATS = [
    [  # 开始一个列表,包含多个问题对象
        Problem(  # 创建第一个问题对象
            task=[  # 问题描述列表
                "Translate the following `text` from Spanish to English.",  # 翻译从西班牙语到英语的文本
                "Translate the following `text` from Spanish to English.",  # 同上,重复描述
            ],
            inputs=["text"],  # 输入参数为一个文本字符串
            answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",  # 答案是调用翻译器函数进行翻译
        ),
        Problem(  # 创建第二个问题对象
            task=[  # 问题描述列表
                "Is it positive or negative?",  # 判断文本情感是积极还是消极
                "Tell me if its positive or negative.",  # 同上,重复描述
            ],
            inputs=[],  # 无输入参数
            answer="text_classifier(translated_text, labels=['positive', 'negative'])",  # 使用文本分类器判断文本情感
        ),
    ],
    [  # 开始第二个问题列表
        Problem(  # 创建第一个问题对象
            task=[  # 问题描述列表
                "What does this `image` contain?",  # 描述图像包含的内容
                "Describe the following `image`.",  # 描述以下的图像
                "Find what is in the picture stored in `image`",  # 找出存储在 `image` 中图片的内容
            ],
            inputs=["image"],  # 输入参数为一个图像
            answer=[  # 答案是一个包含两个动作的列表
                "description=image_captioner(image)",  # 生成图像描述
                "description=image_qa(image, question='What is in the image?')",  # 使用图像问答系统找出图像中的内容
            ],
        ),
        Problem(  # 创建第二个问题对象
            task=[  # 问题描述列表
                "Now, read the description out loud.",  # 现在大声朗读描述
                "Great! Can you read it out loud?",  # 太棒了!你能大声朗读吗?
                "Read it out loud.",  # 大声朗读
            ],
            inputs=[],  # 无输入参数
            answer=["audio=text_reader(description)", "audio=text_reader(description)"],  # 生成描述的语音输出
        ),
    ],
    [  # 开始第三个问题列表
        Problem(  # 创建第一个问题对象
            task=[  # 问题描述列表
                "Generate an image from the text given in `text_input`.",  # 使用 `text_input` 中的文本生成图像
                "Use the following `text_input` to generate an image",  # 使用以下 `text_input` 生成图像
            ],
            inputs=["text_input"],  # 输入参数为一个文本输入
            answer="image = image_generator(text_input)",  # 生成图像的操作
        ),
        Problem(  # 创建第二个问题对象
            task=[  # 问题描述列表
                "Transform it according to the text in `prompt`.",  # 根据 `prompt` 中的文本对图像进行转换
                "Transform it by using the text in `prompt`.",  # 使用 `prompt` 中的文本进行转换
            ],
            inputs=["prompt"],  # 输入参数为一个提示文本
            answer="image_transformer(image, prompt)",  # 对图像进行转换的操作
        ),
    ],
    [  # 开始第四个问题列表
        Problem(  # 创建第一个问题对象
            task=[  # 问题描述列表
                "Download the content of `url` and summarize it.",  # 下载 `url` 的内容并进行摘要
                "Summarize the content of the web page at `url`.",  # 总结位于 `url` 的网页内容
            ],
            inputs=["url"],  # 输入参数为一个 URL
            answer="summary = summarizer(text_downloader(url))",  # 使用文本下载器下载内容并进行摘要生成
        ),
        Problem(  # 创建第二个问题对象
            task=[  # 问题描述列表
                "Generate an image from its content.",  # 从其内容生成一幅图像
                "Use the previous result to generate an image.",  # 使用上述结果生成图像
            ],
            inputs=[],  # 无输入参数
            answer="image_generator(summary)",  # 根据摘要内容生成图像
        ),
    ],
]
    [
        # 第一个问题组
        Problem(
            # 任务描述:将这段西班牙文`text`翻译成英文。
            task=[
                "Translate this Spanish `text` in English.",
                "Translate the `text` from Spanish to English.",
            ],
            # 输入参数:text,需要翻译的文本
            inputs=["text"],
            # 答案:调用translator函数进行翻译,从西班牙语到英语
            answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
        ),
        Problem(
            # 任务描述:使用翻译后的`text`来转换以下的`image`。
            task=[
                "Transform the following `image` using the translated `text`.",
                "Use the previous result to transform the following `image`.",
            ],
            # 输入参数:image,需要进行转换的图像;translated_text,已翻译的文本
            inputs=["image"],
            # 答案:调用image_transformer函数,使用翻译后的文本来转换图像
            answer="image_transformer(image, translated_text)",
        ),
    ],
    [
        # 第二个问题组
        Problem(
            # 任务描述:下载`url`的内容。
            task=["Download the content of `url`.", "Get me the text on the web page `url`."],
            # 输入参数:url,需要下载内容的网址
            inputs=["url"],
            # 答案:调用text_downloader函数下载网页内容
            answer="text = text_downloader(url)",
        ),
        Problem(
            # 任务描述:对文本进行总结。
            task=["Summarize this text.", "Summarize this text."],
            # 输入参数:无(使用前面下载的文本)
            inputs=[],
            # 答案:调用summarizer函数对文本进行总结
            answer="summary = summarizer(text)",
        ),
        Problem(
            # 任务描述:朗读给我听。
            task=["Read it out loud to me.", "Read me the previous result."],
            # 输入参数:无(使用前面生成的总结文本)
            inputs=[],
            # 答案:调用text_reader函数朗读总结文本
            answer="text_reader(summary)",
        ),
    ],
    [
        # 第三个问题组
        Problem(
            # 任务描述:根据给定的`text_input`生成一张图像。
            task=["Generate an image from the text given in `text_input`."],
            # 输入参数:text_input,用于生成图像的文本输入
            inputs=["text_input"],
            # 答案:调用image_generator函数生成图像
            answer="image_generator(text_input)",
        ),
    ],
    [
        # 第四个问题组
        Problem(
            # 任务描述:用`prompt`替换`image`中的海狸。
            task=[
                "Replace the beaver in the `image` by the `prompt`.",
                "Transform the `image` so that it contains the `prompt`.",
                "Use `prompt` to transform this `image`.",
            ],
            # 输入参数:image,需要进行转换的图像;prompt,用于替换的提示
            inputs=["image", "prompt"],
            # 答案:调用image_transformer函数,使用prompt来转换图像
            answer="image_transformer(image, prompt)",
        ),
    ],
    [
        # 第五个问题组
        Problem(
            # 任务描述:提供`text`的摘要。
            task=["Provide me the summary of the `text`.", "Summarize `text`."],
            # 输入参数:text,需要进行总结的文本
            inputs=["text"],
            # 答案:调用summarizer函数对文本进行总结
            answer="summary = summarizer(text)",
        ),
        Problem(
            # 任务描述:将摘要朗读给我听。
            task=["Read this summary to me.", "Read it out loud."],
            # 输入参数:无(使用前面生成的总结文本)
            inputs=[],
            # 答案:调用text_reader函数朗读总结文本
            answer="audio = text_reader(summarizer(text))",
        ),
        Problem(
            # 任务描述:将上一结果转录成文本。
            task=["Transcribing the previous result back in text.", "Transcribe the audio."],
            # 输入参数:无(使用前面生成的音频)
            inputs=[],
            # 答案:调用transcriber函数将音频转录成文本
            answer="text = transcriber(audio)",
        ),
        Problem(
            # 任务描述:将上一结果翻译成法语。
            task=["Translating the last result in French.", "Translate this in French."],
            # 输入参数:无(使用前面生成的文本)
            inputs=[],
            # 答案:调用translator函数将文本从英语翻译成法语
            answer="translator(text, src_lang='English', tgt_lang='French')",
        ),
    ],
    [
        # 第六个问题组
        Problem(
            # 任务描述:根据`prompt`生成一个视频。
            task=[
                "Generate a video of the `prompt`",
                "Animate a `prompt`",
                "Make me a short video using `prompt`.",
            ],
            # 输入参数:prompt,用于生成视频的提示文本
            inputs={"prompt": "A lobster swimming"},
            # 答案:调用video_generator函数生成视频
            answer="video_generator('A lobster swimming')",
        ),
    ],
    [
        # 创建一个包含两个问题的列表,每个问题包括任务描述、输入要求和答案方法
        Problem(
            # 第一个问题的任务描述
            task=[
                "Download the content of `url` and summarize it.",
                "Summarize the content of the web page at `url`."
            ],
            # 第一个问题的输入要求,需要一个参数 `url`
            inputs=["url"],
            # 第一个问题的答案方法,使用 `text_downloader` 下载 `url` 的内容,然后使用 `summarizer` 进行总结
            answer="summary = summarizer(text_downloader(url))"
        ),
        # 第二个问题的问题描述
        Problem(
            task=["generate a video from it.", "Create an animation from the last result."],
            # 第二个问题没有输入要求,所以是一个空列表
            inputs=[],
            # 第二个问题的答案方法,使用上一个问题中生成的 `summary` 来生成视频
            answer="video_generator(summary)"
        ),
    ],
# 定义函数,用于获取理论工具集和代码中实际使用的工具集的比较结果
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
    # 如果理论答案不是列表,则返回代码中的测试工具集合
    if not isinstance(theoretical_answer, list):
        return {name for name in TEST_TOOLS if name in code_answer}

    # 如果代理答案是字典类型,则逐个比较理论答案和代码答案
    if isinstance(agent_answer, dict):
        for one_answer, one_code in zip(theoretical_answer, code_answer):
            # 如果代理答案的值包含在理论答案中,则返回在代码中使用的测试工具集合
            if one_answer in agent_answer.values():
                return {name for name in TEST_TOOLS if name in one_code}

    # 逐个比较理论答案和代码答案
    for one_answer, one_code in zip(theoretical_answer, code_answer):
        # 如果代理答案等于理论答案之一,则返回在代码中使用的测试工具集合
        if agent_answer == one_answer:
            return {name for name in TEST_TOOLS if name in one_code}

    # 返回代码中使用的第一个测试工具集合
    return {name for name in TEST_TOOLS if name in code_answer[0]}


# 定义函数,评估给定的代码
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
    # 复制基本的 Python 工具集到当前工具集中
    tools = BASE_PYTHON_TOOLS.copy()
    
    # 遍历测试工具集合,将代码中使用的工具添加到当前工具集中
    for name, tool in TEST_TOOLS.items():
        if name not in code:
            continue
        tools[name] = tool

    # 如果输入是字典类型,则复制一份输入
    if isinstance(inputs, dict):
        inputs = inputs.copy()
    # 如果输入不为空,则将每个输入映射到特定的占位符格式
    elif inputs is not None:
        inputs = {inp: f"<<{inp}>>" for inp in inputs}

    # 如果状态不为空,则更新状态信息,否则使用输入作为状态信息
    if state is not None:
        state.update(inputs)
    else:
        state = inputs

    try:
        # 尝试评估代码,使用当前工具集和状态
        return evaluate(code, tools, state)
    except InterpretorError as e:
        # 如果发生解释器错误,则返回错误消息字符串
        return str(e)
    except Exception as e:
        # 如果发生其他异常,根据 verbose 参数决定是否打印异常信息,并返回 None
        if verbose:
            print(e)
        return None


# 定义函数,评分给定的代码答案
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
    # 如果 verbose 为 True,则打印代理答案和理论答案
    if verbose:
        print(agent_answer, theoretical_answer)
    
    # 如果理论答案不是列表,则将其转换为列表形式
    theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]

    # 如果代理答案包含在理论答案中,则返回完美匹配的分数 1.0
    if agent_answer in theoretical_answer:
        if verbose:
            print("Perfect!")
        return 1
    # 如果代理答案是字典类型,并且其值在理论答案中,则返回部分匹配的分数 0.75
    elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
        if verbose:
            print("Almost perfect, result in state!")
        return 0.75
    # 否则,返回未完全匹配的分数 0.3
    else:
        if verbose:
            print("Result is not the right one but code executed.")
        return 0.3


# 定义函数,评估单个结果的解释
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
    # 提取解释中使用的工具集合
    tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
    
    # 获取理论工具集和代码实际使用的工具集的比较结果
    theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
    
    # 如果解释中使用的工具集与理论工具集完全匹配,则设置工具选择分数为 1.0,无工具选择错误
    if tools_in_explanation == theoretical_tools:
        tool_selection_score = 1.0
        tool_selection_errors = None
    else:
        # 否则,计算缺失的工具和意外的工具数量,并计算工具选择分数
        missing_tools = len(theoretical_tools - tools_in_explanation)
        unexpected_tools = len(tools_in_explanation - theoretical_tools)
        tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)

        # 设置工具选择错误信息
        tool_selection_errors = {
            "selected_tools": tools_in_explanation,
            "theoretical_tools": theoretical_tools,
        }

    # 提取代码中使用的工具集合
    tools_in_code = {name for name in TEST_TOOLS if name in code}
    # 如果代码中使用的工具与理论工具相匹配
    if tools_in_code == theoretical_tools:
        # 工具使用得分为满分 1.0
        tool_used_score = 1.0
        # 错误信息为空
        tool_used_errors = None
    else:
        # 计算缺失的工具数量
        missing_tools = len(theoretical_tools - tools_in_code)
        # 计算多余的工具数量
        unexpected_tools = len(tools_in_code - theoretical_tools)
        # 计算工具使用得分,考虑缺失工具和多余工具的惩罚
        tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)

        # 生成工具使用错误信息,包含选中的工具和理论上应有的工具
        tool_used_errors = {
            "selected_tools": tools_in_explanation,
            "theoretical_tools": theoretical_tools,
        }

    # 对代码进行评分,返回评分结果
    score = score_code(agent_answer, theoretical_answer, verbose=verbose)
    # 如果评分小于 1.0
    if score < 1.0:
        # 生成代码错误信息,包含生成的代码、评估结果和理论答案
        code_errors = {
            "code_produced": code,
            "evaluation": agent_answer,
            "theoretical_answer": theoretical_answer,
        }
    else:
        # 如果评分为满分,错误信息为空
        code_errors = None

    # 返回工具选择得分、工具使用得分、代码评分以及相应的错误信息
    return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
# 对代理工具进行一致性检查,确保包含所有必需的测试工具
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
    # 计算缺失的工具和多余的工具,并引发值错误
    missing_tools = set(TEST_TOOLS) - agent_tools
    unexpected_tools = agent_tools - set(TEST_TOOLS)
    raise ValueError(
        f"Fix the test tools in the evaluate_agent module. Tools missing: {missing_tools}. Extra tools: {unexpected_tools}."
    )

# 初始化评估任务列表和其对应的索引列表
eval_tasks = []
eval_idx = []
for idx, pb in enumerate(EVALUATION_TASKS):
    if isinstance(pb.task, list):
        # 将任务列表展开,并更新索引列表
        eval_tasks.extend(pb.task)
        eval_idx.extend([idx] * len(pb.task))
    else:
        # 添加单个任务及其索引
        eval_tasks.append(pb.task)
        eval_idx.append(idx)

# 初始化评分变量
tool_selection_score = 0
tool_used_score = 0
code_score = 0

# 如果需要返回错误信息,则初始化错误字典
if return_errors:
    tool_selection_errors = {}
    tool_used_errors = {}
    code_errors = {}

# 分批次处理评估任务
for start_idx in range(0, len(eval_tasks), batch_size):
    end_idx = min(start_idx + batch_size, len(eval_tasks))
    batch_tasks = eval_tasks[start_idx:end_idx]

    # 根据任务生成相应的提示语句
    prompts = [agent.format_prompt(task) for task in batch_tasks]
    # 代理执行生成代码任务,停止条件为 "Task:"
    results = agent.generate_many(prompts, stop=["Task:"])

    # 遍历每个任务结果
    for idx, result in enumerate(results):
        # 获取当前任务的问题和答案
        problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
        if verbose:
            # 如果启用了详细输出,打印任务内容
            print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")

        # 清理生成的代码并准备执行
        explanation, code = agent.clean_code_for_run(result)

        # 评估代理的答案和生成的代码答案
        agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
        if isinstance(problem.answer, list):
            theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
        else:
            theoretical_answer = evaluate_code(problem.answer, problem.inputs)

        # 调用评估函数,获取得分和可能的错误
        scores, errors = evaluate_one_result(
            explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
        )

        # 累加各项得分
        tool_selection_score += scores[0]
        tool_used_score += scores[1]
        code_score += scores[2]

        # 如果需要记录错误信息,则将其添加到相应的错误字典中
        if return_errors:
            if errors[0] is not None:
                tool_selection_errors[batch_tasks[idx]] = errors[0]
            if errors[1] is not None:
                tool_used_errors[batch_tasks[idx]] = errors[1]
            if errors[2] is not None:
                code_errors[batch_tasks[idx]] = errors[2]
    # 计算并构建评分字典,包括工具选择、工具使用和代码评分,每项分数都是相对于评估任务数量的百分比
    scores = {
        "tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
        "tool used score": 100 * (tool_used_score / len(eval_tasks)),
        "code score": 100 * (code_score / len(eval_tasks)),
    }

    # 如果需要返回错误信息,则返回评分字典和各类错误列表;否则,仅返回评分字典
    if return_errors:
        return scores, tool_selection_errors, tool_used_errors, code_errors
    else:
        return scores
# 对给定的代理程序进行评估,检查其是否具备正确的工具集
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
    """
    Evaluates a new agent on all `EVALUATION_CHATS`.

    Example:

    ```
    agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
    bads = new_evaluate_agent(agent)
    for bad in bads:
        print(bad)
    ```
    """
    # 检查代理程序的工具集合是否与预期的测试工具集合一致
    agent_tools = set(agent.toolbox.keys())
    if agent_tools != set(TEST_TOOLS):
        # 计算缺失的工具和多余的工具
        missing_tools = set(TEST_TOOLS) - agent_tools
        unexpected_tools = agent_tools - set(TEST_TOOLS)
        # 抛出数值错误,指示需要修复评估模块中的测试工具
        raise ValueError(
            f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
        )

    # 初始化评分变量
    tool_selection_score = 0
    tool_used_score = 0
    code_score = 0
    total_steps = 0

    # 如果需要返回错误信息,初始化错误字典
    if return_errors:
        tool_selection_errors = {}
        tool_used_errors = {}
        code_errors = {}
    # 遍历评估对话中的每个问题
    for chat_problem in EVALUATION_CHATS:
        # 检查第一个任务是否为字符串,若是则标记为已解决的问题列表
        if isinstance(chat_problem[0].task, str):
            resolved_problems = [chat_problem]
        else:
            # 否则,根据每个任务生成一个新的Problem对象列表
            resolved_problems = [
                [Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
                for i in range(len(chat_problem[0].task))
            ]
        
        # 遍历解决的问题列表
        for problem in resolved_problems:
            # 准备Agent进行新对话的准备工作
            agent.prepare_for_new_chat()
            agent_state = {}  # 重置Agent的状态
            # 根据第一个答案是否为列表,确定理论状态的初始化方式
            theoretical_state = (
                [{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
            )
            
            # 遍历每个问题中的每个步骤
            for step, step_problem in enumerate(problem):
                # 如果设定了详细输出模式,打印当前任务描述
                if verbose:
                    print(step_problem.task)
                
                total_steps += 1  # 总步数加一
                # 格式化Agent的提示信息,准备生成一条对话
                prompt = agent.format_prompt(step_problem.task, chat_mode=True)
                # 生成Agent的回答,同时设定停止词以防止过长输出
                result = agent.generate_one(prompt, stop=["Human:", "====="])
                # 将生成的对话历史记录保存到Agent的聊天历史中
                agent.chat_history = prompt + result + "\n"

                # 清理生成的代码,获取解释和代码本身
                explanation, code = clean_code_for_chat(result)

                # 如果设定了详细输出模式,打印Agent生成的解释和代码
                if verbose:
                    print(f"==Explanation from the agent==\n{explanation}")
                    print(f"\n==Code generated by the agent==\n{code}")

                # 评估Agent的回答和生成的代码
                agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)

                answer = step_problem.answer
                if isinstance(answer, list):
                    # 若答案为列表,计算每个理论答案对应的状态
                    theoretical_answer = [
                        evaluate_code(a, step_problem.inputs, state=state)
                        for a, state in zip(answer, theoretical_state)
                    ]
                else:
                    # 否则,直接计算理论答案
                    theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)

                # 评估一次结果,获取分数和可能的错误信息
                scores, errors = evaluate_one_result(
                    explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
                )

                # 累加工具选择得分、工具使用得分和代码得分
                tool_selection_score += scores[0]
                tool_used_score += scores[1]
                code_score += scores[2]

                # 如果需要返回错误信息,记录工具选择、工具使用和代码错误
                if return_errors:
                    if errors[0] is not None:
                        tool_selection_errors[step_problem.task] = errors[0]
                    if errors[1] is not None:
                        tool_used_errors[step_problem.task] = errors[1]
                    if errors[2] is not None:
                        code_errors[step_problem.task] = errors[2]

    # 计算并返回总体得分,根据需要返回错误信息
    scores = {
        "tool selection score": 100 * (tool_selection_score / total_steps),
        "tool used score": 100 * (tool_used_score / total_steps),
        "code score": 100 * (code_score / total_steps),
    }

    if return_errors:
        return scores, tool_selection_errors, tool_used_errors, code_errors
    else:
        return scores

.\tools\image_captioning.py

#!/usr/bin/env python
# coding=utf-8

# 导入 TYPE_CHECKING 模块,用于静态类型检查
from typing import TYPE_CHECKING

# 导入 AutoModelForVision2Seq 类,用于视觉到序列任务的自动模型加载
from ..models.auto import AutoModelForVision2Seq
# 导入 requires_backends 函数,用于检查所需的后端库是否安装
from ..utils import requires_backends
# 导入 PipelineTool 基类,作为工具类的基础类
from .base import PipelineTool

# 如果 TYPE_CHECKING 为 True,导入 Image 类
if TYPE_CHECKING:
    from PIL import Image

# 定义 ImageCaptioningTool 类,继承自 PipelineTool 基类
class ImageCaptioningTool(PipelineTool):
    # 默认的模型检查点路径
    default_checkpoint = "Salesforce/blip-image-captioning-base"
    # 工具描述信息,生成图像描述的工具
    description = (
        "This is a tool that generates a description of an image. It takes an input named `image` which should be the "
        "image to caption, and returns a text that contains the description in English."
    )
    # 工具名称
    name = "image_captioner"
    # 模型类
    model_class = AutoModelForVision2Seq

    # 输入要求,图像输入
    inputs = ["image"]
    # 输出要求,文本输出
    outputs = ["text"]

    # 初始化方法,检查视觉后端的必要性
    def __init__(self, *args, **kwargs):
        requires_backends(self, ["vision"])
        super().__init__(*args, **kwargs)

    # 编码方法,将图像进行编码
    def encode(self, image: "Image"):
        return self.pre_processor(images=image, return_tensors="pt")

    # 前向推理方法,生成描述文本
    def forward(self, inputs):
        return self.model.generate(**inputs)

    # 解码方法,解析生成的文本结果
    def decode(self, outputs):
        return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()

.\tools\image_question_answering.py

#!/usr/bin/env python
# coding=utf-8

# 导入需要的模块和类
from typing import TYPE_CHECKING
import torch

# 导入自定义的模块和类
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import requires_backends
from .base import PipelineTool

# 如果是类型检查,则导入PIL中的Image类
if TYPE_CHECKING:
    from PIL import Image

# 定义一个处理图像问答的工具类,继承自PipelineTool基类
class ImageQuestionAnsweringTool(PipelineTool):
    # 默认的模型检查点
    default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
    # 工具的描述信息
    description = (
        "This is a tool that answers a question about an image. It takes an input named `image` which should be the "
        "image containing the information, as well as a `question` which should be the question in English. It "
        "returns a text that is the answer to the question."
    )
    # 工具的名称
    name = "image_qa"
    # 预处理器类,用于处理输入
    pre_processor_class = AutoProcessor
    # 模型类,用于图像问答
    model_class = AutoModelForVisualQuestionAnswering

    # 输入和输出的定义
    inputs = ["image", "text"]
    outputs = ["text"]

    # 初始化方法,检查并加载必要的后端
    def __init__(self, *args, **kwargs):
        requires_backends(self, ["vision"])
        super().__init__(*args, **kwargs)

    # 编码方法,将图像和问题编码成模型可以接受的输入格式
    def encode(self, image: "Image", question: str):
        return self.pre_processor(image, question, return_tensors="pt")

    # 前向推理方法,使用模型进行推理并返回logits
    def forward(self, inputs):
        with torch.no_grad():
            return self.model(**inputs).logits

    # 解码方法,根据输出的logits找到对应的标签并返回
    def decode(self, outputs):
        idx = outputs.argmax(-1).item()
        return self.model.config.id2label[idx]

.\tools\image_segmentation.py

# 导入必要的库和模块
import numpy as np
import torch

# 从自定义的模块中导入CLIPSegForImageSegmentation模型
from ..models.clipseg import CLIPSegForImageSegmentation
# 从自定义的模块中导入必要的工具函数和类
from ..utils import is_vision_available, requires_backends
# 从本地模块中导入基础工具类PipelineTool
from .base import PipelineTool

# 如果视觉功能可用,则导入PIL库中的Image类
if is_vision_available():
    from PIL import Image

# 定义一个图像分割工具类,继承自PipelineTool基类
class ImageSegmentationTool(PipelineTool):
    # 工具描述信息
    description = (
        "This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image. "
        "It takes two arguments named `image` which should be the original image, and `label` which should be a text "
        "describing the elements what should be identified in the segmentation mask. The tool returns the mask."
    )
    # 默认的模型检查点路径
    default_checkpoint = "CIDAS/clipseg-rd64-refined"
    # 工具名称
    name = "image_segmenter"
    # 使用的模型类
    model_class = CLIPSegForImageSegmentation

    # 输入参数列表
    inputs = ["image", "text"]
    # 输出参数列表
    outputs = ["image"]

    # 初始化方法,检查视觉后端支持
    def __init__(self, *args, **kwargs):
        # 检查并确保视觉后端可用
        requires_backends(self, ["vision"])
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

    # 编码方法,将图像和标签转换为模型输入格式
    def encode(self, image: "Image", label: str):
        # 使用预处理器处理文本和图像,返回PyTorch张量
        return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")

    # 前向传播方法,执行模型推理
    def forward(self, inputs):
        # 使用无梯度计算环境执行模型推理,获取logits
        with torch.no_grad():
            logits = self.model(**inputs).logits
        return logits

    # 解码方法,将模型输出转换为图像
    def decode(self, outputs):
        # 将输出张量转换为NumPy数组
        array = outputs.cpu().detach().numpy()
        # 将数组中小于等于0的值设为0,大于0的值设为1
        array[array <= 0] = 0
        array[array > 0] = 1
        # 将数组转换为PIL图像,并返回
        return Image.fromarray((array * 255).astype(np.uint8))

.\tools\prompts.py

# 消息提示模板,包含了人类消息和助手回复的基本结构
CHAT_MESSAGE_PROMPT = """
Human: <<task>>

Assistant: """
# 默认的提示信息仓库地址
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
# 不同模式下的提示文件名映射
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}

def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
    """
    根据提示信息或仓库 ID 下载并缓存提示信息,并返回其内容(如果需要)
    """
    # 如果未提供提示信息或仓库 ID,则使用默认的提示信息仓库地址
    if prompt_or_repo_id is None:
        prompt_or_repo_id = DEFAULT_PROMPTS_REPO

    # 当提示信息中包含空格时,视其为仓库 ID 而非具体的提示信息
    if re.search("\\s", prompt_or_repo_id) is not None:
        return prompt_or_repo_id

    # 使用 cached_file 函数下载指定模式下的提示文件,并返回其文件路径
    prompt_file = cached_file(
        prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
    )
    # 打开并读取下载的提示文件内容
    with open(prompt_file, "r", encoding="utf-8") as f:
        return f.read()

.\tools\python_interpreter.py

    """
    Evaluate an abstract syntax tree (AST) node representing a Python expression, using variables from `state` and 
    restricted to functions in `tools`.

    Args:
        expression (`ast.AST`):
            The AST node to evaluate.
        state (`Dict[str, Any]`):
            A dictionary mapping variable names to their current values.
        tools (`Dict[str, Callable]`):
            Allowed functions that can be called during evaluation.

    Returns:
        Any:
            The result of evaluating the expression represented by `expression`.

    Raises:
        InterpretorError:
            If evaluation encounters an unsupported operation or other error.
    """
    try:
        # Parse the provided AST node
        line_result = ast.literal_eval(expression, globals=state, locals=tools)
    except (ValueError, TypeError, SyntaxError) as e:
        # Capture and raise an InterpretorError for unsupported operations or syntax errors
        raise InterpretorError(f"Failed to evaluate expression: {e}")
    return line_result
    This function will recurse trough the nodes of the tree provided.

    Args:
        expression (`ast.AST`):
            The code to evaluate, as an abstract syntax tree.
        state (`Dict[str, Any]`):
            A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
            encounters assignments.
        tools (`Dict[str, Callable]`):
            The functions that may be called during the evaluation. Any call to another function will fail with an
            `InterpretorError`.
    """
    if isinstance(expression, ast.Assign):
        # If the expression is an assignment statement
        # Evaluate the assignment and return the assigned variable's value
        return evaluate_assign(expression, state, tools)
    elif isinstance(expression, ast.Call):
        # If the expression is a function call
        # Evaluate the function call and return its value
        return evaluate_call(expression, state, tools)
    elif isinstance(expression, ast.Constant):
        # If the expression is a constant value (literal)
        # Return the constant's value
        return expression.value
    elif isinstance(expression, ast.Dict):
        # If the expression is a dictionary literal
        # Evaluate all keys and values recursively and return a dictionary
        keys = [evaluate_ast(k, state, tools) for k in expression.keys]
        values = [evaluate_ast(v, state, tools) for v in expression.values]
        return dict(zip(keys, values))
    elif isinstance(expression, ast.Expr):
        # If the expression is an expression statement
        # Evaluate the expression and return its value
        return evaluate_ast(expression.value, state, tools)
    elif isinstance(expression, ast.For):
        # If the expression is a for loop
        # Evaluate the loop and return its result
        return evaluate_for(expression, state, tools)
    elif isinstance(expression, ast.FormattedValue):
        # If the expression is a formatted value in an f-string
        # Evaluate the content and return its value
        return evaluate_ast(expression.value, state, tools)
    elif isinstance(expression, ast.If):
        # If the expression is an if statement
        # Evaluate the condition and execute the appropriate branch
        return evaluate_if(expression, state, tools)
    elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
        # If the expression is an index operation
        # Evaluate the indexed value and return it
        return evaluate_ast(expression.value, state, tools)
    elif isinstance(expression, ast.JoinedStr):
        # If the expression is a joined string (part of an f-string)
        # Evaluate the concatenated parts and return the resulting string
        return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
    elif isinstance(expression, ast.List):
        # If the expression is a list literal
        # Evaluate all elements recursively and return a list
        return [evaluate_ast(elt, state, tools) for elt in expression.elts]
    elif isinstance(expression, ast.Name):
        # If the expression is a variable name
        # Retrieve its value from the state dictionary
        return evaluate_name(expression, state, tools)
    elif isinstance(expression, ast.Subscript):
        # If the expression is a subscript operation
        # Evaluate the subscripted value and return it
        return evaluate_subscript(expression, state, tools)
    else:
        # If the expression type is not recognized
        # Raise an interpreter error indicating the unsupported expression type
        raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
# 对赋值表达式进行求值,更新状态并返回结果
def evaluate_assign(assign, state, tools):
    # 获取赋值表达式左侧的变量名列表
    var_names = assign.targets
    # 调用 evaluate_ast 函数求解赋值表达式右侧的值
    result = evaluate_ast(assign.value, state, tools)

    # 如果只有一个变量名,则直接将结果赋给状态中的对应变量
    if len(var_names) == 1:
        state[var_names[0].id] = result
    else:
        # 否则,检查结果的长度是否与变量名列表相符
        if len(result) != len(var_names):
            raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
        # 遍历变量名列表和结果,逐个更新状态中的变量值
        for var_name, r in zip(var_names, result):
            state[var_name.id] = r
    # 返回结果
    return result


# 对函数调用表达式进行求值,返回调用结果
def evaluate_call(call, state, tools):
    # 如果调用的函数不是一个简单的名称,抛出错误
    if not isinstance(call.func, ast.Name):
        raise InterpretorError(
            f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
            f"type {type(call.func)}."
        )
    # 获取函数名
    func_name = call.func.id
    # 如果函数名不在提供的工具集中,抛出错误
    if func_name not in tools:
        raise InterpretorError(
            f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
        )

    # 获取函数对象
    func = tools[func_name]
    # 处理函数调用的参数
    args = [evaluate_ast(arg, state, tools) for arg in call.args]
    kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
    # 调用函数并返回结果
    return func(*args, **kwargs)


# 对下标表达式进行求值,返回索引后的值
def evaluate_subscript(subscript, state, tools):
    # 求解下标和值
    index = evaluate_ast(subscript.slice, state, tools)
    value = evaluate_ast(subscript.value, state, tools)

    # 如果值是列表或元组,则返回索引对应的值
    if isinstance(value, (list, tuple)):
        return value[int(index)]
    # 如果索引存在于值中,则返回相应的值
    if index in value:
        return value[index]
    # 如果索引是字符串且值是映射类型,则找出最接近的键并返回其对应的值
    if isinstance(index, str) and isinstance(value, Mapping):
        close_matches = difflib.get_close_matches(index, list(value.keys()))
        if len(close_matches) > 0:
            return value[close_matches[0]]

    # 抛出错误,表示无法进行索引操作
    raise InterpretorError(f"Could not index {value} with '{index}'.")


# 对名称表达式进行求值,返回变量的值
def evaluate_name(name, state, tools):
    # 如果变量名存在于状态中,则返回其对应的值
    if name.id in state:
        return state[name.id]
    # 否则,查找变量名的最接近匹配,并返回对应的值
    close_matches = difflib.get_close_matches(name.id, list(state.keys()))
    if len(close_matches) > 0:
        return state[close_matches[0]]
    # 抛出错误,表示变量未定义
    raise InterpretorError(f"The variable `{name.id}` is not defined.")


# 对条件表达式进行求值,返回布尔值表示的条件结果
def evaluate_condition(condition, state, tools):
    # 如果条件包含多个操作符,抛出错误
    if len(condition.ops) > 1:
        raise InterpretorError("Cannot evaluate conditions with multiple operators")

    # 求解条件左侧和右侧的值
    left = evaluate_ast(condition.left, state, tools)
    comparator = condition.ops[0]
    right = evaluate_ast(condition.comparators[0], state, tools)

    # 根据比较符的类型,比较左右两侧的值并返回结果
    if isinstance(comparator, ast.Eq):
        return left == right
    elif isinstance(comparator, ast.NotEq):
        return left != right
    elif isinstance(comparator, ast.Lt):
        return left < right
    elif isinstance(comparator, ast.LtE):
        return left <= right
    elif isinstance(comparator, ast.Gt):
        return left > right
    elif isinstance(comparator, ast.GtE):
        return left >= right
    elif isinstance(comparator, ast.Is):
        return left is right
    elif isinstance(comparator, ast.IsNot):
        return left is not right
    # 如果比较符号是 'in',则返回左操作数是否包含在右操作数中的布尔值
    elif isinstance(comparator, ast.In):
        return left in right
    # 如果比较符号是 'not in',则返回左操作数是否不包含在右操作数中的布尔值
    elif isinstance(comparator, ast.NotIn):
        return left not in right
    else:
        # 如果比较符号不是以上两种情况,抛出解释器错误,显示不支持的操作符信息
        raise InterpretorError(f"Operator not supported: {comparator}")
# 根据条件语句评估条件并执行相应的操作,返回最后一个操作的结果
def evaluate_if(if_statement, state, tools):
    result = None
    # 如果条件为真,执行条件体内的语句
    if evaluate_condition(if_statement.test, state, tools):
        # 遍历条件体内的每一行语句
        for line in if_statement.body:
            # 评估并执行当前行的抽象语法树节点
            line_result = evaluate_ast(line, state, tools)
            # 如果结果不为空,更新结果
            if line_result is not None:
                result = line_result
    else:
        # 如果条件为假,执行否定体内的语句
        for line in if_statement.orelse:
            # 评估并执行当前行的抽象语法树节点
            line_result = evaluate_ast(line, state, tools)
            # 如果结果不为空,更新结果
            if line_result is not None:
                result = line_result
    # 返回最后执行的结果
    return result


# 根据for循环语句评估迭代器,并依次执行循环体内的操作,返回最后一个操作的结果
def evaluate_for(for_loop, state, tools):
    result = None
    # 评估迭代器表达式,获取迭代器对象
    iterator = evaluate_ast(for_loop.iter, state, tools)
    # 遍历迭代器对象中的每一个元素
    for counter in iterator:
        # 将当前元素赋值给循环目标变量
        state[for_loop.target.id] = counter
        # 遍历for循环体内的每一个表达式
        for expression in for_loop.body:
            # 评估并执行当前表达式的抽象语法树节点
            line_result = evaluate_ast(expression, state, tools)
            # 如果结果不为空,更新结果
            if line_result is not None:
                result = line_result
    # 返回最后执行的结果
    return result

.\tools\speech_to_text.py

#!/usr/bin/env python
# coding=utf-8

# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的模块和类
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from .base import PipelineTool

# 定义一个继承自PipelineTool的子类SpeechToTextTool
class SpeechToTextTool(PipelineTool):
    # 默认的模型检查点路径
    default_checkpoint = "openai/whisper-base"
    # 工具的描述信息
    description = (
        "This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
        "transcribed text."
    )
    # 工具的名称
    name = "transcriber"
    # 预处理器类,用于处理输入数据
    pre_processor_class = WhisperProcessor
    # 模型类,用于生成输出
    model_class = WhisperForConditionalGeneration

    # 定义输入的名称列表
    inputs = ["audio"]
    # 定义输出的名称列表
    outputs = ["text"]

    # 编码方法,将输入的音频转换成模型可以处理的张量形式
    def encode(self, audio):
        return self.pre_processor(audio, return_tensors="pt").input_features

    # 前向传播方法,使用模型生成输出
    def forward(self, inputs):
        return self.model.generate(inputs=inputs)

    # 解码方法,将模型输出的张量转换成文本形式
    def decode(self, outputs):
        return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]

.\tools\text_classification.py

#!/usr/bin/env python
# coding=utf-8

# 导入torch模块
import torch

# 从上级目录中导入AutoModelForSequenceClassification和AutoTokenizer类
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer

# 从base模块中导入PipelineTool类
from .base import PipelineTool


class TextClassificationTool(PipelineTool):
    """
    文本分类工具类,继承自PipelineTool基类。

    Example:

    ```
    from transformers.tools import TextClassificationTool

    classifier = TextClassificationTool()
    classifier("This is a super nice API!", labels=["positive", "negative"])
    ```
    """

    # 默认的预训练模型
    default_checkpoint = "facebook/bart-large-mnli"
    # 工具描述信息
    description = (
        "This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
        "should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
        "It returns the most likely label in the list of provided `labels` for the input text."
    )
    # 工具名称
    name = "text_classifier"
    # 预处理器类,使用AutoTokenizer
    pre_processor_class = AutoTokenizer
    # 模型类,使用AutoModelForSequenceClassification

    model_class = AutoModelForSequenceClassification

    # 输入参数列表
    inputs = ["text", ["text"]]
    # 输出参数列表
    outputs = ["text"]

    def setup(self):
        # 调用父类的setup方法
        super().setup()
        # 获取模型配置
        config = self.model.config
        # 初始化entailment_id为-1
        self.entailment_id = -1
        # 遍历id2label字典,找到以"entail"开头的标签对应的索引
        for idx, label in config.id2label.items():
            if label.lower().startswith("entail"):
                self.entailment_id = int(idx)
        # 如果未找到对应的entailment标签,抛出数值错误异常
        if self.entailment_id == -1:
            raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")

    def encode(self, text, labels):
        # 编码函数,将输入的文本和标签进行编码处理
        self._labels = labels
        return self.pre_processor(
            [text] * len(labels),
            [f"This example is {label}" for label in labels],
            return_tensors="pt",
            padding="max_length",
        )

    def decode(self, outputs):
        # 解码函数,根据模型输出的logits确定最可能的标签
        logits = outputs.logits
        label_id = torch.argmax(logits[:, 2]).item()
        return self._labels[label_id]

.\tools\text_question_answering.py

#!/usr/bin/env python
# coding=utf-8

# 导入必要的模块和类
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool

# 定义一个包含占位符文本和问题的模板字符串
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.

Can you answer this question about the text: '{question}'"""

# 定义一个工具类,继承自PipelineTool基类,用于文本问答任务
class TextQuestionAnsweringTool(PipelineTool):
    # 默认使用的模型的检查点名称
    default_checkpoint = "google/flan-t5-base"
    # 工具的描述信息
    description = (
        "This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
        "text where to find the answer, and `question`, which is the question, and returns the answer to the question."
    )
    # 工具的名称
    name = "text_qa"
    # 预处理器使用的类别
    pre_processor_class = AutoTokenizer
    # 模型使用的类别
    model_class = AutoModelForSeq2SeqLM

    # 输入参数列表,包括文本和问题
    inputs = ["text", "text"]
    # 输出参数列表,只有文本答案
    outputs = ["text"]

    # 编码函数,将文本和问题格式化为模型输入
    def encode(self, text: str, question: str):
        # 根据模板生成特定格式的问题提示文本
        prompt = QA_PROMPT.format(text=text, question=question)
        # 使用预处理器处理文本,并返回PyTorch张量格式的输入
        return self.pre_processor(prompt, return_tensors="pt")

    # 前向推理函数,执行模型生成文本答案
    def forward(self, inputs):
        # 使用模型生成输出标识符
        output_ids = self.model.generate(**inputs)

        # 计算输入和输出张量的形状信息
        in_b, _ = inputs["input_ids"].shape
        out_b = output_ids.shape[0]

        # 重新整形输出张量,保证符合预期的格式
        return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]

    # 解码函数,将模型输出的标识符转换为文本答案
    def decode(self, outputs):
        # 使用预处理器解码,去除特殊标记并清理空白字符
        return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

.\tools\text_summarization.py

# 指定脚本使用的 Python 解释器,并声明编码格式为 UTF-8

# 导入必要的库和模块
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool

# 定义文本摘要工具类,继承自 PipelineTool 基类
class TextSummarizationTool(PipelineTool):
    """
    Example:

    ```
    from transformers.tools import TextSummarizationTool

    summarizer = TextSummarizationTool()
    summarizer(long_text)
    ```
    """

    # 默认使用的模型检查点
    default_checkpoint = "philschmid/bart-large-cnn-samsum"
    # 工具的描述信息
    description = (
        "This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
        "and returns a summary of the text."
    )
    # 工具的名称
    name = "summarizer"
    # 预处理器类,用于处理输入文本
    pre_processor_class = AutoTokenizer
    # 模型类,用于生成摘要
    model_class = AutoModelForSeq2SeqLM

    # 输入数据的名称列表
    inputs = ["text"]
    # 输出数据的名称列表
    outputs = ["text"]

    # 对输入文本进行编码的方法,使用预处理器返回 PyTorch 张量,并进行截断处理
    def encode(self, text):
        return self.pre_processor(text, return_tensors="pt", truncation=True)

    # 执行前向传播的方法,使用模型生成摘要
    def forward(self, inputs):
        return self.model.generate(**inputs)[0]

    # 对生成的输出进行解码的方法,跳过特殊符号并清理分词空格
    def decode(self, outputs):
        return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

.\tools\text_to_speech.py

#!/usr/bin/env python
# coding=utf-8

# 导入 PyTorch 库
import torch

# 从上层目录中导入相应模块和类
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from .base import PipelineTool

# 如果 datasets 可用,则从 datasets 库中导入 load_dataset 函数
if is_datasets_available():
    from datasets import load_dataset


# 定义 TextToSpeechTool 类,继承自 PipelineTool 基类
class TextToSpeechTool(PipelineTool):
    # 默认模型检查点
    default_checkpoint = "microsoft/speecht5_tts"
    # 工具描述
    description = (
        "This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
        "text to read (in English) and returns a waveform object containing the sound."
    )
    # 工具名称
    name = "text_reader"
    # 预处理器类
    pre_processor_class = SpeechT5Processor
    # 模型类
    model_class = SpeechT5ForTextToSpeech
    # 后处理器类
    post_processor_class = SpeechT5HifiGan

    # 输入要求
    inputs = ["text"]
    # 输出结果
    outputs = ["audio"]

    # 设置方法,初始化后处理器
    def setup(self):
        if self.post_processor is None:
            self.post_processor = "microsoft/speecht5_hifigan"
        super().setup()

    # 编码方法,将文本编码为模型输入格式,支持截断处理
    def encode(self, text, speaker_embeddings=None):
        # 使用预处理器将文本编码为输入张量
        inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)

        # 如果未提供说话者嵌入向量,则加载默认数据集中的说话者嵌入向量
        if speaker_embeddings is None:
            if not is_datasets_available():
                # 如果 datasets 库不可用,则抛出 ImportError
                raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")

            # 加载指定数据集的验证集中的说话者嵌入向量
            embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
            # 获取特定说话者的嵌入向量并添加批处理维度
            speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)

        # 返回输入格式化后的字典,包括输入文本和说话者嵌入向量
        return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}

    # 前向传播方法,使用模型生成语音数据
    def forward(self, inputs):
        # 使用无梯度计算环境执行语音生成
        with torch.no_grad():
            return self.model.generate_speech(**inputs)

    # 解码方法,使用后处理器处理生成的语音输出
    def decode(self, outputs):
        # 使用无梯度计算环境执行后处理器,将输出从 GPU 移动到 CPU,并且断开梯度追踪
        with torch.no_grad():
            return self.post_processor(outputs).cpu().detach()

.\tools\translation.py

#!/usr/bin/env python
# coding=utf-8

# 导入必要的模块和类
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool

# 定义语言代码的映射关系,将语言名映射到其对应的缩写代码
LANGUAGE_CODES = {
    "Acehnese Arabic": "ace_Arab",
    "Acehnese Latin": "ace_Latn",
    "Mesopotamian Arabic": "acm_Arab",
    "Ta'izzi-Adeni Arabic": "acq_Arab",
    "Tunisian Arabic": "aeb_Arab",
    "Afrikaans": "afr_Latn",
    "South Levantine Arabic": "ajp_Arab",
    "Akan": "aka_Latn",
    "Amharic": "amh_Ethi",
    "North Levantine Arabic": "apc_Arab",
    "Modern Standard Arabic": "arb_Arab",
    "Modern Standard Arabic Romanized": "arb_Latn",
    "Najdi Arabic": "ars_Arab",
    "Moroccan Arabic": "ary_Arab",
    "Egyptian Arabic": "arz_Arab",
    "Assamese": "asm_Beng",
    "Asturian": "ast_Latn",
    "Awadhi": "awa_Deva",
    "Central Aymara": "ayr_Latn",
    "South Azerbaijani": "azb_Arab",
    "North Azerbaijani": "azj_Latn",
    "Bashkir": "bak_Cyrl",
    "Bambara": "bam_Latn",
    "Balinese": "ban_Latn",
    "Belarusian": "bel_Cyrl",
    "Bemba": "bem_Latn",
    "Bengali": "ben_Beng",
    "Bhojpuri": "bho_Deva",
    "Banjar Arabic": "bjn_Arab",
    "Banjar Latin": "bjn_Latn",
    "Standard Tibetan": "bod_Tibt",
    "Bosnian": "bos_Latn",
    "Buginese": "bug_Latn",
    "Bulgarian": "bul_Cyrl",
    "Catalan": "cat_Latn",
    "Cebuano": "ceb_Latn",
    "Czech": "ces_Latn",
    "Chokwe": "cjk_Latn",
    "Central Kurdish": "ckb_Arab",
    "Crimean Tatar": "crh_Latn",
    "Welsh": "cym_Latn",
    "Danish": "dan_Latn",
    "German": "deu_Latn",
    "Southwestern Dinka": "dik_Latn",
    "Dyula": "dyu_Latn",
    "Dzongkha": "dzo_Tibt",
    "Greek": "ell_Grek",
    "English": "eng_Latn",
    "Esperanto": "epo_Latn",
    "Estonian": "est_Latn",
    "Basque": "eus_Latn",
    "Ewe": "ewe_Latn",
    "Faroese": "fao_Latn",
    "Fijian": "fij_Latn",
    "Finnish": "fin_Latn",
    "Fon": "fon_Latn",
    "French": "fra_Latn",
    "Friulian": "fur_Latn",
    "Nigerian Fulfulde": "fuv_Latn",
    "Scottish Gaelic": "gla_Latn",
    "Irish": "gle_Latn",
    "Galician": "glg_Latn",
    "Guarani": "grn_Latn",
    "Gujarati": "guj_Gujr",
    "Haitian Creole": "hat_Latn",
    "Hausa": "hau_Latn",
    "Hebrew": "heb_Hebr",
    "Hindi": "hin_Deva",
    "Chhattisgarhi": "hne_Deva",
    "Croatian": "hrv_Latn",
    "Hungarian": "hun_Latn",
    "Armenian": "hye_Armn",
    "Igbo": "ibo_Latn",
    "Ilocano": "ilo_Latn",
    "Indonesian": "ind_Latn",
    # 印尼语使用拉丁字母表

    "Icelandic": "isl_Latn",
    # 冰岛语使用拉丁字母表

    "Italian": "ita_Latn",
    # 意大利语使用拉丁字母表

    "Javanese": "jav_Latn",
    # 爪哇语使用拉丁字母表

    "Japanese": "jpn_Jpan",
    # 日语使用日本汉字

    "Kabyle": "kab_Latn",
    # 卡比尔语使用拉丁字母表

    "Jingpho": "kac_Latn",
    # 景颇语使用拉丁字母表

    "Kamba": "kam_Latn",
    # 坎巴语使用拉丁字母表

    "Kannada": "kan_Knda",
    # 卡纳达语使用卡纳达字母表

    "Kashmiri Arabic": "kas_Arab",
    # 克什米尔语使用阿拉伯字母表

    "Kashmiri Devanagari": "kas_Deva",
    # 克什米尔语使用梵文字母表

    "Georgian": "kat_Geor",
    # 格鲁吉亚语使用格鲁吉亚字母表

    "Central Kanuri Arabic": "knc_Arab",
    # 中夸努里语使用阿拉伯字母表

    "Central Kanuri Latin": "knc_Latn",
    # 中夸努里语使用拉丁字母表

    "Kazakh": "kaz_Cyrl",
    # 哈萨克语使用西里尔字母表

    "Kabiyè": "kbp_Latn",
    # 卡比语使用拉丁字母表

    "Kabuverdianu": "kea_Latn",
    # 佛得角克里奥尔语使用拉丁字母表

    "Khmer": "khm_Khmr",
    # 高棉语使用高棉字母表

    "Kikuyu": "kik_Latn",
    # 基库尤语使用拉丁字母表

    "Kinyarwanda": "kin_Latn",
    # 卢旺达语使用拉丁字母表

    "Kyrgyz": "kir_Cyrl",
    # 吉尔吉斯语使用西里尔字母表

    "Kimbundu": "kmb_Latn",
    # 金本杜语使用拉丁字母表

    "Northern Kurdish": "kmr_Latn",
    # 北库尔德语使用拉丁字母表

    "Kikongo": "kon_Latn",
    # 基孔戈语使用拉丁字母表

    "Korean": "kor_Hang",
    # 韩语使用朝鲜字母

    "Lao": "lao_Laoo",
    # 老挝语使用老挝字母表

    "Ligurian": "lij_Latn",
    # 利古里亚语使用拉丁字母表

    "Limburgish": "lim_Latn",
    # 林堡语使用拉丁字母表

    "Lingala": "lin_Latn",
    # 林加拉语使用拉丁字母表

    "Lithuanian": "lit_Latn",
    # 立陶宛语使用拉丁字母表

    "Lombard": "lmo_Latn",
    # 伦巴第语使用拉丁字母表

    "Latgalian": "ltg_Latn",
    # 拉特加利亚语使用拉丁字母表

    "Luxembourgish": "ltz_Latn",
    # 卢森堡语使用拉丁字母表

    "Luba-Kasai": "lua_Latn",
    # 卢巴卡萨语使用拉丁字母表

    "Ganda": "lug_Latn",
    # 干达语使用拉丁字母表

    "Luo": "luo_Latn",
    # 卢奥语使用拉丁字母表

    "Mizo": "lus_Latn",
    # 米佐语使用拉丁字母表

    "Standard Latvian": "lvs_Latn",
    # 标准拉脱维亚语使用拉丁字母表

    "Magahi": "mag_Deva",
    # 马加希语使用梵文字母表

    "Maithili": "mai_Deva",
    # 麦蒂利语使用梵文字母表

    "Malayalam": "mal_Mlym",
    # 马拉雅拉姆语使用马拉雅拉姆字母表

    "Marathi": "mar_Deva",
    # 马拉地语使用梵文字母表

    "Minangkabau Arabic ": "min_Arab",
    # 苏门答腊语使用阿拉伯字母表

    "Minangkabau Latin": "min_Latn",
    # 苏门答腊语使用拉丁字母表

    "Macedonian": "mkd_Cyrl",
    # 马其顿语使用西里尔字母表

    "Plateau Malagasy": "plt_Latn",
    # 马达加斯加高原语使用拉丁字母表

    "Maltese": "mlt_Latn",
    # 马耳他语使用拉丁字母表

    "Meitei Bengali": "mni_Beng",
    # 曼尼普尔语使用孟加拉字母表

    "Halh Mongolian": "khk_Cyrl",
    # 哈尔哈蒙古语使用西里尔字母表

    "Mossi": "mos_Latn",
    # 莫西语使用拉丁字母表

    "Maori": "mri_Latn",
    # 毛利语使用拉丁字母表

    "Burmese": "mya_Mymr",
    # 缅甸语使用缅甸字母表

    "Dutch": "nld_Latn",
    # 荷兰语使用拉丁字母表

    "Norwegian Nynorsk": "nno_Latn",
    # 挪威尼诺斯克语使用拉丁字母表

    "Norwegian Bokmål": "nob_Latn",
    # 挪威博克马尔语使用拉丁字母表

    "Nepali": "npi_Deva",
    # 尼泊尔语使用梵文字母表

    "Northern Sotho": "nso_Latn",
    # 北索托语使用拉丁字母表

    "Nuer": "nus_Latn",
    # 努埃尔语使用拉丁字母表

    "Nyanja": "nya_Latn",
    # 尼昂加语使用拉丁字母表

    "Occitan": "oci_Latn",
    # 奥克语使用拉丁字母表

    "West Central Oromo": "gaz_Latn",
    # 西中奥罗莫语使用拉丁字母表

    "Odia": "ory_Orya",
    # 奥里雅语使用奥里雅字母表

    "Pangasinan": "pag_Latn",
    # 潘加西南语使用拉丁字母表

    "Eastern Panjabi": "pan_Guru",
    # 东旁遮普语使用古尔穆基字母表

    "Papiamento": "pap_Latn",
    # 帕皮亚门托语使用拉丁字母表

    "Western Persian": "pes_Arab",
    # 西部波斯语
    # 字符串到语言代码的映射,其中键为语言名称,值为语言代码
    {
        "Tamasheq Latin": "taq_Latn",
        "Tamasheq Tifinagh": "taq_Tfng",
        "Tok Pisin": "tpi_Latn",
        "Tswana": "tsn_Latn",
        "Tsonga": "tso_Latn",
        "Turkmen": "tuk_Latn",
        "Tumbuka": "tum_Latn",
        "Turkish": "tur_Latn",
        "Twi": "twi_Latn",
        "Central Atlas Tamazight": "tzm_Tfng",
        "Uyghur": "uig_Arab",
        "Ukrainian": "ukr_Cyrl",
        "Umbundu": "umb_Latn",
        "Urdu": "urd_Arab",
        "Northern Uzbek": "uzn_Latn",
        "Venetian": "vec_Latn",
        "Vietnamese": "vie_Latn",
        "Waray": "war_Latn",
        "Wolof": "wol_Latn",
        "Xhosa": "xho_Latn",
        "Eastern Yiddish": "ydd_Hebr",
        "Yoruba": "yor_Latn",
        "Yue Chinese": "yue_Hant",
        "Chinese Simplified": "zho_Hans",
        "Chinese Traditional": "zho_Hant",
        "Standard Malay": "zsm_Latn",
        "Zulu": "zul_Latn",
    }
    }



class TranslationTool(PipelineTool):
    """
    Example:

    ```
    from transformers.tools import TranslationTool

    translator = TranslationTool()
    translator("This is a super nice API!", src_lang="English", tgt_lang="French")
    ```
    """

    default_checkpoint = "facebook/nllb-200-distilled-600M"
    description = (
        "This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
        "be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
        "which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
        "plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
    )
    name = "translator"
    pre_processor_class = AutoTokenizer
    model_class = AutoModelForSeq2SeqLM
    lang_to_code = LANGUAGE_CODES

    inputs = ["text", "text", "text"]
    outputs = ["text"]

    def encode(self, text, src_lang, tgt_lang):
        # 检查源语言是否在支持的语言列表中
        if src_lang not in self.lang_to_code:
            raise ValueError(f"{src_lang} is not a supported language.")
        # 检查目标语言是否在支持的语言列表中
        if tgt_lang not in self.lang_to_code:
            raise ValueError(f"{tgt_lang} is not a supported language.")
        # 将源语言和目标语言转换为对应的语言代码
        src_lang = self.lang_to_code[src_lang]
        tgt_lang = self.lang_to_code[tgt_lang]
        # 使用预处理器构建翻译输入
        return self.pre_processor._build_translation_inputs(
            text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
        )

    def forward(self, inputs):
        # 使用模型生成翻译结果
        return self.model.generate(**inputs)

    def decode(self, outputs):
        # 使用后处理器解码输出结果,跳过特殊标记
        return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)

.\tools\__init__.py

#!/usr/bin/env python
# coding=utf-8

# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入必要的类型检查工具
from typing import TYPE_CHECKING

# 导入自定义的异常
from ..utils import (
    OptionalDependencyNotAvailable,
    _LazyModule,
    is_torch_available,
)

# 定义模块导入结构
_import_structure = {
    "agents": ["Agent", "AzureOpenAiAgent", "HfAgent", "LocalAgent", "OpenAiAgent"],
    "base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}

# 尝试导入 Torch,若不可用则抛出自定义异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 Torch 可用,扩展导入结构
    _import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
    _import_structure["image_captioning"] = ["ImageCaptioningTool"]
    _import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
    _import_structure["image_segmentation"] = ["ImageSegmentationTool"]
    _import_structure["speech_to_text"] = ["SpeechToTextTool"]
    _import_structure["text_classification"] = ["TextClassificationTool"]
    _import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
    _import_structure["text_summarization"] = ["TextSummarizationTool"]
    _import_structure["text_to_speech"] = ["TextToSpeechTool"]
    _import_structure["translation"] = ["TranslationTool"]

# 如果进行类型检查,则进一步导入具体模块
if TYPE_CHECKING:
    from .agents import Agent, AzureOpenAiAgent, HfAgent, LocalAgent, OpenAiAgent
    from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool

    try:
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        pass
    else:
        # 如果 Torch 可用,详细导入相关工具类
        from .document_question_answering import DocumentQuestionAnsweringTool
        from .image_captioning import ImageCaptioningTool
        from .image_question_answering import ImageQuestionAnsweringTool
        from .image_segmentation import ImageSegmentationTool
        from .speech_to_text import SpeechToTextTool
        from .text_classification import TextClassificationTool
        from .text_question_answering import TextQuestionAnsweringTool
        from .text_summarization import TextSummarizationTool
        from .text_to_speech import TextToSpeechTool
        from .translation import TranslationTool
else:
    import sys

    # 如果不是类型检查,使用 LazyModule 进行懒加载
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\trainer.py

# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import contextlib
import copy
import functools
import glob
import importlib.metadata
import inspect
import math
import os
import random
import re
import shutil
import sys
import tempfile
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

# Integrations must be imported before ML frameworks:
# isort: off
from .integrations import (
    get_reporting_integration_callbacks,
    hp_params,
)

# isort: on

import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler

from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_MAPPING_NAMES,
)
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from .trainer_pt_utils import (
    DistributedTensorGatherer,
    IterableDatasetShard,
    LabelSmoother,
    LayerWiseDummyOptimizer,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    find_batch_size,
    get_dataloader_sampler,
    get_model_param_count,
)

# Importing integration modules and utilities for reporting and hyperparameters
from .integrations import get_reporting_integration_callbacks, hp_params

# Importing utility functions and classes from the huggingface_hub library
import huggingface_hub.utils as hf_hub_utils

# Importing essential modules from Python's standard library
import numpy as np  # NumPy for numerical computing
import torch  # PyTorch for deep learning framework
import torch.distributed as dist  # PyTorch distributed for parallel computing

# Importing specific functions and classes from huggingface_hub library
from huggingface_hub import ModelCard, create_repo, upload_folder

# Importing version comparison utility from packaging module
from packaging import version

# Importing neural network related modules from PyTorch
from torch import nn  # Neural network module
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler  # Data handling utilities

# Importing essential internal modules from Transformers library
from . import __version__  # Current version of Transformers library
from .configuration_utils import PretrainedConfig  # Configuration utilities for pretrained models
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator  # Data collation utilities
from .debug_utils import DebugOption, DebugUnderflowOverflow  # Debugging utilities
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend  # Hyperparameter search utilities
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available  # DeepSpeed integration utilities
from .integrations.tpu import tpu_spmd_dataloader  # TPU integration for data loading
from .modelcard import TrainingSummary  # ModelCard for model documentation
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model  # Model utilities
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES  # Auto model mapping names
from .optimization import Adafactor, get_scheduler  # Optimization utilities
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13  # PyTorch utility functions
from .tokenization_utils_base import PreTrainedTokenizerBase  # Base class for tokenization utilities
from .trainer_callback import CallbackHandler, DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerCallback, TrainerControl, TrainerState  # Trainer callback utilities
from .trainer_pt_utils import DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, get_dataloader_sampler, get_model_param_count  # PyTorch specific trainer utilities
    # 从给定的名称中获取模块类
    get_module_class_from_name,
    
    # 获取参数的名称列表
    get_parameter_names,
    
    # 递归连接(concatenate)操作,可能是连接嵌套结构的函数
    nested_concat,
    
    # 递归分离(detach)操作,可能是将梯度信息分离出来的函数
    nested_detach,
    
    # 将嵌套结构转换为 NumPy 数组的函数
    nested_numpify,
    
    # 对 XLA 网格进行递归降维(reduce)的函数
    nested_xla_mesh_reduce,
    
    # 重新发出 PyTorch 警告信息的函数
    reissue_pt_warnings,
    
    # 移除虚拟检查点的函数
    remove_dummy_checkpoint,
# 从 `trainer_utils` 模块中导入多个工具类和函数
from .trainer_utils import (
    PREFIX_CHECKPOINT_DIR,  # 导入检查点目录前缀
    BestRun,  # 导入最佳运行结果类
    EvalLoopOutput,  # 导入评估循环输出类
    EvalPrediction,  # 导入评估预测类
    HPSearchBackend,  # 导入超参数搜索后端类
    HubStrategy,  # 导入Hub策略类
    IntervalStrategy,  # 导入间隔策略类
    PredictionOutput,  # 导入预测输出类
    RemoveColumnsCollator,  # 导入移除列的集合类
    TrainerMemoryTracker,  # 导入训练器内存追踪类
    TrainOutput,  # 导入训练输出类
    check_target_module_exists,  # 导入检查目标模块是否存在的函数
    default_compute_objective,  # 导入默认计算目标的函数
    denumpify_detensorize,  # 导入去除NumPy array或tensor化的函数
    enable_full_determinism,  # 导入启用完全确定性的函数
    find_executable_batch_size,  # 导入查找可执行批量大小的函数
    get_last_checkpoint,  # 导入获取最后一个检查点的函数
    has_length,  # 导入判断对象是否具有长度的函数
    neftune_post_forward_hook,  # 导入Neftune后向钩子函数
    number_of_arguments,  # 导入获取参数个数的函数
    seed_worker,  # 导入种子工作器函数
    set_seed,  # 导入设置种子的函数
    speed_metrics,  # 导入速度度量指标函数
)

# 从 `training_args` 模块中导入优化器名称、并行模式、训练参数类
from .training_args import OptimizerNames, ParallelMode, TrainingArguments

# 从 `utils` 模块中导入多个常量、类和函数
from .utils import (
    ADAPTER_CONFIG_NAME,  # 导入适配器配置名称
    ADAPTER_SAFE_WEIGHTS_NAME,  # 导入适配器安全权重名称
    ADAPTER_WEIGHTS_NAME,  # 导入适配器权重名称
    CONFIG_NAME,  # 导入配置名称
    SAFE_WEIGHTS_INDEX_NAME,  # 导入安全权重索引名称
    SAFE_WEIGHTS_NAME,  # 导入安全权重名称
    WEIGHTS_INDEX_NAME,  # 导入权重索引名称
    WEIGHTS_NAME,  # 导入权重名称
    PushInProgress,  # 导入推送进行中类
    PushToHubMixin,  # 导入推送到Hub混合类
    can_return_loss,  # 导入能否返回损失的函数
    find_labels,  # 导入查找标签的函数
    is_accelerate_available,  # 导入加速库是否可用的函数
    is_apex_available,  # 导入APEX是否可用的函数
    is_bitsandbytes_available,  # 导入BitsAndBytes是否可用的函数
    is_datasets_available,  # 导入数据集是否可用的函数
    is_galore_torch_available,  # 导入Galore Torch是否可用的函数
    is_in_notebook,  # 导入是否在笔记本中的函数
    is_ipex_available,  # 导入IPEx是否可用的函数
    is_peft_available,  # 导入PEFT是否可用的函数
    is_safetensors_available,  # 导入安全张量是否可用的函数
    is_sagemaker_dp_enabled,  # 导入SageMaker分布式训练是否启用的函数
    is_sagemaker_mp_enabled,  # 导入SageMaker模型并行是否启用的函数
    is_torch_compile_available,  # 导入Torch编译是否可用的函数
    is_torch_neuroncore_available,  # 导入Torch NeuronCore是否可用的函数
    is_torch_npu_available,  # 导入Torch NPU是否可用的函数
    is_torch_xla_available,  # 导入Torch XLA是否可用的函数
    logging,  # 导入日志功能
    strtobool,  # 导入字符串转布尔值的函数
)

# 从 `utils.quantization_config` 模块中导入量化方法
from .utils.quantization_config import QuantizationMethod

# 默认回调函数列表
DEFAULT_CALLBACKS = [DefaultFlowCallback]

# 默认进度回调函数
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# 如果在笔记本中,则导入笔记本进度回调函数
if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback
    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

# 如果APEX可用,则导入AMP
if is_apex_available():
    from apex import amp

# 如果数据集可用,则导入datasets模块
if is_datasets_available():
    import datasets

# 如果Torch XLA可用,则导入相关模块
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.spmd as xs
    import torch_xla.runtime as xr

# 如果SageMaker模型并行可用,则导入相关模块和版本检查
if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")

    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
else:
    IS_SAGEMAKER_MP_POST_1_10 = False

# 如果安全张量可用,则导入相关模块
if is_safetensors_available():
    import safetensors.torch

# 如果PEFT可用,则导入PeftModel
if is_peft_available():
    from peft import PeftModel

# 如果加速库可用,则导入加速库相关模块和函数
if is_accelerate_available():
    from accelerate import Accelerator, skip_first_batches
    from accelerate import __version__ as accelerate_version
    from accelerate.utils import (
        DistributedDataParallelKwargs,
        DistributedType,
        GradientAccumulationPlugin,
        load_fsdp_model,
        load_fsdp_optimizer,
        save_fsdp_model,
        save_fsdp_optimizer,
    )

    DATA_SAMPLERS = [RandomSampler]  # 数据采样器列表
    # 检查加速库的版本是否大于 "0.23.0"
    if version.parse(accelerate_version) > version.parse("0.23.0"):
        # 如果满足条件,导入 SeedableRandomSampler 类从 accelerate.data_loader 模块
        from accelerate.data_loader import SeedableRandomSampler
    
        # 将 SeedableRandomSampler 类加入到 DATA_SAMPLERS 列表中
        DATA_SAMPLERS += [SeedableRandomSampler]
    
    # 检查是否存在 DeepSpeed 库
    if is_deepspeed_available():
        # 如果 DeepSpeed 可用,从 accelerate.utils 模块导入 DeepSpeedSchedulerWrapper 类
        from accelerate.utils import DeepSpeedSchedulerWrapper
# 检查给定的模型是否属于 PEFT 模型类或其混合类
def _is_peft_model(model):
    # 检查是否安装了 PEFT
    if is_peft_available():
        # 如果安装了 PEFT,先将基础类设置为 PeftModel
        classes_to_check = (PeftModel,)
        # 检查 PEFT 的版本是否大于等于 0.7.0
        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
            from peft import PeftMixedModel

            # 如果版本允许,添加 PeftMixedModel 到待检查类列表
            classes_to_check = (*classes_to_check, PeftMixedModel)
        # 返回模型是否属于待检查类列表中的任何一个
        return isinstance(model, classes_to_check)
    # 如果未安装 PEFT,则返回 False
    return False


def _get_fsdp_ckpt_kwargs():
    # TODO: @AjayP13, @younesbelkada 在下一个 `accelerate` 发布中,使用版本检查替换此检查
    # 检查是否安装了 accelerate 并且 save_fsdp_model 的参数列表中包含 "adapter_only"
    if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
        # 如果条件成立,返回适合 FSDP 检查点的参数字典
        return {"adapter_only": True}
    else:
        # 否则返回空字典
        return {}


if TYPE_CHECKING:
    import optuna  # 导入类型检查时所需的 optuna 模块


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


# 用于保存检查点文件的文件名常量
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"


class Trainer:
    """
    Trainer 是一个简单但功能齐全的 PyTorch 训练和评估循环,专为 🤗 Transformers 优化。

    重要属性:

        - **model** -- 始终指向核心模型。如果使用的是 transformers 模型,它将是 [`PreTrainedModel`] 的子类。
        - **model_wrapped** -- 始终指向最外层的模型。如果使用 `DeepSpeed`,内部模型会被包装成 `DeepSpeed` 和 `torch.nn.DistributedDataParallel`。
          如果内部模型没有被包装,则 `self.model_wrapped` 与 `self.model` 相同。
        - **is_model_parallel** -- 是否将模型切换到模型并行模式(不同于数据并行,意味着一些模型层在不同的 GPU 上拆分)。
        - **place_model_on_device** -- 是否自动将模型放置在设备上。如果使用模型并行或 DeepSpeed,或者默认的 `TrainingArguments.place_model_on_device`
          被覆盖为返回 `False`,它将设置为 `False`。
        - **is_in_train** -- 当前模型是否正在执行 `train`(例如,在 `train` 运行时调用 `evaluate`)。

    """

    # 下面的方法是 Trainer 的示例方法,保存在 trainer_pt_utils 中。
    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):
        """
        初始化一个 `Trainer` 对象,用于模型训练和评估。

        Args:
            model (Union[PreTrainedModel, nn.Module], optional): 要训练的模型。
            args (TrainingArguments, optional): 训练和评估的参数设置。
            data_collator (Optional[DataCollator], optional): 用于批处理数据的数据收集器。
            train_dataset (Optional[Dataset], optional): 训练数据集。
            eval_dataset (Optional[Union[Dataset, Dict[str, Dataset]]], optional): 评估数据集。
            tokenizer (Optional[PreTrainedTokenizerBase], optional): 用于处理输入数据的分词器。
            model_init (Optional[Callable[[], PreTrainedModel]], optional): 初始化模型的函数。
            compute_metrics (Optional[Callable[[EvalPrediction], Dict]], optional): 用于计算评估指标的函数。
            callbacks (Optional[List[TrainerCallback]], optional): 训练过程中使用的回调函数列表。
            optimizers (Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], optional): 优化器和学习率调度器的元组。
            preprocess_logits_for_metrics (Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]], optional): 对预测结果进行预处理的函数。
        """

    def _activate_neftune(self, model):
        r"""
        激活 NEFTune 方法,参考代码和论文:
        https://github.com/neelsjain/NEFTune
        https://arxiv.org/abs/2310.05914
        """
        unwrapped_model = unwrap_model(model)

        if _is_peft_model(unwrapped_model):
            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
        else:
            embeddings = unwrapped_model.get_input_embeddings()

        del unwrapped_model

        embeddings.neftune_noise_alpha = self.neftune_noise_alpha
        hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
        self.neftune_hook_handle = hook_handle
        return model

    def _deactivate_neftune(self, model):
        """
        停用 NEFTune 方法。确保先调用 `_activate_neftune`。
        """
        if not hasattr(self, "neftune_hook_handle"):
            raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")

        unwrapped_model = unwrap_model(model)

        if _is_peft_model(unwrapped_model):
            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
        else:
            embeddings = unwrapped_model.get_input_embeddings()

        self.neftune_hook_handle.remove()
        del embeddings.neftune_noise_alpha, unwrapped_model

    def add_callback(self, callback):
        """
        向当前的 [`~transformers.TrainerCallback`] 列表中添加一个回调函数。

        Args:
           callback (type or [`~transformers.TrainerCallback`]):
               [`~transformers.TrainerCallback`] 类或其实例。如果是类,则实例化一个该类的成员。
        """
    def pop_callback(self, callback):
        """
        Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.

        If the callback is not found, returns `None` (and no error is raised).

        Args:
           callback (`type` or [`~transformers.TrainerCallback`]):
               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            [`~transformers.TrainerCallback`]: The callback removed, if found.
        """
        # 调用回调处理器的方法,从当前回调列表中弹出并返回指定的回调对象
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Remove a callback from the current list of [`~transformers.TrainerCallback`].

        Args:
           callback (`type` or [`~transformers.TrainerCallback`]):
               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
        """
        # 调用回调处理器的方法,从当前回调列表中移除指定的回调对象
        self.callback_handler.remove_callback(callback)

    def _move_model_to_device(self, model, device):
        # 将模型移动到指定的设备上
        model = model.to(device)
        # 如果当前使用的是TPU并且模型具有"tie_weights"方法,则需要重新连接权重
        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
            model.tie_weights()

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            # 检查模型的前向方法签名,仅保留模型接受的参数列
            model_to_inspect = self.model
            # 如果模型是PEFT模型,则调整要检查的基础模型
            if _is_peft_model(self.model):
                if hasattr(self.model, "get_base_model"):
                    model_to_inspect = self.model.get_base_model()
                else:
                    # 对于PeftMixedModel,没有提供"get_base_model"方法,因此需要直接访问base_model.model
                    model_to_inspect = self.model.base_model.model
            # 获取模型前向方法的参数签名,并将参数名称添加到签名列列表中
            signature = inspect.signature(model_to_inspect.forward)
            self._signature_columns = list(signature.parameters.keys())
            # 标签可能命名为label或label_ids,使用默认数据拼接器来处理这些情况
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        # 如果不需要移除未使用的列,则直接返回原始数据集
        if not self.args.remove_unused_columns:
            return dataset
        # 根据需要设置签名列
        self._set_signature_columns_if_needed()
        # 获取签名列
        signature_columns = self._signature_columns

        # 找出数据集中被忽略的列(即不在签名列中的列)
        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
        # 如果有被忽略的列,则生成描述信息
        dset_description = "" if description is None else f"in the {description} set"
        # 记录日志,指出被忽略的列
        logger.info(
            f"The following columns {dset_description} don't have a corresponding argument in "
            f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
            f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
            " you can safely ignore this message."
        )

        # 筛选出在签名列中且存在于数据集列名中的列
        columns = [k for k in signature_columns if k in dataset.column_names]

        # 根据 datasets 库的版本进行不同处理
        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            # 设置数据集的格式,保留指定的列
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            # 移除数据集中的被忽略列
            return dataset.remove_columns(ignored_columns)

    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wrap the data collator in a callable removing unused columns."""
        # 如果不需要移除未使用的列,则直接返回原始的数据 collator
        if not self.args.remove_unused_columns:
            return data_collator
        # 根据需要设置签名列
        self._set_signature_columns_if_needed()
        # 获取签名列
        signature_columns = self._signature_columns

        # 创建一个移除未使用列的数据 collator 对象
        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None

        # 构建数据采样器。
        if self.args.group_by_length:
            # 如果数据集支持 datasets 库并且是 datasets.Dataset 类型,则获取长度信息
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
            # 获取模型输入的名称,通常是第一个模型输入的名称
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
            # 返回一个 LengthGroupedSampler 对象,用于按长度分组的数据采样
            return LengthGroupedSampler(
                self.args.train_batch_size * self.args.gradient_accumulation_steps,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
            )

        else:
            # 如果不按长度分组,则返回一个随机采样器 RandomSampler
            return RandomSampler(self.train_dataset)

    def get_train_dataloader(self) -> DataLoader:
        """
        返回训练数据加载器 [`~torch.utils.data.DataLoader`]。

        如果 `train_dataset` 不实现 `__len__`,则不使用采样器;否则使用随机采样器(适应分布式训练)。

        如果需要注入自定义行为,请子类化并重写此方法。
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        # 如果支持 datasets 库并且 train_dataset 是 datasets.Dataset 类型,则移除未使用的列
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            # 否则,使用移除了未使用列的数据集收集器
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        # 设置数据加载器的参数
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        # 如果 train_dataset 不是 IterableDataset 类型,则设置采样器等参数
        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_train_sampler()
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # 使用加速器准备数据加载器并返回
        return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
    # 返回用于评估数据集的采样器,根据不同条件返回不同的采样器或者 None
    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
        # 如果使用了遗留的预测循环方式
        if self.args.use_legacy_prediction_loop:
            # 如果当前环境支持 Torch XLA
            if is_torch_xla_available():
                # 返回一个按顺序分布的分布式采样器,用于 Torch XLA 环境
                return SequentialDistributedSampler(
                    eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
                )
            # 如果当前环境支持 SageMaker 的多进程
            elif is_sagemaker_mp_enabled():
                # 返回一个按顺序分布的分布式采样器,用于 SageMaker 多进程环境
                return SequentialDistributedSampler(
                    eval_dataset,
                    num_replicas=smp.dp_size(),
                    rank=smp.dp_rank(),
                    batch_size=self.args.per_device_eval_batch_size,
                )
            else:
                # 返回一个按顺序的采样器
                return SequentialSampler(eval_dataset)

        # 如果设备的数量小于等于 1,返回一个按顺序的采样器
        if self.args.world_size <= 1:
            return SequentialSampler(eval_dataset)
        else:
            # 否则返回 None,表示不使用任何特定的采样器
            return None
    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation `torch.utils.data.DataLoader`.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a `datasets.Dataset`, columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        # 检查是否没有提供 eval_dataset 且 self.eval_dataset 也没有设置,如果是则抛出数值错误
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        # 如果已经存在 _eval_dataloader,并且设置了 dataloader_persistent_workers,则通过加速器准备 _eval_dataloader 并返回
        if hasattr(self, "_eval_dataloader") and self.args.dataloader_persistent_workers:
            return self.accelerator.prepare(self._eval_dataloader)

        # 如果 eval_dataset 未提供,则使用 self.eval_dataset
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.data_collator

        # 如果使用了 datasets 库,并且 eval_dataset 是 datasets.Dataset 类型,则调用 _remove_unused_columns 方法删除不被 model.forward() 方法接受的列
        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
        else:
            # 否则调用 _get_collator_with_removed_columns 方法更新数据集的数据收集器,以移除不需要的列
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

        # 设置数据加载器参数
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        # 如果 eval_dataset 不是 IterableDataset 类型,则设置 sampler、drop_last 和 prefetch_factor 参数
        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # 创建评估数据加载器 DataLoader 对象
        eval_dataloader = DataLoader(eval_dataset, **dataloader_params)

        # 如果设置了 dataloader_persistent_workers,则将 eval_dataloader 赋值给 self._eval_dataloader
        if self.args.dataloader_persistent_workers:
            self._eval_dataloader = eval_dataloader

        # 返回通过加速器准备后的 eval_dataloader
        return self.accelerator.prepare(eval_dataloader)
    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            test_dataset (`torch.utils.data.Dataset`, *optional*):
                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        # 获取数据收集器
        data_collator = self.data_collator

        # 如果datasets库可用且test_dataset是datasets.Dataset类型,则移除模型不接受的列
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
        else:
            # 否则,使用移除特定列的数据收集器
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

        # 设置DataLoader的参数
        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        # 如果test_dataset不是IterableDataset类型,则配置采样器和其他参数
        if not isinstance(test_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        # 使用加速器准备DataLoader并返回
        return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """
        # 创建优化器
        self.create_optimizer()
        
        # 根据条件选择优化器
        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
            # 如果使用的SageMaker版本 >= 1.10 并且启用了fp16,解包优化器
            optimizer = self.optimizer.optimizer
        else:
            optimizer = self.optimizer
        
        # 创建学习率调度器
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

    def get_decay_parameter_names(self, model) -> List[str]:
        """
        Get all parameter names that weight decay will be applied to

        Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
        apply to those modules since this function only filter out instance of nn.LayerNorm
        """
        # 获取所有需要应用权重衰减的参数名称
        decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
        
        # 过滤掉包含"bias"的参数名
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        
        return decay_parameters
    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        # 根据条件选择要优化的模型
        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model

        # 如果没有指定优化器,则根据模型参数进行分组设置
        if self.optimizer is None:
            # 获取需要进行权重衰减的参数名列表
            decay_parameters = self.get_decay_parameter_names(opt_model)
            # 设置优化器的分组参数
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [
                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
                    ],
                    "weight_decay": 0.0,
                },
            ]

            # 获取优化器类和参数
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)

            # 如果 optimizer_kwargs 中包含 'params' 键,则使用它覆盖之前设置的 optimizer_grouped_parameters
            # 例如,适用于 GaLore 优化器
            if "params" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("params")

            # 如果 optimizer_kwargs 中包含 'optimizer_dict' 键,则使用它覆盖 optimizer_grouped_parameters
            # 避免参数冲突问题,适用于逐层的虚拟优化器
            if "optimizer_dict" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")

            # 创建优化器实例
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
            
            # 如果使用的优化器是 'Adam8bit',则执行特定处理
            if optimizer_cls.__name__ == "Adam8bit":
                import bitsandbytes

                # 获取全局优化管理器的实例
                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

                skipped = 0
                # 注册所有嵌入层模块并指定优化参数的精度为 32 位浮点数
                for module in opt_model.modules():
                    if isinstance(module, nn.Embedding):
                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                        logger.info(f"skipped {module}: {skipped/2**20}M params")
                        manager.register_module_override(module, "weight", {"optim_bits": 32})
                        logger.debug(f"bitsandbytes: will optimize {module} in fp32")
                logger.info(f"skipped: {skipped/2**20}M params")

        # 如果启用了 SageMaker 多进程,则使用分布式优化器
        if is_sagemaker_mp_enabled():
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

        # 返回设置好的优化器
        return self.optimizer

    @staticmethod
    def get_optimizer_cls_and_kwargs(
        args: TrainingArguments, model: Optional[PreTrainedModel] = None
    ):
        """
        Helper function to retrieve the optimizer class and its keyword arguments based on training arguments and model.
        """
        # 省略,用于获取优化器类和参数的辅助函数
        pass
    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        设置调度器。在调用此方法之前,训练器的优化器必须已经设置好,或者作为参数传递进来。

        Args:
            num_training_steps (int): 要执行的训练步数。
        """
        # 如果当前没有设置学习率调度器,则根据参数设置一个新的调度器
        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                optimizer=self.optimizer if optimizer is None else optimizer,
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
                num_training_steps=num_training_steps,
                scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
            )
            # 标记已经创建了学习率调度器
            self._created_lr_scheduler = True
        # 返回当前的学习率调度器对象
        return self.lr_scheduler

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        辅助函数,通过访问其数据集来获取 [`~torch.utils.data.DataLoader`] 中的样本数量。
        当 dataloader.dataset 不存在或长度为零时,尽可能进行估计。
        """
        try:
            dataset = dataloader.dataset
            # 对于 IterableDatasetShard,需要进一步深入获取其数据集的长度
            if isinstance(dataset, IterableDatasetShard):
                return len(dataloader.dataset.dataset)
            # 返回 DataLoader 中数据集的长度
            return len(dataloader.dataset)
        except (NameError, AttributeError, TypeError):  # 如果没有数据集或长度信息,通过 DataLoader 的长度估算
            # 通过 DataLoader 的长度估算样本数,乘以每设备训练批次大小
            return len(dataloader) * self.args.per_device_train_batch_size

    def num_tokens(self, train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
        """
        辅助函数,通过枚举数据加载器来获取 [`~torch.utils.data.DataLoader`] 中的令牌数量。
        """
        train_tokens = 0
        try:
            # 枚举训练数据加载器的步骤和批次
            for step, batch in enumerate(train_dl):
                tokens = batch["input_ids"].numel()  # 获取当前批次中 "input_ids" 的令牌数量
                if max_steps is not None:
                    return tokens * max_steps  # 如果指定了最大步数,则返回按最大步数估算的令牌总数
                train_tokens += tokens  # 累加当前批次的令牌数量到训练令牌总数
            return train_tokens  # 返回训练数据加载器中的总令牌数量
        except KeyError:
            logger.warning("Cannot get num_tokens from dataloader")  # 日志警告,无法从数据加载器获取令牌数量信息
            return train_tokens  # 返回当前已经累积的训练令牌总数
    # 设置超参数搜索的初始化代码
    def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
        """HP search setup code"""
        # 将试验对象保存到实例属性中
        self._trial = trial

        # 如果超参数搜索后端为空或试验对象为空,则直接返回
        if self.hp_search_backend is None or trial is None:
            return

        # 根据选择的超参数搜索后端进行参数初始化
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            # 使用 Optuna 提供的超参数空间生成参数字典
            params = self.hp_space(trial)
        elif self.hp_search_backend == HPSearchBackend.RAY:
            # 如果使用 Ray 提供的试验参数,排除其中的 WandB 相关项
            params = trial
            params.pop("wandb", None)
        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
            # 如果使用 SigOpt 提供的试验分配,将字符串形式的值转换为整数
            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
        elif self.hp_search_backend == HPSearchBackend.WANDB:
            # 如果使用 WandB 提供的试验参数
            params = trial

        # 根据参数字典更新实例属性中的参数
        for key, value in params.items():
            # 如果参数在 self.args 中不存在,则发出警告
            if not hasattr(self.args, key):
                logger.warning(
                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
                    " `TrainingArguments`."
                )
                continue
            old_attr = getattr(self.args, key, None)
            # 将值转换为与旧属性相同类型的值
            if old_attr is not None:
                value = type(old_attr)(value)

            # 更新 self.args 中的参数值
            setattr(self.args, key, value)

        # 根据不同的超参数搜索后端记录日志信息
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            logger.info(f"Trial: {trial.params}")
        if self.hp_search_backend == HPSearchBackend.SIGOPT:
            logger.info(f"SigOpt Assignments: {trial.assignments}")
        if self.hp_search_backend == HPSearchBackend.WANDB:
            logger.info(f"W&B Sweep parameters: {trial}")

        # 如果启用了 DeepSpeed 加速,并且未设置 args.deepspeed,则引发异常
        if self.is_deepspeed_enabled:
            if self.args.deepspeed is None:
                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")

            # 重新构建 DeepSpeed 配置,以反映更新后的训练参数
            from accelerate.utils import DeepSpeedPlugin
            from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

            # 创建 HfTrainerDeepSpeedConfig 实例
            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
            # 处理 trainer 配置过程
            self.args.hf_deepspeed_config.trainer_config_process(self.args)
            # 创建 DeepSpeed 插件
            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)

        # 创建加速器并进行后处理
        self.create_accelerator_and_postprocess()
    # 将试验结果报告给超参数搜索后端(例如Optuna或Ray)
    def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
        # 如果超参数搜索后端未设置或试验对象为空,则直接返回
        if self.hp_search_backend is None or trial is None:
            return
        
        # 复制metrics,以免修改原始数据
        metrics = metrics.copy()
        
        # 计算当前目标值
        self.objective = self.compute_objective(metrics)
        
        # 如果使用Optuna作为超参数搜索后端
        if self.hp_search_backend == HPSearchBackend.OPTUNA:
            import optuna
            
            # 如果试验不是多目标优化
            if not trial.study._is_multi_objective():
                # 向Optuna试验报告当前目标值和步数
                trial.report(self.objective, step)
                
                # 如果试验应该被剪枝,则结束训练并抛出TrialPruned异常
                if trial.should_prune():
                    self.callback_handler.on_train_end(self.args, self.state, self.control)
                    raise optuna.TrialPruned()
        
        # 如果使用Ray作为超参数搜索后端
        elif self.hp_search_backend == HPSearchBackend.RAY:
            import ray.train
            
            # 使用临时目录保存检查点
            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                checkpoint = None
                # 如果需要保存模型检查点
                if self.control.should_save:
                    # 保存当前模型检查点到临时目录
                    self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
                    checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
                
                # 将目标值添加到metrics中
                metrics["objective"] = self.objective
                
                # 向Ray报告metrics和检查点
                ray.train.report(metrics, checkpoint=checkpoint)

    # 保存当前训练状态的检查点
    def _tune_save_checkpoint(self, checkpoint_dir: str):
        # 设置输出目录为临时检查点目录下的指定全局步数的目录
        output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
        
        # 将模型保存到指定的输出目录
        self.save_model(output_dir, _internal_call=True)
        
        # 如果需要保存参数
        if self.args.should_save:
            # 将当前训练状态保存为JSON文件
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
            
            # 保存优化器状态字典
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            
            # 保存学习率调度器状态字典
            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

    # 调用模型初始化函数并返回初始化后的模型
    def call_model_init(self, trial=None):
        # 获取model_init函数的参数个数
        model_init_argcount = number_of_arguments(self.model_init)
        
        # 根据model_init函数的参数个数调用相应的初始化方式
        if model_init_argcount == 0:
            model = self.model_init()
        elif model_init_argcount == 1:
            model = self.model_init(trial)
        else:
            raise RuntimeError("model_init should have 0 or 1 argument.")
        
        # 如果模型初始化结果为None,则抛出异常
        if model is None:
            raise RuntimeError("model_init should not return None.")
        
        # 返回初始化后的模型
        return model
    # 定义一个方法用于在 PyTorch 中评估 JIT 模式下的模型
    def torch_jit_model_eval(self, model, dataloader, training=False):
        # 如果不是训练模式
        if not training:
            # 如果数据加载器为 None,则记录警告并返回原始模型
            if dataloader is None:
                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
                return model
            # 获取一个示例批次数据
            example_batch = next(iter(dataloader))
            # 准备输入数据
            example_batch = self._prepare_inputs(example_batch)
            try:
                # 复制模型用于 JIT 编译
                jit_model = copy.copy(model)
                # 将模型设置为评估模式
                jit_model.eval()
                # 提取模型中的原始 forward 方法
                original_forward = jit_model.__dict__.pop("_original_forward", None)
                # 如果存在原始 forward 方法,则恢复
                if original_forward:
                    jit_model.forward = original_forward
                # 使用加速器自动混合精度,并关闭缓存,同时不计算梯度
                with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
                    # 根据 PyTorch 版本选择不同的 JIT 编译方式
                    if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
                        # 如果示例批次是字典,则使用关键字参数输入 JIT 编译
                        if isinstance(example_batch, dict):
                            jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
                        # 否则,构建适合的关键字参数输入
                        else:
                            jit_model = torch.jit.trace(
                                jit_model,
                                example_kwarg_inputs={key: example_batch[key] for key in example_batch},
                                strict=False,
                            )
                    else:
                        # 构建适合的位置参数输入
                        jit_inputs = []
                        for key in example_batch:
                            example_tensor = torch.ones_like(example_batch[key])
                            jit_inputs.append(example_tensor)
                        jit_inputs = tuple(jit_inputs)
                        jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
                # 冻结 JIT 编译后的模型,以提高性能
                jit_model = torch.jit.freeze(jit_model)
                # 使用 JIT 模型执行两次示例批次数据
                with torch.no_grad():
                    jit_model(**example_batch)
                    jit_model(**example_batch)
                # 更新模型为 JIT 编译后的模型,并关闭 CPU 自动混合精度优化
                model = jit_model
                self.use_cpu_amp = False
            # 捕获可能的异常并记录警告
            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")

        # 返回最终的模型
        return model
    # 使用 IPEX 优化模型,如果 IPEX 不可用则抛出 ImportError 异常
    def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
        if not is_ipex_available():
            raise ImportError(
                "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
                " to https://github.com/intel/intel-extension-for-pytorch."
            )

        import intel_extension_for_pytorch as ipex

        if not training:
            # 如果不是训练模式,设置模型为评估模式
            model.eval()
            # 根据条件设置 dtype,如果不在训练中且启用了 bf16 全量评估,则使用 torch.bfloat16
            dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
            # 对模型进行优化,设置数据类型和优化级别 O1,同时禁用 conv_bn_folding 来避免符号跟踪中的问题
            model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
        else:
            if not model.training:
                # 如果模型不处于训练状态,则设置为训练状态
                model.train()
            # 对模型进行优化,设置数据类型、优化器和优化级别 O1,同时进行原地操作
            model, self.optimizer = ipex.optimize(
                model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
            )

        return model

    # 训练方法,支持从检查点恢复和使用 Optuna 进行超参数搜索
    def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
        ignore_keys_for_eval: Optional[List[str]] = None,
        **kwargs,
    ):
    
    # 内部训练循环方法,接受批处理大小、参数、从检查点恢复的标志和超参数搜索实例作为输入
    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
    
    # 根据超参数搜索后端和试验对象确定输出目录
    def _get_output_dir(self, trial):
        if self.hp_search_backend is not None and trial is not None:
            if self.hp_search_backend == HPSearchBackend.OPTUNA:
                run_id = trial.number
            elif self.hp_search_backend == HPSearchBackend.RAY:
                import ray.train
                run_id = ray.train.get_context().get_trial_id()
            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
                run_id = trial.id
            elif self.hp_search_backend == HPSearchBackend.WANDB:
                import wandb
                run_id = wandb.run.id
            # 如果设置了超参数名称生成函数,则使用该函数;否则使用默认名称格式
            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
            # 组合运行目录,基于设定的输出目录和生成的运行名称
            run_dir = os.path.join(self.args.output_dir, run_name)
        else:
            # 如果未设置超参数搜索后端或没有试验对象,则直接使用默认的输出目录
            run_dir = self.args.output_dir
        return run_dir

    # 在加载模型后发出警告,根据加载结果中的丢失和意外键发出相应的警告信息
    def _issue_warnings_after_load(self, load_result):
        if len(load_result.missing_keys) != 0:
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
                # 如果加载结果中的丢失键与保存时忽略的键匹配,则进行权重绑定
                self.model.tie_weights()
            else:
                # 否则发出丢失键的警告
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
        if len(load_result.unexpected_keys) != 0:
            # 如果加载结果中存在意外键,则发出相应的警告
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )
        # 检查是否应记录日志,并且全局步数大于上次记录的步数
        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
            # 如果使用 Torch XLA,标记当前步骤
            if is_torch_xla_available():
                xm.mark_step()

            # 初始化日志字典
            logs: Dict[str, float] = {}

            # 使用 all_gather + mean() 计算所有进程的平均损失
            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

            # 将训练损失重置为零
            tr_loss -= tr_loss

            # 计算并记录平均损失
            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            
            # 如果存在梯度范数,记录梯度范数值
            if grad_norm is not None:
                logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
            
            # 记录当前学习率
            logs["learning_rate"] = self._get_learning_rate()

            # 累加总损失
            self._total_loss_scalar += tr_loss_scalar

            # 更新上次记录的全局步数
            self._globalstep_last_logged = self.state.global_step

            # 存储 FLOPs(浮点运算次数)
            self.store_flos()

            # 记录日志
            self.log(logs)

        metrics = None
        # 如果需要进行评估
        if self.control.should_evaluate:
            # 执行评估并获取指标
            metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
            # 报告给超参数搜索器当前的全局步数和指标
            self._report_to_hp_search(trial, self.state.global_step, metrics)

            # 如果使用 ReduceLROnPlateau 类型的学习率调度器
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                # 获取用于最佳模型选择的指标名称
                metric_to_check = self.args.metric_for_best_model
                if not metric_to_check.startswith("eval_"):
                    metric_to_check = f"eval_{metric_to_check}"
                # 调用学习率调度器的步进方法,根据最新的指标值更新学习率
                self.lr_scheduler.step(metrics[metric_to_check])

        # 如果需要保存模型
        if self.control.should_save:
            # 保存模型检查点
            self._save_checkpoint(model, trial, metrics=metrics)
            # 调用保存模型时的回调处理
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)
    # 从检查点 `checkpoint` 中加载 RNG 状态
    def _load_rng_state(self, checkpoint):
        # 如果检查点为 None,则直接返回
        if checkpoint is None:
            return
        
        # 如果使用多进程(分布式训练),根据进程索引读取相应的 RNG 状态文件
        if self.args.world_size > 1:
            process_index = self.args.process_index
            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
            # 如果对应的 RNG 文件不存在,则记录警告信息并返回
            if not os.path.isfile(rng_file):
                logger.info(
                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            # 如果是单进程,则读取默认的 RNG 状态文件
            rng_file = os.path.join(checkpoint, "rng_state.pth")
            # 如果默认的 RNG 文件不存在,则记录警告信息并返回
            if not os.path.isfile(rng_file):
                logger.info(
                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        # 加载检查点中的 RNG 状态
        checkpoint_rng_state = torch.load(rng_file)
        # 恢复 Python 内置的随机数生成器状态
        random.setstate(checkpoint_rng_state["python"])
        # 恢复 NumPy 随机数生成器状态
        np.random.set_state(checkpoint_rng_state["numpy"])
        # 恢复 PyTorch CPU 随机数生成器状态
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        
        # 如果 CUDA 可用
        if torch.cuda.is_available():
            # 如果是分布式训练且使用并行模式
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                # 恢复所有 GPU 的 CUDA 随机数生成器状态
                torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
            else:
                try:
                    # 恢复当前 GPU 的 CUDA 随机数生成器状态
                    torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
                except Exception as e:
                    logger.info(
                        f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
        
        # 如果使用了 Torch XLA
        if is_torch_xla_available():
            # 恢复 Torch XLA 的随机数生成器状态
            xm.set_rng_state(checkpoint_rng_state["xla"])
        
        # 如果使用了 Torch NPU
        if is_torch_npu_available():
            # 如果是分布式训练且使用并行模式
            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
                # 恢复所有 NPU 的随机数生成器状态
                torch.npu.random.set_rng_state_all(checkpoint_rng_state["npu"])
            else:
                try:
                    # 恢复当前 NPU 的随机数生成器状态
                    torch.npu.random.set_rng_state(checkpoint_rng_state["npu"])
                except Exception as e:
                    logger.info(
                        f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}"
                        "\nThis won't yield the same results as if the training had not been interrupted."
                    )
    # 定义一个保存检查点的方法,用于保存模型及其相关状态和参数
    def _save_checkpoint(self, model, trial, metrics=None):
        # 在所有情况下,包括使用 ddp/dp/deepspeed,self.model 总是指向我们想要保存的模型的引用,除了 FullyShardedDDP。
        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"

        # 保存模型检查点的文件夹名称,包含全局步数信息
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        # 如果未进行超参数搜索并且没有试验信息,则存储 FLOPS
        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        # 获取运行输出目录,根据试验信息创建检查点输出目录
        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        # 调用保存模型的方法,并指明这是内部调用
        self.save_model(output_dir, _internal_call=True)

        # 如果不仅仅保存模型,则继续保存优化器和调度器
        if not self.args.save_only_model:
            # 保存优化器和调度器状态
            self._save_optimizer_and_scheduler(output_dir)
            # 保存随机数生成器状态
            self._save_rng_state(output_dir)

        # 确定新的最佳指标和最佳模型检查点
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            # 如果当前的指标值更好,则更新最佳指标和最佳模型检查点路径
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # 如果需要保存 Trainer 的状态信息,则将其保存为 JSON 文件
        if self.args.should_save:
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        # 如果指定了推送到 Hub,则从当前检查点路径进行推送
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # 可能会删除一些旧的检查点
        if self.args.should_save:
            # 仅依赖于数字化的检查点 id 进行旋转管理
            # 在某些云环境中,特别是一些 fuse 文件系统中,mtime 并不可靠
            self._rotate_checkpoints(use_mtime=False, output_dir=run_dir)
    # 在非分布式训练中保存随机数生成器状态
    rng_states = {
        "python": random.getstate(),  # 获取 Python 内置随机数生成器的状态
        "numpy": np.random.get_state(),  # 获取 NumPy 随机数生成器的状态
        "cpu": torch.random.get_rng_state(),  # 获取 PyTorch CPU 随机数生成器的状态
    }
    # 如果 CUDA 可用
    if torch.cuda.is_available():
        # 如果在分布式模式下
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
            # 在非分布式情况下,保存全局 CUDA 随机数生成器的状态(会考虑到 DataParallel)
            rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
        else:
            # 获取当前 CUDA 设备上的随机数生成器的状态
            rng_states["cuda"] = torch.cuda.random.get_rng_state()

    # 如果 PyTorch XLA 可用
    if is_torch_xla_available():
        # 获取当前 XLA 设备上的随机数生成器的状态
        rng_states["xla"] = xm.get_rng_state()

    # 如果 PyTorch NPU 可用
    if is_torch_npu_available():
        # 如果在分布式模式下
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
            # 获取所有 NPU 设备上的随机数生成器的状态
            rng_states["npu"] = torch.npu.random.get_rng_state_all()
        else:
            # 获取当前 NPU 设备上的随机数生成器的状态
            rng_states["npu"] = torch.npu.random.get_rng_state()

    # 在保存模型之前,确保输出目录已经存在,如果不存在则创建
    os.makedirs(output_dir, exist_ok=True)

    # 根据当前进程数决定保存的文件名
    if self.args.world_size <= 1:
        # 如果只有一个进程,则保存为 rng_state.pth
        torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
    else:
        # 如果有多个进程,则保存为 rng_state_{进程编号}.pth
        torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
    # 保存优化器和调度器状态到指定的输出目录
    def _save_optimizer_and_scheduler(self, output_dir):
        # 如果支持 Torch XLA 加速
        if is_torch_xla_available():
            # 使用 XM 进行同步,保存优化器状态字典到指定路径
            xm.rendezvous("saving_optimizer_states")
            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            # 使用 XM 保存学习率调度器状态字典到指定路径,并记录警告信息
            with warnings.catch_warnings(record=True) as caught_warnings:
                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
                reissue_pt_warnings(caught_warnings)
        # 如果启用了 SageMaker 分布式训练
        elif is_sagemaker_mp_enabled():
            # 获取本地优化器状态字典,并进行屏障同步
            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
            smp.barrier()
            # 如果当前进程是第一个或者配置要求分片优化器状态,保存优化器状态到指定路径
            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
                smp.save(
                    opt_state_dict,
                    os.path.join(output_dir, OPTIMIZER_NAME),
                    partial=True,
                    v3=smp.state.cfg.shard_optimizer_state,
                )
        # 如果启用了 DeepSpeed
        elif self.is_deepspeed_enabled:
            # 根据条件判断是否排除冻结参数并保存模型检查点到指定路径
            accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
                inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
            )
            if accept_exclude_frozen_parameters and _is_peft_model(self.model):
                self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
            else:
                self.model_wrapped.save_checkpoint(output_dir)
        # 如果启用了 FSDP(Fully Sharded Data Parallelism)
        elif self.is_fsdp_enabled:
            # 保存 FSDP 特定的模型检查点和优化器状态到指定路径
            save_fsdp_model(
                self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
            )
            save_fsdp_optimizer(
                self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
            )
        # 如果需要保存模型
        elif self.args.should_save:
            # 仅保存优化器状态字典到指定路径
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

        # 保存学习率调度器状态字典到指定路径,如果满足保存条件
        is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
            self.lr_scheduler, DeepSpeedSchedulerWrapper
        )
        if (
            self.args.should_save
            and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
            and not is_torch_xla_available()
        ):
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            # 处理保存过程中的警告信息
            reissue_pt_warnings(caught_warnings)
    def hyperparameter_search(
        self,
        hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
        compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
        n_trials: int = 20,
        direction: Union[str, List[str]] = "minimize",
        backend: Optional[Union["str", HPSearchBackend]] = None,
        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
        **kwargs,
    ):
        """
        Perform hyperparameter search using Optuna.

        Args:
            hp_space (Optional[Callable[["optuna.Trial"], Dict[str, float]]]):
                Function defining the hyperparameter search space.
            compute_objective (Optional[Callable[[Dict[str, float]], float]]):
                Function to compute the objective given a set of hyperparameters.
            n_trials (int):
                Number of trials (hyperparameter combinations) to evaluate.
            direction (Union[str, List[str]]):
                Direction to optimize the objective, either 'minimize' or 'maximize'.
            backend (Optional[Union[str, HPSearchBackend]]):
                Backend for hyperparameter search.
            hp_name (Optional[Callable[["optuna.Trial"], str]]):
                Function to generate a name for the hyperparameter set.
            **kwargs:
                Additional keyword arguments passed to the hyperparameter search.

        Returns:
            None
        """
        pass

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)

    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.

        Args:
            data (Union[torch.Tensor, Any]):
                The input data to prepare.

        Returns:
            Union[torch.Tensor, Any]:
                Prepared data ready to be fed into the model.
        """
        if isinstance(data, Mapping):
            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            kwargs = {"device": self.args.device}
            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
                # Adjust dtype for deepspeed enabled models
                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
            return data.to(**kwargs)
        return data
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
        # 调用内部方法,准备输入数据,确保转换为张量(如果尚未),同时处理潜在的状态
        inputs = self._prepare_input(inputs)
        
        # 如果输入数据为空,抛出数值错误,防止模型无法训练
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
        
        # 如果设置了历史索引且存在历史数据,则将历史数据添加到输入中
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past
        
        # 返回准备好的输入数据字典
        return inputs

    def compute_loss_context_manager(self):
        """
        A helper wrapper to group together context managers.
        """
        # 调用自动混合精度上下文管理器的帮助器包装器
        return self.autocast_smart_context_manager()

    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
        """
        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
        arguments, depending on the situation.
        """
        # 根据是否启用 CPU 自动混合精度,创建相应的上下文管理器
        if self.use_cpu_amp:
            ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
        else:
            ctx_manager = contextlib.nullcontext()

        return ctx_manager

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with training loss on this batch.
        """
        # 设置模型为训练模式
        model.train()
        
        # 准备输入数据
        inputs = self._prepare_inputs(inputs)

        # 如果启用 SageMaker 多进程训练,则调用相应的前向-反向传播函数
        if is_sagemaker_mp_enabled():
            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
            return loss_mb.reduce_mean().detach().to(self.args.device)

        # 使用上下文管理器计算损失
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        # 如果有多个 GPU,则对损失进行平均,用于多 GPU 并行训练
        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        # 如果使用 Apex 混合精度训练,则使用 Amp 应用梯度缩放
        if self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.accelerator.backward(loss)

        # 返回分离后的损失值,除以梯度累积步骤数
        return loss.detach() / self.args.gradient_accumulation_steps
    # 定义一个方法,计算模型的损失值。
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # 如果存在标签平滑器且输入中包含标签,则将标签从输入中弹出
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        # 使用模型处理输入数据,获取模型输出
        outputs = model(**inputs)
        # 如果定义了保存过去状态的索引
        if self.args.past_index >= 0:
            # 将输出中对应索引的内容保存到 _past 属性中
            self._past = outputs[self.args.past_index]

        # 如果存在标签,则执行以下逻辑
        if labels is not None:
            # 获取未封装的模型(去除所有包装器)
            unwrapped_model = unwrap_model(model)
            # 判断模型是否为 PEFT 模型
            if _is_peft_model(unwrapped_model):
                # 获取模型的名称
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # 如果模型名称存在于 causal LM 映射名称中
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                # 使用标签平滑器计算损失(带标签偏移)
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                # 使用标签平滑器计算损失
                loss = self.label_smoother(outputs, labels)
        else:
            # 如果模型输出是字典且没有包含损失键
            if isinstance(outputs, dict) and "loss" not in outputs:
                # 抛出数值错误,说明模型未从输入中返回损失值
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # 从 outputs 中获取损失值(可能是元组形式的 ModelOutput)
            # 这里不使用 .loss 是因为模型可能返回元组而不是 ModelOutput
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        # 如果需要返回输出内容,则同时返回损失和模型输出
        return (loss, outputs) if return_outputs else loss

    # 定义一个方法,用于判断当前进程是否是本地的主进程(在分布式训练中,一台机器上的主进程)
    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
        machines) main process.
        """
        return self.args.local_process_index == 0

    # 定义一个方法,用于判断当前进程是否是全局的主进程(在分布式训练中,只有一个进程会返回 True)
    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on several
        machines, this is only going to be `True` for one process).
        """
        # 对于 SageMaker ModelParallel 的特殊情况,进程索引为 dp_process_index 而不是全局进程索引
        if is_sagemaker_mp_enabled():
            return smp.rank() == 0
        else:
            # 判断当前进程索引是否为 0
            return self.args.process_index == 0
    # 定义保存模型的方法,可以指定输出目录和是否内部调用标志
    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        """
        Will save the model, so you can reload it using `from_pretrained()`.

        Will only save from the main process.
        """

        # 如果未指定输出目录,则使用默认的输出目录
        if output_dir is None:
            output_dir = self.args.output_dir

        # 如果当前环境支持 Torch XLA,保存模型到指定目录
        if is_torch_xla_available():
            self._save_tpu(output_dir)
        
        # 如果是在 SageMaker 多进程环境下,创建输出目录,并保存模型状态字典
        elif is_sagemaker_mp_enabled():
            os.makedirs(output_dir, exist_ok=True)
            state_dict = self.model_wrapped.state_dict()
            # 如果应当保存模型,则执行保存操作
            if self.args.should_save:
                self._save(output_dir, state_dict=state_dict)
            # 如果是 SageMaker MP 版本大于等于 1.10,则创建一个标志文件指示模型状态字典已保存
            if IS_SAGEMAKER_MP_POST_1_10:
                Path(os.path.join(output_dir, "user_content.pt")).touch()
        
        # 如果启用了 FSDP(Fully Sharded Data Parallelism),保存模型状态字典
        elif self.is_fsdp_enabled:
            # 检查是否完整状态字典,并且加速库版本大于 0.24.1
            if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
                version.parse(accelerate_version) > version.parse("0.24.1")
            ):
                state_dict = self.accelerator.get_state_dict(self.model)
                # 如果应当保存模型,则执行保存操作
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
        
        # 如果启用了 DeepSpeed,尝试获取 DeepSpeed 的状态字典并保存
        elif self.is_deepspeed_enabled:
            try:
                state_dict = self.accelerator.get_state_dict(self.deepspeed)
                # 如果应当保存模型,则执行保存操作
                if self.args.should_save:
                    self._save(output_dir, state_dict=state_dict)
            except ValueError:
                # 如果出现异常,则警告并保存空的状态字典
                logger.warning(
                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
                    " zero_to_fp32.py to recover weights"
                )
                if self.args.should_save:
                    self._save(output_dir, state_dict={})
                # 移除虚拟的状态字典
                remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
                # 使用包装后的模型保存完整的检查点
                self.model_wrapped.save_checkpoint(output_dir)
        
        # 如果应当保存模型,则执行保存操作
        elif self.args.should_save:
            self._save(output_dir)

        # 当用户调用 `save_model` 且 `push_to_hub` 为真时,推送模型到 Hub
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")
    # 定义一个方法用于保存模型到指定目录,可以选择性地指定输出目录,否则使用默认输出目录
    def _save_tpu(self, output_dir: Optional[str] = None):
        # 如果未提供输出目录,则使用 self.args.output_dir
        output_dir = output_dir if output_dir is not None else self.args.output_dir

        # 打印日志,指示正在保存模型检查点到指定的输出目录
        logger.info(f"Saving model checkpoint to {output_dir}")

        # 获取模型对象
        model = self.model

        # 在TPU上标记当前步骤
        xm.mark_step()

        # 将模型移动到CPU上保存
        model.to("cpu")

        # 如果是主节点(master ordinal),创建输出目录(如果不存在),并保存训练参数
        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        # 保存训练好的模型和配置,使用 `save_pretrained()` 方法
        # 可以通过 `from_pretrained()` 方法重新加载
        supported_classes = (PushToHubMixin,)
        xm.rendezvous("saving_checkpoint")

        # 如果模型不是支持的类别,则尝试解开模型再保存其状态字典
        if not isinstance(model, supported_classes):
            if isinstance(unwrap_model(model), supported_classes):
                unwrap_model(model).save_pretrained(
                    output_dir,
                    is_main_process=self.args.should_save,
                    state_dict=model.state_dict(),
                    save_function=xm.save,
                    safe_serialization=self.args.save_safetensors,
                )
            else:
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                state_dict = model.state_dict()
                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            # 如果模型属于支持的类别,则直接调用其 save_pretrained 方法保存
            model.save_pretrained(
                output_dir,
                is_main_process=self.args.should_save,
                save_function=xm.save,
                safe_serialization=self.args.save_safetensors,
            )

        # 如果存在 tokenizer 并且应该保存,则保存 tokenizer 到输出目录
        if self.tokenizer is not None and self.args.should_save:
            self.tokenizer.save_pretrained(output_dir)

        # 将模型从 CPU 移回到指定设备(通常是 GPU 或 TPU),确保后续计算可以继续正常运行
        model.to(self.args.device)
    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # 如果指定了输出目录,则使用指定的输出目录,否则使用默认的输出目录 self.args.output_dir
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        # 创建输出目录,如果目录已存在则不报错
        os.makedirs(output_dir, exist_ok=True)
        # 记录日志,指示正在将模型检查点保存到哪个目录
        logger.info(f"Saving model checkpoint to {output_dir}")

        # 定义支持的模型类别,如果没有 PEFT 可用,则只支持 PreTrainedModel;否则还支持 PeftModel
        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
        # 保存训练好的模型和配置,使用 `save_pretrained()` 方法保存
        # 可以使用 `from_pretrained()` 方法重新加载
        if not isinstance(self.model, supported_classes):
            # 如果没有提供状态字典,则获取当前模型的状态字典
            if state_dict is None:
                state_dict = self.model.state_dict()

            # 如果模型属于支持的类别,则调用其 save_pretrained 方法保存模型和状态字典
            if isinstance(unwrap_model(self.model), supported_classes):
                unwrap_model(self.model).save_pretrained(
                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
                )
            else:
                # 如果模型不是 `PreTrainedModel` 类型,则只保存其状态字典
                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
                # 如果设置了保存安全张量,则使用 safetensors 库保存文件
                if self.args.save_safetensors:
                    safetensors.torch.save_file(
                        state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
                    )
                else:
                    # 否则使用 PyTorch 自带的 torch.save 方法保存状态字典
                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            # 如果模型属于支持的类别,则直接调用其 save_pretrained 方法保存模型和状态字典
            self.model.save_pretrained(
                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
            )

        # 如果存在 tokenizer 对象,则保存其预训练模型
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        # 最佳实践:将训练参数与训练好的模型一起保存
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def store_flos(self):
        # 存储模型中所用的浮点运算总数
        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
            # 如果并行模式为 DISTRIBUTED,则将当前计算的 FLOPS 分布式广播到所有设备并求和
            self.state.total_flos += (
                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
            )
            # 清零当前计算的 FLOPS
            self.current_flos = 0
        else:
            # 否则直接将当前计算的 FLOPS 加到总数中,并清零当前计算的 FLOPS
            self.state.total_flos += self.current_flos
            self.current_flos = 0

    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ):
        # 该函数暂未提供具体实现,在这里只做函数声明
    # 返回已排序的检查点路径列表,按照文件修改时间或者文件名中的数字排序
    def _sorted_checkpoints(self, use_mtime=False, output_dir=None) -> List[str]:
        ordering_and_checkpoint_path = []

        # 获取匹配指定前缀的所有检查点文件夹路径
        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]

        # 遍历每个检查点文件夹路径
        for path in glob_checkpoints:
            if use_mtime:
                # 如果使用文件修改时间作为排序依据,将时间戳和路径添加到排序列表中
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                # 否则,尝试从文件名中提取数字,将数字和路径添加到排序列表中
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match is not None and regex_match.groups() is not None:
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        # 按照时间戳或者文件名中的数字排序检查点路径
        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        # 仅保留排序后的检查点路径,不包含排序依据
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]

        # 确保最佳模型不会被删除,将其移动到检查点列表的前部
        if (
            self.state.best_model_checkpoint is not None
            and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
        ):
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]

        # 返回排序后的检查点路径列表
        return checkpoints_sorted

    # 根据保存的总数限制旋转检查点,删除多余的检查点
    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # 获取排序后的检查点路径列表
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
        # 如果检查点数量小于或等于保存总数限制,则不需要删除
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        # 根据特定条件调整保存总数限制
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        # 计算需要删除的检查点数量,并获取待删除的检查点路径列表
        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]

        # 删除待删除的检查点文件夹
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
            shutil.rmtree(checkpoint, ignore_errors=True)
        """
        # 运行预测并返回预测结果和可能的指标。

        # 根据数据集和使用情况,测试数据集可能包含标签。在这种情况下,该方法还会返回指标,例如在 `evaluate()` 中一样。

        Args:
            test_dataset (`Dataset`):
                要运行预测的数据集。如果是 `datasets.Dataset`,则会自动删除模型 `forward()` 方法不接受的列。必须实现 `__len__` 方法。
            ignore_keys (`List[str]`, *可选*):
                在模型输出中应忽略的键列表(如果是字典)。
            metric_key_prefix (`str`, *可选*, 默认为 `"test"`):
                用作指标键前缀的可选前缀。例如,如果前缀是 "test"(默认),则指标 "bleu" 将命名为 "test_bleu"。

        <Tip>

        如果您的预测或标签具有不同的序列长度(例如,因为您在标记分类任务中进行动态填充),则会对预测进行填充(在右侧),以允许串联到一个数组中。填充索引为 -100。

        </Tip>

        Returns: *NamedTuple* 具有以下键的命名元组:

            - predictions (`np.ndarray`): 对 `test_dataset` 的预测。
            - label_ids (`np.ndarray`, *可选*): 标签(如果数据集包含)。
            - metrics (`Dict[str, float]`, *可选*): 可能包含标签的字典。

        """
        # 内存指标 - 必须尽早设置
        self._memory_tracker.start()

        # 获取测试数据加载器
        test_dataloader = self.get_test_dataloader(test_dataset)
        
        # 记录开始时间
        start_time = time.time()

        # 选择使用旧版预测循环还是评估循环
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        
        # 执行预测/评估循环
        output = eval_loop(
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )
        
        # 计算总批量大小
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        
        # 如果存在编译时间指标,则调整开始时间
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        
        # 更新速度指标
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        # 调用预测的回调处理器
        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
        
        # 停止内存跟踪器并更新指标
        self._memory_tracker.stop_and_update_metrics(output.metrics)

        # 返回预测输出
        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ):
        """
        Runs evaluation loop over `dataloader`.

        Args:
            dataloader (DataLoader): The data loader for evaluation.
            description (str): Description of the evaluation loop.
            prediction_loss_only (Optional[bool], optional): Whether to compute only prediction loss. Defaults to None.
            ignore_keys (Optional[List[str]], optional): List of keys to ignore during evaluation. Defaults to None.
            metric_key_prefix (str, optional): Prefix for metric keys. Defaults to "eval".
        """
        # Implementation of evaluation loop code...
        

    def _nested_gather(self, tensors, name=None):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`.
        """
        if tensors is None:
            return
        # Check if running on TPU (Tensor Processing Unit)
        if is_torch_xla_available():
            if name is None:
                name = "nested_gather"
            tensors = nested_xla_mesh_reduce(tensors, name)
        # Check if running on SageMaker's multi-processing
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        # Check if running in distributed setting
        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
            self.args.distributed_state is None and self.args.local_rank != -1
        ):
            tensors = distributed_concat(tensors)
        return tensors

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ):
        """
        Perform a prediction step using `model` on `inputs`.

        Args:
            model (nn.Module): The model for prediction.
            inputs (Dict[str, Union[torch.Tensor, Any]]): Dictionary of inputs for the model.
            prediction_loss_only (bool): Whether to compute only prediction loss.
            ignore_keys (Optional[List[str]], optional): List of keys to ignore during prediction.

        Returns:
            Depends on the model's prediction step implementation.
        """
        # Implementation of prediction step code...
        

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
        Computes the number of floating point operations for the model.

        Args:
            inputs (Dict[str, Union[torch.Tensor, Any]]): The inputs and targets of the model.

        Returns:
            int: The number of floating-point operations.
        """
        if hasattr(self.model, "floating_point_ops"):
            return self.model.floating_point_ops(inputs)
        else:
            return 0

    def init_hf_repo(self):
        """
        Initializes a git repository in `self.args.hub_model_id`.
        """
        # Only proceed if current process is the zeroth process
        if not self.is_world_process_zero():
            return

        # Determine repository name
        if self.args.hub_model_id is None:
            repo_name = Path(self.args.output_dir).absolute().name
        else:
            repo_name = self.args.hub_model_id

        # Create or access the repository
        repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
        self.hub_model_id = repo_url.repo_id
        self.push_in_progress = None
    `
        # 定义一个方法 `create_model_card`,用于创建模型卡片文档
        def create_model_card(
            self,
            # 可选参数:指定语言,用于描述模型卡片的语言环境
            language: Optional[str] = None,
            # 可选参数:指定许可证信息,用于描述模型的许可条件
            license: Optional[str] = None,
            # 可选参数:模型标签,可以是单个标签字符串或标签列表,用于标记模型特性
            tags: Union[str, List[str], None] = None,
            # 可选参数:模型名称,用于标识模型的名称
            model_name: Optional[str] = None,
            # 可选参数:模型微调自哪个模型,用于记录模型的微调来源
            finetuned_from: Optional[str] = None,
            # 可选参数:任务类型,可以是单个任务字符串或任务列表,描述模型适用的任务类型
            tasks: Union[str, List[str], None] = None,
            # 可选参数:数据集标签,可以是单个标签字符串或标签列表,描述模型所用的数据集标签
            dataset_tags: Union[str, List[str], None] = None,
            # 可选参数:数据集名称或标识,可以是单个名称字符串或标识可以是字符串或字符串列表
    ):
        """
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to the `Trainer` comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to the `Trainer` (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
                One or several dataset arguments, to be included in the metadata of the model card.
        """
        # 检查当前进程是否是主进程,如果不是,则直接返回,不执行后续操作
        if not self.is_world_process_zero():
            return

        # 模型卡片文件的保存路径
        model_card_filepath = os.path.join(self.args.output_dir, "README.md")

        # 判断模型卡片文件是否已存在
        is_peft_library = False
        if os.path.exists(model_card_filepath):
            # 如果文件存在,加载模型卡片数据并获取其中的库名称
            library_name = ModelCard.load(model_card_filepath).data.get("library_name")
            # 判断加载的模型卡片是否来自于 PEFT 库
            is_peft_library = library_name == "peft"

            # 如果有指定 tags,并且已存在的 tags 不为空,则将新的 tags 添加到现有 tags 中
            existing_tags = ModelCard.load(model_card_filepath).data.tags
            if tags is not None and existing_tags is not None:
                if isinstance(tags, str):
                    tags = [tags]
                for tag in existing_tags:
                    if tag not in tags:
                        tags.append(tag)

        # 根据 Trainer 中的信息生成训练摘要
        training_summary = TrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
            tasks=tasks,
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        # 将训练摘要转换为模型卡片格式
        model_card = training_summary.to_model_card()
        # 将模型卡片写入到指定路径的 README.md 文件中
        with open(model_card_filepath, "w") as f:
            f.write(model_card)

        # 如果是 PEFT 库,则调用特定函数更新或创建模型卡片
        if is_peft_library:
            unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
    # 仅在当前节点是主节点时执行推送操作
    if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
        return
    # 如果上次推送未完成且未设置 args.hub_always_push=True,则不执行当前推送操作
    if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
        return

    output_dir = self.args.output_dir
    # 为避免重新同步所有模型权重,从检查点文件夹中复制指定的模型文件到输出目录
    modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
    # 如果可用,添加适配器相关文件
    if is_peft_available():
        modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
    # 遍历需要复制的模型文件列表
    for modeling_file in modeling_files:
        # 如果文件存在于检查点文件夹中,则复制到输出目录
        if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
            shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
    # 如果存在 tokenizer 对象,则保存其当前状态到输出目录
    if self.tokenizer is not None:
        self.tokenizer.save_pretrained(output_dir)
    # 同样地,保存训练参数对象到输出目录
    torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    # 根据保存策略生成提交消息
    if self.args.save_strategy == IntervalStrategy.STEPS:
        commit_message = f"Training in progress, step {self.state.global_step}"
    else:
        commit_message = f"Training in progress, epoch {int(self.state.epoch)}"

    # 上传整个输出目录到指定的模型库仓库,作为一个新版本的提交
    model_push_job = upload_folder(
        repo_id=self.hub_model_id,
        folder_path=output_dir,
        commit_message=commit_message,
        token=self.args.hub_token,
        run_as_future=True,
        ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
    )

    push_jobs = [model_push_job]

    # 如果指定了保存策略为 CHECKPOINT 或 ALL_CHECKPOINTS,则将检查点文件夹也上传到模型库仓库
    if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
        # 确定在仓库中的路径名称,如果策略是 CHECKPOINT 则使用 "last-checkpoint",否则使用检查点文件夹的名称
        path_in_repo = (
            "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
        )
        # 创建一个上传任务,将检查点文件夹中的内容作为一个检查点版本提交
        checkpoint_push = upload_folder(
            repo_id=self.hub_model_id,
            folder_path=checkpoint_folder,
            path_in_repo=path_in_repo,
            commit_message=commit_message + ", checkpoint",
            token=self.args.hub_token,
            run_as_future=True,
        )
        push_jobs.append(checkpoint_push)

    # 如果当前没有进行中的推送任务或已完成的任务,创建新的推送任务
    if self.push_in_progress is None or self.push_in_progress.is_done():
        self.push_in_progress = PushInProgress(push_jobs)
    else:
        # 否则,将当前生成的推送任务添加到已有的推送任务列表中
        self.push_in_progress.jobs.extend(push_jobs)
    # 检查是否存在属性 "push_in_progress",如果不存在则直接返回,不进行后续操作
    if not hasattr(self, "push_in_progress"):
        return
    
    # 检查当前推送操作是否正在进行,并且还未完成
    if self.push_in_progress is not None and not self.push_in_progress.is_done():
        # 记录日志,提示当前正在等待检查点推送操作完成,可能需要几分钟时间
        logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
        
        # 等待推送操作完成,直到完成为止
        self.push_in_progress.wait_until_done()
    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
        """
        Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.

        Parameters:
            commit_message (`str`, *optional*, defaults to `"End of training"`):
                Message to commit while pushing.
            blocking (`bool`, *optional*, defaults to `True`):
                Whether the function should return only when the `git push` has finished.
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].

        Returns:
            The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
            progress of the commit if `blocking=True`.
        """
        model_name = kwargs.pop("model_name", None)
        # 如果未指定模型名称且应该保存模型
        if model_name is None and self.args.should_save:
            # 如果没有指定模型在模型输出目录中使用目录名称作为模型名称
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                # 否则使用 hub_model_id 中的最后一个部分作为模型名称
                model_name = self.args.hub_model_id.split("/")[-1]

        # 如果未初始化 hub_model_id,则初始化
        if self.hub_model_id is None:
            self.init_hf_repo()

        # 需要在所有进程上执行以支持 TPU 训练,但仅在 self.args.should_save 确定的进程上保存
        self.save_model(_internal_call=True)

        # 只在一个节点上执行推送操作
        if not self.is_world_process_zero():
            return

        # 如果模型已有标签并且用户传递了 "tags" 参数,则添加额外的标签以处理内部标签
        if getattr(self.model, "model_tags", None) is not None:
            if "tags" not in kwargs:
                kwargs["tags"] = []

            # 如果 tags 是字符串,转换为列表
            if isinstance(kwargs["tags"], str):
                kwargs["tags"] = [kwargs["tags"]]

            # 将模型的每个标签添加到 kwargs["tags"] 中
            for model_tag in self.model.model_tags:
                if model_tag not in kwargs["tags"]:
                    kwargs["tags"].append(model_tag)

        # 创建模型卡片
        self.create_model_card(model_name=model_name, **kwargs)

        # 等待当前上传完成
        self._finish_current_push()
        # 返回上传文件夹的结果
        return upload_folder(
            repo_id=self.hub_model_id,
            folder_path=self.args.output_dir,
            commit_message=commit_message,
            token=self.args.hub_token,
            run_as_future=not blocking,
            ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
        )

    #
    # Deprecated code
    #
    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ):
        """
        Perform a prediction loop over the given data loader.

        Args:
            dataloader (DataLoader): The data loader containing the data to predict on.
            description (str): Description of the prediction loop.
            prediction_loss_only (Optional[bool], optional): Whether to calculate only prediction loss.
            ignore_keys (Optional[List[str]], optional): Keys to ignore during prediction.
            metric_key_prefix (str, optional): Prefix for metric keys.

        Returns:
            None
        """

    def _gather_and_numpify(self, tensors, name):
        """
        Gather values of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy arrays.

        Args:
            tensors: Tensor or list/tuple of nested tensors to gather and convert.
            name: Name associated with the gathering operation.

        Returns:
            numpy.ndarray or list/tuple/nested structure of numpy arrays corresponding to `tensors`.
        """
        if tensors is None:
            return

        # If using Torch XLA, perform mesh reduction
        if is_torch_xla_available():
            tensors = nested_xla_mesh_reduce(tensors, name)
        # If using SageMaker multi-processing, gather tensors
        elif is_sagemaker_mp_enabled():
            tensors = smp_gather(tensors)
        # If in distributed training mode, concatenate tensors
        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
            tensors = distributed_concat(tensors)

        # Convert gathered tensors to numpy arrays
        return nested_numpify(tensors)

    def _add_sm_patterns_to_gitignore(self) -> None:
        """
        Add SageMaker Checkpointing patterns to .gitignore file if running on the main process.

        Returns:
            None
        """
        # Ensure only the main process performs this operation
        if not self.is_world_process_zero():
            return

        # Patterns to be added to .gitignore
        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]

        # Read current .gitignore content if it exists
        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
                current_content = f.read()
        else:
            current_content = ""

        # Prepare content to write, appending patterns if not already present
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if content.endswith("\n"):
                    content += pattern
                else:
                    content += f"\n{pattern}"

        # Write to .gitignore if there were changes
        if content != current_content:
            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
                logger.debug(f"Writing .gitignore file. Content: {content}")
                f.write(content)

        # Add .gitignore to the git index
        self.repo.git_add(".gitignore")

        # Ensure there's no race condition with git status
        time.sleep(0.5)

        # Commit changes if repository is not clean
        if not self.repo.is_repo_clean():
            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
            self.repo.git_push()
    # 定义一个方法,用于创建加速器对象并进行后处理
    def create_accelerator_and_postprocess(self):
        # 设置梯度累积插件的参数字典
        grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
        # 设置不与数据加载器同步
        grad_acc_kwargs["sync_with_dataloader"] = False
        # 创建梯度累积插件对象
        gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

        # 创建加速器对象
        self.accelerator = Accelerator(
            deepspeed_plugin=self.args.deepspeed_plugin,  # 指定深度加速插件
            gradient_accumulation_plugin=gradient_accumulation_plugin,  # 指定梯度累积插件
            **self.args.accelerator_config.to_dict(),  # 使用加速器配置的参数
        )
        # 某些 Trainer 类需要使用 `gather` 而不是 `gather_for_metrics`,因此存储一个标志
        self.gather_function = self.accelerator.gather_for_metrics

        # 检查是否启用了 DeepSpeed 插件
        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
        # 检查是否启用了 FSDP 插件
        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

        # 加速器创建后的设置
        if self.is_fsdp_enabled:
            fsdp_plugin = self.accelerator.state.fsdp_plugin
            # 设置 FSDP 插件的所有 gather 限制
            fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
                "limit_all_gathers", fsdp_plugin.limit_all_gathers
            )
            # 如果加速器版本兼容,则设置激活检查点功能
            if is_accelerate_available("0.23.0"):
                fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
                    "activation_checkpointing", fsdp_plugin.activation_checkpointing
                )
                # 如果同时设置了激活检查点和梯度检查点,则抛出错误
                if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
                    raise ValueError(
                        "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
                        "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
                        "when using FSDP."
                    )

        # 如果启用了 DeepSpeed,并且未提供 hf_deepspeed_config 参数,则传播参数到 DeepSpeed
        if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
            self.propagate_args_to_deepspeed()

        # 如果设置了 `save_only_model` 并且同时使用了 DeepSpeed 或 FSDP 以及 `load_best_model_at_end`,则抛出错误
        if (
            self.args.save_only_model
            and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
            and self.args.load_best_model_at_end
        ):
            wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
            raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")

        # 如果使用了 DeepSpeed 或 FSDP,并且设置了 `auto_find_batch_size`,则抛出未实现错误
        if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size:
            wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
            raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.")
    # 将 Trainer 参数传播到 DeepSpeed 插件中
    def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
        """
        Sets values in the deepspeed plugin based on the Trainer args
        根据 Trainer 参数设置 DeepSpeed 插件中的数值
        """
        # 导入 DeepSpeed 配置类
        from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

        # 获取当前加速器状态中的 DeepSpeed 插件
        ds_plugin = self.accelerator.state.deepspeed_plugin

        # 使用 Trainer args 配置一个 HfTrainerDeepSpeedConfig 对象
        ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
        # 将 DeepSpeed 插件的配置设置为新创建的 HfTrainerDeepSpeedConfig 对象的配置
        ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
        # 根据 Trainer 参数进一步处理 DeepSpeed 配置
        ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)

    # 更新 FSDP 插件中的 QLoRa 相关设置
    def _fsdp_qlora_plugin_updates(self):
        """
        Updates the FSDP plugin with QLoRa related settings if applicable
        如果适用,更新 FSDP 插件的 QLoRa 相关设置
        """
        # 检查是否启用了 FSDP 并且模型是 PEFT 模型
        if self.is_fsdp_enabled and _is_peft_model(self.model):
            # 导入 PEFT 配置和 FSDP 自动包装策略工具
            from peft import LoraConfig
            from peft.utils.other import fsdp_auto_wrap_policy

            # 如果模型的 active_peft_config 是 LoraConfig 类型
            if isinstance(self.model.active_peft_config, LoraConfig):
                # 获取加速器状态中的 FSDP 插件
                fsdp_plugin = self.accelerator.state.fsdp_plugin
                # 设置 FSDP 插件的自动包装策略
                fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)

            # 如果模型的量化方法是 BITS_AND_BYTES,且量化配置是浮点数,并且加速器版本高于 "0.27.0"
            if (
                getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
                and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
                and version.parse(accelerate_version) > version.parse("0.27.0")
            ):
                # 获取加速器状态中的 FSDP 插件并设置混合精度
                fsdp_plugin.set_mixed_precision(
                    self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
                )
posted @ 2024-07-01 10:56  绝不原创的飞龙  阅读(75)  评论(0编辑  收藏  举报