Transformers-源码解析-二-

Transformers 源码解析(二)

.\commands\convert.py

# 导入必要的模块和类
from argparse import ArgumentParser, Namespace  # 导入参数解析相关模块
from ..utils import logging  # 导入日志工具
from . import BaseTransformersCLICommand  # 导入基础转换器命令类

# 转换命令工厂函数,用于创建 ConvertCommand 实例
def convert_command_factory(args: Namespace):
    """
    Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.

    Returns: ServeCommand
    """
    return ConvertCommand(
        args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
    )

# 当 transformers 无法导入时显示的错误消息
IMPORT_ERROR_MESSAGE = """
transformers can only be used from the commandline to convert TensorFlow models in PyTorch, In that case, it requires
TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.
"""

# ConvertCommand 类,继承自 BaseTransformersCLICommand
class ConvertCommand(BaseTransformersCLICommand):

    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        # 添加转换命令到参数解析器
        train_parser = parser.add_parser(
            "convert",
            help="CLI tool to run convert model from original author checkpoints to Transformers PyTorch checkpoints.",
        )
        # 添加转换命令的参数
        train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
        train_parser.add_argument(
            "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
        )
        train_parser.add_argument(
            "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch saved model output."
        )
        train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
        train_parser.add_argument(
            "--finetuning_task_name",
            type=str,
            default=None,
            help="Optional fine-tuning task name if the TF model was a finetuned model.",
        )
        train_parser.set_defaults(func=convert_command_factory)

    def __init__(
        self,
        model_type: str,
        tf_checkpoint: str,
        pytorch_dump_output: str,
        config: str,
        finetuning_task_name: str,
        *args,
        ):
        # 获取名为 "transformers-cli/converting" 的日志记录器实例
        self._logger = logging.get_logger("transformers-cli/converting")

        # 记录信息日志,显示加载模型的信息
        self._logger.info(f"Loading model {model_type}")
        
        # 设置实例变量来存储模型类型
        self._model_type = model_type
        
        # 设置实例变量来存储 TensorFlow 的检查点路径
        self._tf_checkpoint = tf_checkpoint
        
        # 设置实例变量来存储 PyTorch 转换后的输出路径
        self._pytorch_dump_output = pytorch_dump_output
        
        # 设置实例变量来存储模型的配置信息
        self._config = config
        
        # 设置实例变量来存储微调任务的名称
        self._finetuning_task_name = finetuning_task_name

.\commands\download.py

# 导入模块 argparse 中的 ArgumentParser 类
from argparse import ArgumentParser

# 从当前目录下的 __init__.py 文件中导入 BaseTransformersCLICommand 类
from . import BaseTransformersCLICommand

# 定义一个函数 download_command_factory,用于创建 DownloadCommand 类的实例并返回
def download_command_factory(args):
    return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)

# 定义 DownloadCommand 类,继承自 BaseTransformersCLICommand 类
class DownloadCommand(BaseTransformersCLICommand):

    # 静态方法,用于注册命令行参数
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        # 添加一个名为 "download" 的子命令解析器
        download_parser = parser.add_parser("download")

        # 添加命令行参数 --cache-dir,用于指定存储模型的路径
        download_parser.add_argument(
            "--cache-dir", type=str, default=None, help="Path to location to store the models"
        )

        # 添加命令行参数 --force,用于强制下载模型,即使已存在于 cache-dir 中
        download_parser.add_argument(
            "--force", action="store_true", help="Force the model to be download even if already in cache-dir"
        )

        # 添加命令行参数 --trust-remote-code,用于控制是否信任远程代码
        download_parser.add_argument(
            "--trust-remote-code",
            action="store_true",
            help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
        )

        # 添加一个位置参数 model,用于指定要下载的模型名称
        download_parser.add_argument("model", type=str, help="Name of the model to download")

        # 设置默认的函数处理程序为 download_command_factory
        download_parser.set_defaults(func=download_command_factory)

    # DownloadCommand 类的初始化方法,接收模型名称、缓存路径、是否强制下载、是否信任远程代码四个参数
    def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
        self._model = model  # 将模型名称存储在实例变量 _model 中
        self._cache = cache  # 将缓存路径存储在实例变量 _cache 中
        self._force = force  # 将是否强制下载标志存储在实例变量 _force 中
        self._trust_remote_code = trust_remote_code  # 将是否信任远程代码标志存储在实例变量 _trust_remote_code 中

    # 定义 run 方法,用于执行下载模型的操作
    def run(self):
        # 从 ..models.auto 模块中导入 AutoModel 和 AutoTokenizer 类
        from ..models.auto import AutoModel, AutoTokenizer

        # 使用 AutoModel 类从预训练模型中加载模型
        AutoModel.from_pretrained(
            self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
        )

        # 使用 AutoTokenizer 类从预训练模型中加载 tokenizer
        AutoTokenizer.from_pretrained(
            self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
        )

.\commands\env.py

# 导入所需的模块
import importlib.util
import os
import platform
from argparse import ArgumentParser

# 导入 Hugging Face Hub 库
import huggingface_hub

# 导入版本号
from .. import __version__ as version

# 导入一些实用函数
from ..utils import (
    is_accelerate_available,
    is_flax_available,
    is_safetensors_available,
    is_tf_available,
    is_torch_available,
)

# 导入基础命令类
from . import BaseTransformersCLICommand

# 定义一个工厂函数,用于创建环境命令对象
def info_command_factory(_):
    return EnvironmentCommand()

# 定义一个工厂函数,用于创建下载命令对象
def download_command_factory(args):
    return EnvironmentCommand(args.accelerate_config_file)

# 环境命令类,继承自基础 Transformers CLI 命令类
class EnvironmentCommand(BaseTransformersCLICommand):
    
    # 静态方法:注册子命令
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        # 添加一个名为 "env" 的子命令解析器
        download_parser = parser.add_parser("env")
        # 设置默认的命令函数为 info_command_factory
        download_parser.set_defaults(func=info_command_factory)
        # 添加参数:accelerate-config_file,用于指定加速配置文件的路径
        download_parser.add_argument(
            "--accelerate-config_file",
            default=None,
            help="The accelerate config file to use for the default values in the launching script.",
        )
        # 再次设置默认的命令函数为 download_command_factory
        download_parser.set_defaults(func=download_command_factory)
    
    # 初始化方法,接受加速配置文件作为参数
    def __init__(self, accelerate_config_file, *args) -> None:
        self._accelerate_config_file = accelerate_config_file

    # 静态方法:格式化字典为字符串,每个键值对前缀为 "-",以换行连接
    @staticmethod
    def format_dict(d):
        return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"

.\commands\lfs.py

"""
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs.

Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py

Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md


To launch debugger while developing:

``` [lfs "customtransfer.multipart"]
path = /path/to/transformers/.env/bin/python args = -m debugpy --listen 5678 --wait-for-client
/path/to/transformers/src/transformers/commands/transformers_cli.py lfs-multipart-upload ```"""

import json  # 导入处理 JSON 的模块
import os  # 导入操作系统功能的模块
import subprocess  # 导入运行外部命令的模块
import sys  # 导入与 Python 解释器交互的模块
import warnings  # 导入警告处理的模块
from argparse import ArgumentParser  # 从 argparse 模块中导入 ArgumentParser 类
from contextlib import AbstractContextManager  # 从 contextlib 模块中导入 AbstractContextManager 类
from typing import Dict, List, Optional  # 导入类型提示相关的模块

import requests  # 导入处理 HTTP 请求的模块

from ..utils import logging  # 从相对路径中导入 logging 模块
from . import BaseTransformersCLICommand  # 从当前目录中导入 BaseTransformersCLICommand 类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器对象,并赋值给 logger 变量  # pylint: disable=invalid-name

LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"  # 定义一个常量,指定 LFS 多部分上传命令的名称

class LfsCommands(BaseTransformersCLICommand):
    """
    Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload
    large files >5GB 🔥. Spec for LFS custom transfer agent is:
    https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md

    This introduces two commands to the CLI:

    1. $ transformers-cli lfs-enable-largefiles

    This should be executed once for each model repo that contains a model file >5GB. It's documented in the error
    message you get if you just try to git push a 5GB file without having enabled it before.

    2. $ transformers-cli lfs-multipart-upload

    This command is called by lfs directly and is not meant to be called by the user.
    """

    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        enable_parser = parser.add_parser(
            "lfs-enable-largefiles",
            help=(
                "Deprecated: use `huggingface-cli` instead. Configure your repository to enable upload of files > 5GB."
            ),
        )
        enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
        enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))  # 设置默认的命令处理函数为 LfsEnableCommand 类的实例化

        upload_parser = parser.add_parser(
            LFS_MULTIPART_UPLOAD_COMMAND,
            help=(
                "Deprecated: use `huggingface-cli` instead. "
                "Command will get called by git-lfs, do not call it directly."
            ),
        )
        upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))  # 设置默认的命令处理函数为 LfsUploadCommand 类的实例化

class LfsEnableCommand:
    def __init__(self, args):
        self.args = args  # 初始化类实例时,将参数保存到实例属性中
    def run(self):
        # 发出警告信息,提示使用 `huggingface-cli` 取代 `transformers-cli` 管理仓库
        warnings.warn(
            "Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
        )
        # 获取指定路径的绝对路径
        local_path = os.path.abspath(self.args.path)
        # 如果指定路径不是一个目录,则输出错误信息并退出程序
        if not os.path.isdir(local_path):
            print("This does not look like a valid git repo.")
            exit(1)
        # 设置 git-lfs 的自定义传输程序路径为 `transformers-cli`,在指定路径下执行
        subprocess.run(
            "git config lfs.customtransfer.multipart.path transformers-cli".split(), check=True, cwd=local_path
        )
        # 设置 git-lfs 的自定义传输程序参数为预定义的 `LFS_MULTIPART_UPLOAD_COMMAND` 值,在指定路径下执行
        subprocess.run(
            f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
            check=True,
            cwd=local_path,
        )
        # 输出信息,表示本地仓库已设置好以处理大文件
        print("Local repo set up for largefiles")
# 将字典消息转换为 JSON 格式并写入标准输出
def write_msg(msg: Dict):
    msg = json.dumps(msg) + "\n"  # 转换字典消息为 JSON 字符串,并添加换行符
    sys.stdout.write(msg)  # 将 JSON 字符串写入标准输出
    sys.stdout.flush()  # 刷新标准输出缓冲区,确保消息被写入

# 从标准输入读取一行 JSON 格式的消息
def read_msg() -> Optional[Dict]:
    msg = json.loads(sys.stdin.readline().strip())  # 读取并解析 JSON 格式的消息

    if "terminate" in (msg.get("type"), msg.get("event")):
        # 如果消息中包含 "terminate" 类型或事件,表示终止消息已接收
        return None

    if msg.get("event") not in ("download", "upload"):
        logger.critical("Received unexpected message")  # 记录关键错误日志,表示接收到意外的消息
        sys.exit(1)  # 非预期消息时退出程序

    return msg  # 返回解析后的消息字典

# 用于从文件中读取指定范围的数据的上下文管理器类
class FileSlice(AbstractContextManager):
    """
    File-like object that only reads a slice of a file

    Inspired by stackoverflow.com/a/29838711/593036
    """

    def __init__(self, filepath: str, seek_from: int, read_limit: int):
        self.filepath = filepath  # 文件路径
        self.seek_from = seek_from  # 读取起始位置
        self.read_limit = read_limit  # 读取数据限制大小
        self.n_seen = 0  # 已读取的字节数

    def __enter__(self):
        self.f = open(self.filepath, "rb")  # 打开文件以供读取
        self.f.seek(self.seek_from)  # 设置文件读取的起始位置
        return self  # 返回 FileSlice 对象本身作为上下文管理器

    def __len__(self):
        total_length = os.fstat(self.f.fileno()).st_size  # 获取文件总大小
        return min(self.read_limit, total_length - self.seek_from)  # 返回实际可读取的数据长度

    def read(self, n=-1):
        if self.n_seen >= self.read_limit:
            return b""  # 如果已读取数据超出限制,则返回空字节串

        remaining_amount = self.read_limit - self.n_seen  # 剩余可读取的数据量
        # 读取数据,不超过剩余可读取的数据量或指定的 n 字节
        data = self.f.read(remaining_amount if n < 0 else min(n, remaining_amount))
        self.n_seen += len(data)  # 更新已读取的字节数
        return data  # 返回读取的数据

    def __iter__(self):
        yield self.read(n=4 * 1024 * 1024)  # 以迭代器方式返回每次最多 4MB 的数据

    def __exit__(self, *args):
        self.f.close()  # 关闭文件

# LFS 上传命令类,初始化时接收参数
class LfsUploadCommand:
    def __init__(self, args):
        self.args = args  # 初始化 LFS 上传命令的参数
    def run(self):
        # 立即在调用自定义传输过程后,git-lfs通过标准输入发送初始化数据到进程中。
        # 这向进程提供了关于配置的有用信息。
        init_msg = json.loads(sys.stdin.readline().strip())
        # 如果初始化消息不是"init"事件且操作不是"upload",则写入错误消息并退出程序。
        if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
            write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
            sys.exit(1)

        # 传输过程应使用初始化结构中的信息,并执行任何一次性设置任务。
        # 然后通过标准输出响应一个简单的空确认结构。
        write_msg({})

        # 初始化交换后,git-lfs将按序列发送任意数量的传输请求到传输进程的标准输入。
        while True:
            msg = read_msg()
            if msg is None:
                # 当所有传输都已处理完毕时,git-lfs将向传输进程的标准输入发送终止事件。
                # 收到此消息后,传输进程应清理并终止。不需要响应。
                sys.exit(0)

            oid = msg["oid"]
            filepath = msg["path"]
            completion_url = msg["action"]["href"]
            header = msg["action"]["header"]
            chunk_size = int(header.pop("chunk_size"))
            presigned_urls: List[str] = list(header.values())

            parts = []
            for i, presigned_url in enumerate(presigned_urls):
                # 使用FileSlice从文件中读取数据片段,根据chunk_size和偏移量进行读取。
                with FileSlice(filepath, seek_from=i * chunk_size, read_limit=chunk_size) as data:
                    # 发送PUT请求上传数据片段到预签名的URL。
                    r = requests.put(presigned_url, data=data)
                    r.raise_for_status()
                    # 添加上传片段的ETag和序号到parts列表。
                    parts.append(
                        {
                            "etag": r.headers.get("etag"),
                            "partNumber": i + 1,
                        }
                    )
                    # 为了支持数据上传/下载过程中的进度报告,
                    # 传输进程应向标准输出发送消息。
                    write_msg(
                        {
                            "event": "progress",
                            "oid": oid,
                            "bytesSoFar": (i + 1) * chunk_size,
                            "bytesSinceLast": chunk_size,
                        }
                    )
                    # 不是精确的进度报告,但可以接受。

            # 发送包含oid和已上传部分信息的POST请求到完成URL。
            r = requests.post(
                completion_url,
                json={
                    "oid": oid,
                    "parts": parts,
                },
            )
            r.raise_for_status()

            # 发送完成事件到标准输出。
            write_msg({"event": "complete", "oid": oid})

.\commands\pt_to_tf.py

# 导入inspect模块,用于检查和获取Python对象的信息
import inspect
# 导入os模块,提供与操作系统交互的功能
import os
# 从argparse模块中导入ArgumentParser和Namespace,用于解析命令行参数
from argparse import ArgumentParser, Namespace
# 从importlib模块中导入import_module函数,用于动态导入模块
from importlib import import_module

# 导入huggingface_hub模块,用于与Hugging Face Hub交互
import huggingface_hub
# 导入numpy模块,并重命名为np,用于数值计算
import numpy as np
# 从packaging模块中导入version函数,用于处理版本号
from packaging import version

# 从上层目录中导入以下对象
from .. import (
    FEATURE_EXTRACTOR_MAPPING,
    IMAGE_PROCESSOR_MAPPING,
    PROCESSOR_MAPPING,
    TOKENIZER_MAPPING,
    AutoConfig,
    AutoFeatureExtractor,
    AutoImageProcessor,
    AutoProcessor,
    AutoTokenizer,
    is_datasets_available,
    is_tf_available,
    is_torch_available,
)
# 从上层目录的utils模块中导入TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME和logging
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
# 从当前目录的__init__.py文件中导入BaseTransformersCLICommand类
from . import BaseTransformersCLICommand

# 如果TensorFlow可用,则导入tensorflow模块
if is_tf_available():
    import tensorflow as tf
    # 禁用TensorFlow的32位浮点数执行
    tf.config.experimental.enable_tensor_float_32_execution(False)

# 如果PyTorch可用,则导入torch模块
if is_torch_available():
    import torch

# 如果datasets可用,则从datasets模块中导入load_dataset函数
if is_datasets_available():
    from datasets import load_dataset

# 定义最大误差常量,用于测试时的误差容忍度
MAX_ERROR = 5e-5  # 比内部测试宽松的误差容忍度,以避免用户界面错误

# 定义convert_command_factory函数,用于创建转换模型检查点的命令
def convert_command_factory(args: Namespace):
    """
    Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.

    Returns: ServeCommand
    """
    # 返回一个PTtoTFCommand对象,用于执行模型转换命令
    return PTtoTFCommand(
        args.model_name,
        args.local_dir,
        args.max_error,
        args.new_weights,
        args.no_pr,
        args.push,
        args.extra_commit_description,
        args.override_model_class,
    )

# 定义PTtoTFCommand类,继承自BaseTransformersCLICommand类
class PTtoTFCommand(BaseTransformersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        # 创建一个子命令解析器,命令名为"pt-to-tf"
        train_parser = parser.add_parser(
            "pt-to-tf",
            help=(
                "CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
                " Can also be used to validate existing weights without opening PRs, with --no-pr."
            ),
        )
        # 添加--model-name参数到train_parser,用于指定模型名称,必须提供
        train_parser.add_argument(
            "--model-name",
            type=str,
            required=True,
            help="The model name, including owner/organization, as seen on the hub.",
        )
        # 添加--local-dir参数到train_parser,用于指定模型仓库的本地目录,可选,默认为/tmp/{model_name}
        train_parser.add_argument(
            "--local-dir",
            type=str,
            default="",
            help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
        )
        # 添加--max-error参数到train_parser,用于指定最大误差容忍度,可选,默认为预设的MAX_ERROR值
        train_parser.add_argument(
            "--max-error",
            type=float,
            default=MAX_ERROR,
            help=(
                f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
            ),
        )
        # 添加--new-weights参数到train_parser,用于指示是否创建新的TensorFlow权重,即使已存在
        train_parser.add_argument(
            "--new-weights",
            action="store_true",
            help="Optional flag to create new TensorFlow weights, even if they already exist.",
        )
        # 添加--no-pr参数到train_parser,用于指示是否不开启一个带有转换后权重的PR
        train_parser.add_argument(
            "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
        )
        # 添加--push参数到train_parser,用于指示是否直接将权重推送到'main'分支(需要权限)
        train_parser.add_argument(
            "--push",
            action="store_true",
            help="Optional flag to push the weights directly to `main` (requires permissions)",
        )
        # 添加--extra-commit-description参数到train_parser,用于提供附加的提交描述信息,用于打开PR时使用
        train_parser.add_argument(
            "--extra-commit-description",
            type=str,
            default="",
            help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
        )
        # 添加--override-model-class参数到train_parser,用于指定模型类别,允许手动覆盖自动检测的模型类型
        train_parser.add_argument(
            "--override-model-class",
            type=str,
            default=None,
            help="If you think you know better than the auto-detector, you can specify the model class here. "
            "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
        )
        # 设置默认的命令处理函数为convert_command_factory
        train_parser.set_defaults(func=convert_command_factory)
    def find_pt_tf_differences(pt_outputs, tf_outputs):
        """
        Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
        """
        # 1. All output attributes must be the same
        pt_out_attrs = set(pt_outputs.keys())  # 获取 PyTorch 输出的所有属性名
        tf_out_attrs = set(tf_outputs.keys())  # 获取 TensorFlow 输出的所有属性名
        if pt_out_attrs != tf_out_attrs:  # 如果两者属性名不一致,则抛出数值错误异常
            raise ValueError(
                f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
                f" {tf_out_attrs})"
            )

        # 2. For each output attribute, computes the difference
        def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
            # 如果当前属性是张量(tensor),则比较它们的差异;否则递归地深入比较
            if isinstance(pt_out, torch.Tensor):
                tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
                differences[attr_name] = tensor_difference
            else:
                root_name = attr_name
                for i, pt_item in enumerate(pt_out):
                    # 如果是具名属性,则保持属性名;否则只保留索引
                    if isinstance(pt_item, str):
                        branch_name = root_name + pt_item
                        tf_item = tf_out[pt_item]
                        pt_item = pt_out[pt_item]
                    else:
                        branch_name = root_name + f"[{i}]"
                        tf_item = tf_out[i]
                    differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)

            return differences

        return _find_pt_tf_differences(pt_outputs, tf_outputs, {})

    def __init__(
        self,
        model_name: str,
        local_dir: str,
        max_error: float,
        new_weights: bool,
        no_pr: bool,
        push: bool,
        extra_commit_description: str,
        override_model_class: str,
        *args,
    ):
        self._logger = logging.get_logger("transformers-cli/pt_to_tf")  # 初始化日志记录器
        self._model_name = model_name  # 初始化模型名称
        self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)  # 初始化本地目录
        self._max_error = max_error  # 初始化最大误差
        self._new_weights = new_weights  # 初始化是否使用新权重
        self._no_pr = no_pr  # 初始化是否不使用 PR
        self._push = push  # 初始化是否推送
        self._extra_commit_description = extra_commit_description  # 初始化额外提交描述
        self._override_model_class = override_model_class  # 初始化覆盖模型类

.\commands\run.py

# 导入必要的模块和函数
from argparse import ArgumentParser

from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand

# 获取当前模块的日志记录器对象
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 根据文件路径推断输入文件格式的函数
def try_infer_format_from_ext(path: str):
    # 如果路径为空,则默认返回"pipe"
    if not path:
        return "pipe"

    # 遍历已支持的数据格式列表
    for ext in PipelineDataFormat.SUPPORTED_FORMATS:
        # 如果路径以当前格式结尾,则返回该格式
        if path.endswith(ext):
            return ext
    
    # 如果无法推断出格式,则抛出异常
    raise Exception(
        f"Unable to determine file format from file extension {path}. "
        f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
    )

# 创建运行命令的工厂函数,根据参数配置创建相应的pipeline和数据格式对象
def run_command_factory(args):
    # 根据参数配置创建pipeline对象
    nlp = pipeline(
        task=args.task,
        model=args.model if args.model else None,
        config=args.config,
        tokenizer=args.tokenizer,
        device=args.device,
    )
    
    # 根据输入文件路径推断数据格式,或者直接使用给定的格式参数
    format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
    
    # 根据参数创建数据格式对象
    reader = PipelineDataFormat.from_str(
        format=format,
        output_path=args.output,
        input_path=args.input,
        column=args.column if args.column else nlp.default_input_names,
        overwrite=args.overwrite,
    )
    
    # 返回运行命令对象,传入创建的pipeline和数据格式对象
    return RunCommand(nlp, reader)

# 定义运行命令的类,继承自BaseTransformersCLICommand
class RunCommand(BaseTransformersCLICommand):
    def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
        self._nlp = nlp  # 存储传入的pipeline对象
        self._reader = reader  # 存储传入的数据格式对象

    @staticmethod
    def run(self):
        # 初始化 NLP 模型和输出结果列表
        nlp, outputs = self._nlp, []

        # 遍历数据读取器中的每个条目
        for entry in self._reader:
            # 如果数据读取器支持多列输入,则调用 NLP 模型的处理方法
            if self._reader.is_multi_columns:
                output = nlp(**entry)
            else:
                output = nlp(entry)

            # 如果输出是字典,则将其添加到输出列表中
            if isinstance(output, dict):
                outputs.append(output)
            else:
                # 如果输出不是字典,假定它是一个列表,将其扩展到输出列表中
                outputs += output

        # 保存数据
        if self._nlp.binary_output:
            # 如果 NLP 模型要求二进制输出,将输出保存为二进制文件
            binary_path = self._reader.save_binary(outputs)
            # 记录警告,指出当前管道需要以二进制格式保存输出
            logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
        else:
            # 否则,将输出保存到适当的位置
            self._reader.save(outputs)

.\commands\serving.py

# 版权声明和许可证信息,指明此代码受 Apache License, Version 2.0 保护,禁止未经许可使用
#
# from ...pipelines 导入需要的模块和函数
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional

from ..pipelines import Pipeline, get_supported_tasks, pipeline
# 导入日志模块
from ..utils import logging
# 导入基础命令行接口类
from . import BaseTransformersCLICommand

try:
    # 尝试导入 FastAPI 和相关依赖
    from fastapi import Body, FastAPI, HTTPException
    from fastapi.routing import APIRoute
    from pydantic import BaseModel
    from starlette.responses import JSONResponse
    from uvicorn import run

    # 标记服务依赖已安装
    _serve_dependencies_installed = True
except (ImportError, AttributeError):
    # 如果导入错误或属性错误,将 BaseModel 设为 object,并定义 Body 函数为空函数
    BaseModel = object

    def Body(*x, **y):
        pass

    # 标记服务依赖未安装
    _serve_dependencies_installed = False

# 获取名为 "transformers-cli/serving" 的日志记录器对象
logger = logging.get_logger("transformers-cli/serving")


def serve_command_factory(args: Namespace):
    """
    从提供的命令行参数实例化服务服务器的工厂函数。

    Returns: ServeCommand 实例
    """
    # 调用 pipeline 函数创建 NLP 管道对象 nlp
    nlp = pipeline(
        task=args.task,
        model=args.model if args.model else None,
        config=args.config,
        tokenizer=args.tokenizer,
        device=args.device,
    )
    # 返回 ServeCommand 的实例,传递 nlp 对象、主机地址、端口和工作进程数作为参数
    return ServeCommand(nlp, args.host, args.port, args.workers)


class ServeModelInfoResult(BaseModel):
    """
    暴露模型信息的数据模型
    """

    infos: dict


class ServeTokenizeResult(BaseModel):
    """
    分词结果数据模型
    """

    tokens: List[str]
    tokens_ids: Optional[List[int]]


class ServeDeTokenizeResult(BaseModel):
    """
    反分词结果数据模型
    """

    text: str


class ServeForwardResult(BaseModel):
    """
    前向传播结果数据模型
    """

    output: Any


class ServeCommand(BaseTransformersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        # 创建一个子命令解析器 'serve',用于运行 REST 和 GraphQL 端点的推理请求
        serve_parser = parser.add_parser(
            "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
        )
        # 添加 '--task' 参数,指定要在管道上运行的任务,从支持的任务列表中选择
        serve_parser.add_argument(
            "--task",
            type=str,
            choices=get_supported_tasks(),
            help="The task to run the pipeline on",
        )
        # 添加 '--host' 参数,指定服务器监听的接口,默认为 localhost
        serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
        # 添加 '--port' 参数,指定服务器监听的端口,默认为 8888
        serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
        # 添加 '--workers' 参数,指定 HTTP 服务器的工作线程数,默认为 1
        serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
        # 添加 '--model' 参数,指定模型的名称或存储路径
        serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
        # 添加 '--config' 参数,指定模型配置的名称或存储路径
        serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
        # 添加 '--tokenizer' 参数,指定要使用的分词器的名称
        serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
        # 添加 '--device' 参数,指定运行的设备,-1 表示 CPU,>= 0 表示 GPU,默认为 -1
        serve_parser.add_argument(
            "--device",
            type=int,
            default=-1,
            help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
        )
        # 将函数 'serve_command_factory' 设置为默认处理函数
        serve_parser.set_defaults(func=serve_command_factory)
    # 初始化方法,接受 Pipeline 对象、主机名、端口号和工作进程数作为参数
    def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
        # 将传入的 Pipeline 对象赋值给实例变量 _pipeline
        self._pipeline = pipeline

        # 将传入的主机名赋值给实例变量 host
        self.host = host
        # 将传入的端口号赋值给实例变量 port
        self.port = port
        # 将传入的工作进程数赋值给实例变量 workers
        self.workers = workers

        # 检查是否已安装 serve 所需的依赖,如果未安装则抛出运行时错误
        if not _serve_dependencies_installed:
            raise RuntimeError(
                "Using serve command requires FastAPI and uvicorn. "
                'Please install transformers with [serving]: pip install "transformers[serving]". '
                "Or install FastAPI and uvicorn separately."
            )
        else:
            # 若依赖已安装,则记录信息,指示模型正在指定的主机和端口上提供服务
            logger.info(f"Serving model over {host}:{port}")
            # 创建 FastAPI 应用实例 _app,并设置路由和超时时间
            self._app = FastAPI(
                routes=[
                    APIRoute(
                        "/",
                        self.model_info,
                        response_model=ServeModelInfoResult,
                        response_class=JSONResponse,
                        methods=["GET"],
                    ),
                    APIRoute(
                        "/tokenize",
                        self.tokenize,
                        response_model=ServeTokenizeResult,
                        response_class=JSONResponse,
                        methods=["POST"],
                    ),
                    APIRoute(
                        "/detokenize",
                        self.detokenize,
                        response_model=ServeDeTokenizeResult,
                        response_class=JSONResponse,
                        methods=["POST"],
                    ),
                    APIRoute(
                        "/forward",
                        self.forward,
                        response_model=ServeForwardResult,
                        response_class=JSONResponse,
                        methods=["POST"],
                    ),
                ],
                timeout=600,
            )

    # 启动服务的方法,运行 FastAPI 应用
    def run(self):
        run(self._app, host=self.host, port=self.port, workers=self.workers)

    # 返回模型信息的方法,以 ServeModelInfoResult 对象的形式返回 Pipeline 模型的配置信息
    def model_info(self):
        return ServeModelInfoResult(infos=vars(self._pipeline.model.config))

    # 对输入文本进行标记化处理的方法,接受 text_input 和 return_ids 两个参数
    def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
        """
        Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
        tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
        mapping.
        """
        try:
            # 使用 Pipeline 对象的 tokenizer 对输入文本进行标记化处理
            tokens_txt = self._pipeline.tokenizer.tokenize(text_input)

            # 如果 return_ids 为 True,则将标记化后的文本转换为对应的整数标识
            if return_ids:
                tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
                return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
            else:
                # 否则,返回标记化后的文本
                return ServeTokenizeResult(tokens=tokens_txt)

        # 捕获异常,并返回 HTTP 错误码 500 及错误详情
        except Exception as e:
            raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
    # Detokenize函数将token ids转换为可读文本
    def detokenize(
        self,
        tokens_ids: List[int] = Body(None, embed=True),  # 输入参数:token ids列表
        skip_special_tokens: bool = Body(False, embed=True),  # 是否跳过特殊token的标志
        cleanup_tokenization_spaces: bool = Body(True, embed=True),  # 是否清除token化空格的标志
    ):
        """
        Detokenize the provided tokens ids to readable text:
        - **tokens_ids**: List of tokens ids
        - **skip_special_tokens**: Flag indicating to not try to decode special tokens
        - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
        """
        try:
            # 使用tokenizer对象解码tokens_ids,根据skip_special_tokens和cleanup_tokenization_spaces的设置进行处理
            decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
            # 返回解码后的结果,作为ServeDeTokenizeResult的一部分
            return ServeDeTokenizeResult(model="", text=decoded_str)
        except Exception as e:
            # 如果出现异常,抛出HTTP异常,返回500状态码和错误详情
            raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})

    async def forward(self, inputs=Body(None, embed=True)):
        """
        **inputs**: **attention_mask**: **tokens_type_ids**:
        """

        # 检查输入是否为空字符串
        if len(inputs) == 0:
            # 如果为空,则返回空的ServeForwardResult对象,输出和attention都为空列表
            return ServeForwardResult(output=[], attention=[])

        try:
            # 通过模型处理输入数据
            output = self._pipeline(inputs)
            # 返回ServeForwardResult对象,包含模型输出的结果
            return ServeForwardResult(output=output)
        except Exception as e:
            # 如果出现异常,抛出HTTP异常,返回500状态码和错误详情
            raise HTTPException(500, {"error": str(e)})

.\commands\train.py

# 导入标准库中的os模块,用于处理操作系统相关的功能
import os
# 从argparse模块中导入ArgumentParser类和Namespace类,用于处理命令行参数
from argparse import ArgumentParser, Namespace

# 从..data包中导入SingleSentenceClassificationProcessor作为Processor
from ..data import SingleSentenceClassificationProcessor as Processor
# 从..pipelines包中导入TextClassificationPipeline,用于文本分类任务的流水线处理
from ..pipelines import TextClassificationPipeline
# 从..utils包中导入is_tf_available、is_torch_available、logging等工具函数和类
from ..utils import is_tf_available, is_torch_available, logging
# 从当前包的__init__.py中导入BaseTransformersCLICommand类
from . import BaseTransformersCLICommand

# 如果既没有安装TensorFlow也没有安装PyTorch,则抛出运行时异常
if not is_tf_available() and not is_torch_available():
    raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")

# TF训练参数设置
USE_XLA = False  # 是否使用XLA加速(TensorFlow专用)
USE_AMP = False  # 是否使用混合精度训练(TensorFlow专用)

def train_command_factory(args: Namespace):
    """
    工厂函数,根据给定的命令行参数实例化训练命令对象。

    Returns:
        TrainCommand: 实例化的训练命令对象
    """
    return TrainCommand(args)

class TrainCommand(BaseTransformersCLICommand):
    @staticmethod
    def register_subcommand(parser: ArgumentParser):
        """
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        """
        # 创建子命令 'train',用于训练模型
        train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")

        # 添加训练数据路径参数
        train_parser.add_argument(
            "--train_data",
            type=str,
            required=True,
            help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
        )

        # 添加数据集中标签所在列的参数
        train_parser.add_argument(
            "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
        )

        # 添加数据集中文本所在列的参数
        train_parser.add_argument(
            "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
        )

        # 添加数据集中ID所在列的参数
        train_parser.add_argument(
            "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
        )

        # 添加是否跳过CSV文件第一行(标题行)的参数
        train_parser.add_argument(
            "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
        )

        # 添加验证数据集路径参数
        train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")

        # 添加验证数据集分割比例参数
        train_parser.add_argument(
            "--validation_split",
            type=float,
            default=0.1,
            help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
        )

        # 添加保存训练模型的路径参数
        train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")

        # 添加训练任务类型参数
        train_parser.add_argument(
            "--task", type=str, default="text_classification", help="Task to train the model on."
        )

        # 添加模型名称或存储路径参数
        train_parser.add_argument(
            "--model", type=str, default="google-bert/bert-base-uncased", help="Model's name or path to stored model."
        )

        # 添加训练批次大小参数
        train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")

        # 添加验证批次大小参数
        train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")

        # 添加学习率参数
        train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")

        # 添加Adam优化器的epsilon参数
        train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")

        # 设置训练命令的默认函数工厂
        train_parser.set_defaults(func=train_command_factory)
    # 初始化方法,接受一个参数 Namespace 类型的 args
    def __init__(self, args: Namespace):
        # 设置日志记录器,命名为 "transformers-cli/training"
        self.logger = logging.get_logger("transformers-cli/training")

        # 根据是否可用 TensorFlow 设置框架为 "tf" 或 "torch"
        self.framework = "tf" if is_tf_available() else "torch"

        # 创建输出目录,如果已存在则不创建
        os.makedirs(args.output, exist_ok=True)
        self.output = args.output  # 设置输出目录路径

        # 设置用于标签、文本和ID的列名
        self.column_label = args.column_label
        self.column_text = args.column_text
        self.column_id = args.column_id

        # 记录加载任务和模型信息到日志
        self.logger.info(f"Loading {args.task} pipeline for {args.model}")

        # 根据任务类型加载不同的pipeline
        if args.task == "text_classification":
            self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
        elif args.task == "token_classification":
            raise NotImplementedError  # 抛出未实现错误
        elif args.task == "question_answering":
            raise NotImplementedError  # 抛出未实现错误

        # 记录加载训练数据集信息到日志
        self.logger.info(f"Loading dataset from {args.train_data}")

        # 从CSV文件创建训练数据集对象,使用指定的列名和参数
        self.train_dataset = Processor.create_from_csv(
            args.train_data,
            column_label=args.column_label,
            column_text=args.column_text,
            column_id=args.column_id,
            skip_first_row=args.skip_first_row,
        )

        # 初始化验证数据集为 None
        self.valid_dataset = None

        # 如果指定了验证数据集路径,则加载验证数据集信息到日志
        if args.validation_data:
            self.logger.info(f"Loading validation dataset from {args.validation_data}")

            # 从CSV文件创建验证数据集对象,使用指定的列名和参数
            self.valid_dataset = Processor.create_from_csv(
                args.validation_data,
                column_label=args.column_label,
                column_text=args.column_text,
                column_id=args.column_id,
                skip_first_row=args.skip_first_row,
            )

        # 设置验证集分割比例、训练批次大小、验证批次大小、学习率和Adam优化器的epsilon值
        self.validation_split = args.validation_split
        self.train_batch_size = args.train_batch_size
        self.valid_batch_size = args.valid_batch_size
        self.learning_rate = args.learning_rate
        self.adam_epsilon = args.adam_epsilon

    # 运行方法,根据框架类型调用相应的运行方法
    def run(self):
        if self.framework == "tf":
            return self.run_tf()  # 调用 TensorFlow 版本的运行方法
        return self.run_torch()  # 调用 PyTorch 版本的运行方法

    # 用于在 PyTorch 框架下运行的方法,抛出未实现错误
    def run_torch(self):
        raise NotImplementedError

    # 用于在 TensorFlow 框架下运行的方法,训练 pipeline 模型并保存
    def run_tf(self):
        # 使用训练数据集训练 pipeline 模型,同时指定验证数据集和训练参数
        self.pipeline.fit(
            self.train_dataset,
            validation_data=self.valid_dataset,
            validation_split=self.validation_split,
            learning_rate=self.learning_rate,
            adam_epsilon=self.adam_epsilon,
            train_batch_size=self.train_batch_size,
            valid_batch_size=self.valid_batch_size,
        )

        # 将训练好的 pipeline 模型保存到指定的输出目录
        self.pipeline.save_pretrained(self.output)

.\commands\transformers_cli.py

# 指定 Python 解释器的位置,并添加版权声明
#!/usr/bin/env python
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入命令行参数解析工具
from argparse import ArgumentParser

# 导入各个命令模块
from .add_new_model import AddNewModelCommand
from .add_new_model_like import AddNewModelLikeCommand
from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
from .lfs import LfsCommands
from .pt_to_tf import PTtoTFCommand
from .run import RunCommand
from .serving import ServeCommand
from .user import UserCommands

# 定义主函数
def main():
    # 创建参数解析器对象,并设置程序的名称和用法说明
    parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
    # 添加子命令的解析器
    commands_parser = parser.add_subparsers(help="transformers-cli command helpers")

    # 注册各个命令的子命令
    ConvertCommand.register_subcommand(commands_parser)
    DownloadCommand.register_subcommand(commands_parser)
    EnvironmentCommand.register_subcommand(commands_parser)
    RunCommand.register_subcommand(commands_parser)
    ServeCommand.register_subcommand(commands_parser)
    UserCommands.register_subcommand(commands_parser)
    AddNewModelCommand.register_subcommand(commands_parser)
    AddNewModelLikeCommand.register_subcommand(commands_parser)
    LfsCommands.register_subcommand(commands_parser)
    PTtoTFCommand.register_subcommand(commands_parser)

    # 检查是否有有效的命令函数被调用
    args = parser.parse_args()
    if not hasattr(args, "func"):
        parser.print_help()
        exit(1)

    # 运行选定的命令函数,并获取其服务对象
    service = args.func(args)
    service.run()

# 如果该脚本被作为主程序运行,则调用主函数
if __name__ == "__main__":
    main()

.\commands\user.py

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import subprocess
from argparse import ArgumentParser
from typing import List, Union

from huggingface_hub.hf_api import HfFolder, create_repo, whoami
from requests.exceptions import HTTPError

from . import BaseTransformersCLICommand


class UserCommands(BaseTransformersCLICommand):
    @staticmethod
    # 静态方法:注册子命令到给定的参数解析器
    def register_subcommand(parser: ArgumentParser):
        # 添加登录子命令解析器,用于登录到huggingface.co
        login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
        login_parser.set_defaults(func=lambda args: LoginCommand(args))
        
        # 添加whoami子命令解析器,用于查看当前登录的huggingface.co账户
        whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
        whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
        
        # 添加登出子命令解析器,用于退出登录
        logout_parser = parser.add_parser("logout", help="Log out")
        logout_parser.set_defaults(func=lambda args: LogoutCommand(args))

        # 新系统:基于git的存储库系统
        # 添加repo子命令解析器,与huggingface.co存储库交互的命令集(已弃用)
        repo_parser = parser.add_parser(
            "repo",
            help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
        )
        # repo子命令的子解析器集合,与huggingface.co存储库相关的命令(已弃用)
        repo_subparsers = repo_parser.add_subparsers(
            help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
        )
        
        # 添加create子命令解析器,创建一个新的huggingface.co存储库(已弃用)
        repo_create_parser = repo_subparsers.add_parser(
            "create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
        )
        # create子命令解析器的参数设置:存储库的名称,将被命名空间化在您的用户名下以构建模型ID
        repo_create_parser.add_argument(
            "name",
            type=str,
            help="Name for your model's repo. Will be namespaced under your username to build the model id.",
        )
        # 可选参数:组织命名空间
        repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
        # 可选参数:对提示回答是
        repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
        repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))


class ANSI:
    """
    Helper for en.wikipedia.org/wiki/ANSI_escape_code
    """

    _bold = "\u001b[1m"
    _red = "\u001b[31m"
    _gray = "\u001b[90m"
    _reset = "\u001b[0m"

    @classmethod
    # 类方法:使文本加粗
    def bold(cls, s):
        return f"{cls._bold}{s}{cls._reset}"

    @classmethod
    # 类方法:使文本显示为红色
    def red(cls, s):
        return f"{cls._bold}{cls._red}{s}{cls._reset}"
    # 类方法 `gray`,接受一个字符串 `s`,返回一个包含灰色文本的字符串
    def gray(cls, s):
        return f"{cls._gray}{s}{cls._reset}"
# 定义一个函数 tabulate,用于将二维列表按表格格式输出为字符串
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
    """
    Inspired by:

    - stackoverflow.com/a/8356620/593036
    - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
    """
    # 计算每列的最大宽度,包括表头
    col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
    # 根据列宽度创建格式化字符串,用于格式化每行数据
    row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
    # 初始化输出行列表
    lines = []
    # 添加表头行
    lines.append(row_format.format(*headers))
    # 添加分隔线行
    lines.append(row_format.format(*["-" * w for w in col_widths]))
    # 遍历每行数据,格式化后添加到输出行列表
    for row in rows:
        lines.append(row_format.format(*row))
    # 将所有行拼接成一个字符串,用换行符连接
    return "\n".join(lines)


class BaseUserCommand:
    def __init__(self, args):
        self.args = args


class LoginCommand(BaseUserCommand):
    def run(self):
        # 打印红色警告信息,指出登录命令已过时
        print(
            ANSI.red(
                "ERROR! `huggingface-cli login` uses an outdated login mechanism "
                "that is not compatible with the Hugging Face Hub backend anymore. "
                "Please use `huggingface-cli login instead."
            )
        )


class WhoamiCommand(BaseUserCommand):
    def run(self):
        # 打印红色警告信息,指出 whoami 命令已过时
        print(
            ANSI.red(
                "WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
                "`huggingface-cli whoami` instead."
            )
        )
        # 获取用户 token
        token = HfFolder.get_token()
        # 如果 token 为空,则打印未登录并退出程序
        if token is None:
            print("Not logged in")
            exit()
        try:
            # 调用 whoami 函数获取用户和组织信息
            user, orgs = whoami(token)
            # 打印用户信息
            print(user)
            # 如果有组织信息,则打印组织信息
            if orgs:
                print(ANSI.bold("orgs: "), ",".join(orgs))
        except HTTPError as e:
            # 捕获 HTTPError 异常,打印异常信息和响应内容,并退出程序
            print(e)
            print(ANSI.red(e.response.text))
            exit(1)


class LogoutCommand(BaseUserCommand):
    def run(self):
        # 打印红色警告信息,指出注销命令已过时
        print(
            ANSI.red(
                "ERROR! `transformers-cli logout` uses an outdated logout mechanism "
                "that is not compatible with the Hugging Face Hub backend anymore. "
                "Please use `huggingface-cli logout instead."
            )
        )


class RepoCreateCommand(BaseUserCommand):
    def run(self):
        # 打印警告信息,提示通过 transformers-cli 管理仓库已被弃用,建议使用 `huggingface-cli`
        print(
            ANSI.red(
                "WARNING! Managing repositories through transformers-cli is deprecated. "
                "Please use `huggingface-cli` instead."
            )
        )
        # 获取用户的令牌
        token = HfFolder.get_token()
        # 如果未获取到令牌,打印未登录信息,并退出程序
        if token is None:
            print("Not logged in")
            exit(1)
        try:
            # 检查并获取当前安装的 git 版本信息
            stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
            print(ANSI.gray(stdout.strip()))
        except FileNotFoundError:
            # 如果未找到 git 命令,提示用户未安装 git
            print("Looks like you do not have git installed, please install.")

        try:
            # 检查并获取当前安装的 git-lfs 版本信息
            stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
            print(ANSI.gray(stdout.strip()))
        except FileNotFoundError:
            # 如果未找到 git-lfs 命令,提示用户未安装 git-lfs,并提供安装指南
            print(
                ANSI.red(
                    "Looks like you do not have git-lfs installed, please install."
                    " You can install from https://git-lfs.github.com/."
                    " Then run `git lfs install` (you only have to do this once)."
                )
            )
        print("")

        # 获取当前登录用户信息
        user, _ = whoami(token)
        # 确定命名空间,可以是命令行参数中指定的组织,也可以是当前用户
        namespace = self.args.organization if self.args.organization is not None else user
        # 组装完整的仓库名称
        full_name = f"{namespace}/{self.args.name}"
        # 打印即将创建的仓库名称,使用 ANSI 加粗样式
        print(f"You are about to create {ANSI.bold(full_name)}")

        # 如果不是自动确认模式,询问用户是否继续
        if not self.args.yes:
            choice = input("Proceed? [Y/n] ").lower()
            if not (choice == "" or choice == "y" or choice == "yes"):
                # 如果用户选择不继续,打印中止信息并退出程序
                print("Abort")
                exit()
        try:
            # 创建仓库,并获取返回的仓库 URL
            url = create_repo(token, name=self.args.name, organization=self.args.organization)
        except HTTPError as e:
            # 如果发生 HTTP 错误,打印错误信息和响应内容,并退出程序
            print(e)
            print(ANSI.red(e.response.text))
            exit(1)
        # 打印创建成功后的仓库 URL
        print("\nYour repo now lives at:")
        print(f"  {ANSI.bold(url)}")
        # 提示用户可以通过克隆命令将仓库克隆到本地,并正常进行提交和推送操作
        print("\nYou can clone it locally with the command below, and commit/push as usual.")
        print(f"\n  git clone {url}")
        print("")

.\commands\__init__.py

# 引入抽象基类(ABC)和抽象方法(abstractmethod)来定义一个基于命令行接口(CLI)的Transformers库命令的基类
from abc import ABC, abstractmethod
# 从argparse模块导入ArgumentParser类,用于解析命令行参数
from argparse import ArgumentParser

# 定义一个抽象基类BaseTransformersCLICommand,继承自ABC类
class BaseTransformersCLICommand(ABC):
    # 声明一个静态方法,用于注册子命令到给定的ArgumentParser对象中
    @staticmethod
    @abstractmethod
    def register_subcommand(parser: ArgumentParser):
        raise NotImplementedError()

    # 声明一个抽象方法run,表示子类需要实现的运行方法
    @abstractmethod
    def run(self):
        raise NotImplementedError()

.\configuration_utils.py

# coding=utf-8
# 版权声明及许可证信息

""" Configuration base class and utilities."""
# 导入必要的库和模块
import copy  # 用于对象的深拷贝操作
import json  # 用于 JSON 数据的处理
import os  # 提供与操作系统相关的功能
import re  # 提供正则表达式的支持
import warnings  # 用于发出警告信息
from typing import Any, Dict, List, Optional, Tuple, Union  # 引入类型提示功能

from packaging import version  # 用于版本号处理

from . import __version__  # 导入当前模块的版本信息
from .dynamic_module_utils import custom_object_save  # 导入自定义对象保存函数
from .utils import (  # 导入一些工具函数和常量
    CONFIG_NAME,  # 配置文件名常量
    PushToHubMixin,  # 提供向 Hub 推送功能的混合类
    add_model_info_to_auto_map,  # 将模型信息添加到自动映射的函数
    cached_file,  # 缓存文件的函数
    copy_func,  # 函数复制的工具函数
    download_url,  # 下载 URL 资源的函数
    extract_commit_hash,  # 提取提交哈希的函数
    is_remote_url,  # 判断是否是远程 URL 的函数
    is_torch_available,  # 判断是否可用 PyTorch 的函数
    logging,  # 日志记录模块
)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

_re_configuration_file = re.compile(r"config\.(.*)\.json")  # 编译用于匹配配置文件名的正则表达式


class PretrainedConfig(PushToHubMixin):
    # no-format
    r"""
    Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
    methods for loading/downloading/saving configurations.

    <Tip>

    A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
    initialize a model does **not** load the model weights. It only affects the model's configuration.

    </Tip>

    Class attributes (overridden by derived classes):

    - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
      the correct object in [`~transformers.AutoConfig`].
    - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
      config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
      [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
    - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
      outputs of the model during inference.
    - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
      naming of attributes.

    Common attributes (present in all subclasses):

    - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
      embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
    - **hidden_size** (`int`) -- The hidden size of the model.

    """
    # 定义了一个预训练配置类 PretrainedConfig,是所有配置类的基类,包含了通用的模型配置参数和加载/保存配置的方法
    # 没有额外的代码需要注释
    # `model_type` 是模型的类型描述字符串
    model_type: str = ""
    # `is_composition` 表示模型是否是一个组合模型,默认为 False
    is_composition: bool = False
    # `attribute_map` 是一个映射,用于属性重命名
    attribute_map: Dict[str, str] = {}
    # `_auto_class` 是一个私有属性,用于存储自动类的名称,可选为 None
    _auto_class: Optional[str] = None

    def __setattr__(self, key, value):
        # 自定义的属性设置方法,用于根据 `attribute_map` 重命名属性
        if key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key]
        super().__setattr__(key, value)

    def __getattribute__(self, key):
        # 自定义的属性获取方法,用于根据 `attribute_map` 重命名属性
        if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
            key = super().__getattribute__("attribute_map")[key]
        return super().__getattribute__(key)

    @property
    def name_or_path(self) -> str:
        # 返回模型的名称或路径,作为 `_name_or_path` 的值
        return getattr(self, "_name_or_path", None)

    @name_or_path.setter
    def name_or_path(self, value):
        # 设置模型的名称或路径,确保为字符串类型(用于 JSON 编码)
        self._name_or_path = str(value)

    @property
    def use_return_dict(self) -> bool:
        """
        `bool`: 是否返回 [`~utils.ModelOutput`] 而不是元组。
        """
        # 如果设置了 torchscript,强制 `return_dict=False` 以避免 JIT 错误
        return self.return_dict and not self.torchscript

    @property
    def num_labels(self) -> int:
        """
        `int`: 分类模型的标签数量。
        """
        # 返回模型的标签数量,基于 `id2label` 的长度
        return len(self.id2label)

    @num_labels.setter
    def num_labels(self, num_labels: int):
        # 设置模型的标签数量,如果 `id2label` 不存在或长度不符合,则重新生成标签映射
        if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
            self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
            self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))

    @property
    def _attn_implementation(self):
        """
        `str`: 注意力机制的实现方式。
        """
        # 私有属性,返回注意力机制的实现方式,默认为 "eager"
        if hasattr(self, "_attn_implementation_internal"):
            if self._attn_implementation_internal is None:
                return "eager"
            else:
                return self._attn_implementation_internal
        else:
            return "eager"

    @_attn_implementation.setter
    def _attn_implementation(self, value):
        # 设置注意力机制的实现方式
        self._attn_implementation_internal = value
    @classmethod
    def _set_token_in_kwargs(kwargs, token=None):
        """在 kwargs 中设置 `token` 参数。

        这个方法是为了避免在所有模型配置类中重复应用相同的更改,这些类重写了 `from_pretrained` 方法。

        需要在随后的 PR 中清理 `use_auth_token`。
        """
        # 一些模型配置类(如 CLIP)定义了自己的 `from_pretrained` 方法,但还没有新参数 `token`。
        if token is None:
            token = kwargs.pop("token", None)
        use_auth_token = kwargs.pop("use_auth_token", None)

        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if token is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            token = use_auth_token

        # 如果存在 token,则将其添加到 kwargs 中
        if token is not None:
            kwargs["token"] = token

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
    ):
        """
        从 `pretrained_model_name_or_path` 解析出参数字典,用于通过 `from_dict` 实例化 `PretrainedConfig`。

        参数:
            pretrained_model_name_or_path (`str` 或 `os.PathLike`):
                想要获取参数字典的预训练检查点的标识符。

        返回:
            `Tuple[Dict, Dict]`: 将用于实例化配置对象的字典。

        """
        # 调用 `_set_token_in_kwargs` 方法,设置 `token` 参数
        cls._set_token_in_kwargs(kwargs)

        original_kwargs = copy.deepcopy(kwargs)
        # 获取与基本配置文件关联的配置字典
        config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
        if "_commit_hash" in config_dict:
            original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

        # 可能会指向另一个要使用的配置文件。
        if "configuration_files" in config_dict:
            configuration_file = get_configuration_file(config_dict["configuration_files"])
            config_dict, kwargs = cls._get_config_dict(
                pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
            )

        return config_dict, kwargs

    @classmethod
    # 定义类方法 `_get_config_dict`,用于获取配置信息的字典
    def _get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ):
    # 类方法:从给定的配置字典中实例化一个预训练配置对象 [`PretrainedConfig`]。

    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
        """
        Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.

        Args:
            config_dict (`Dict[str, Any]`):
                Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
                retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
            kwargs (`Dict[str, Any]`):
                Additional parameters from which to initialize the configuration object.

        Returns:
            [`PretrainedConfig`]: The configuration object instantiated from those parameters.
        """
        # 从 kwargs 中弹出 "return_unused_kwargs" 参数,如果没有则默认为 False
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
        # 从 kwargs 中移除 "_from_auto" 和 "_from_pipeline" 参数,避免它们出现在 `return_unused_kwargs` 中
        kwargs.pop("_from_auto", None)
        kwargs.pop("_from_pipeline", None)
        
        # 如果配置字典中包含 "_commit_hash",则更新 kwargs 中的 "_commit_hash",以防 kwargs 覆盖这个更新
        if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
            kwargs["_commit_hash"] = config_dict["_commit_hash"]

        # 将 kwargs 中的 "attn_implementation" 参数移除,并将其设置为 config_dict 中的值
        config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)

        # 使用 config_dict 实例化一个 cls 类型的配置对象 config
        config = cls(**config_dict)

        # 如果配置对象 config 有 "pruned_heads" 属性,则将其键转换为整数
        if hasattr(config, "pruned_heads"):
            config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}

        # 如果 kwargs 中包含 "num_labels" 和 "id2label",则验证它们是否兼容
        if "num_labels" in kwargs and "id2label" in kwargs:
            num_labels = kwargs["num_labels"]
            id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
            if len(id2label) != num_labels:
                raise ValueError(
                    f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
                    f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
                    "one of them."
                )

        # 准备从配置对象 config 中移除的参数列表
        to_remove = []
        # 遍历 kwargs 中的键值对
        for key, value in kwargs.items():
            # 如果 config 中有对应的属性 key,则将其设置为 value
            if hasattr(config, key):
                current_attr = getattr(config, key)
                # 如果当前属性是 PretrainedConfig 类型且 value 是字典,则将其转换为相应的子配置
                if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
                    value = current_attr.__class__(**value)
                setattr(config, key, value)
                # 将 key 添加到待移除列表中(除了 "torch_dtype")
                if key != "torch_dtype":
                    to_remove.append(key)
        # 从 kwargs 中移除已处理的键值对
        for key in to_remove:
            kwargs.pop(key, None)

        # 记录配置对象 config 的信息
        logger.info(f"Model config {config}")
        # 如果需要返回未使用的 kwargs,则返回配置对象和剩余的 kwargs
        if return_unused_kwargs:
            return config, kwargs
        else:
            # 否则只返回配置对象
            return config

    @classmethod
    # 从 JSON 文件中读取配置并实例化一个 PretrainedConfig 对象
    def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
        """
        Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.

        Args:
            json_file (`str` or `os.PathLike`):
                Path to the JSON file containing the parameters.

        Returns:
            [`PretrainedConfig`]: The configuration object instantiated from that JSON file.

        """
        # 从 JSON 文件中读取配置信息并转换成字典形式
        config_dict = cls._dict_from_json_file(json_file)
        # 使用字典中的配置参数实例化一个 PretrainedConfig 对象
        return cls(**config_dict)

    @classmethod
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
        """
        Reads and parses a JSON file into a dictionary.

        Args:
            json_file (`str` or `os.PathLike`):
                Path to the JSON file.

        Returns:
            dict: Dictionary containing the parsed JSON content.

        """
        # 打开 JSON 文件,读取其中的文本内容
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        # 将读取的 JSON 文本解析为字典对象
        return json.loads(text)

    # 定义相等性比较方法,用于比较两个 PretrainedConfig 对象是否相等
    def __eq__(self, other):
        return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)

    # 定义对象的字符串表示方法,返回包含 JSON 字符串的对象表示形式
    def __repr__(self):
        return f"{self.__class__.__name__} {self.to_json_string()}"
    def to_diff_dict(self) -> Dict[str, Any]:
        """
        Removes all attributes from config which correspond to the default config attributes for better readability and
        serializes to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        # Serialize current configuration instance to a dictionary
        config_dict = self.to_dict()

        # Get default configuration dictionary
        default_config_dict = PretrainedConfig().to_dict()

        # Get class-specific configuration dictionary
        class_config_dict = self.__class__().to_dict() if not self.is_composition else {}

        serializable_config_dict = {}

        # Iterate over each key-value pair in the current configuration dictionary
        for key, value in config_dict.items():
            # Check if the attribute is a PretrainedConfig instance and differs from class-specific config
            if (
                isinstance(getattr(self, key, None), PretrainedConfig)
                and key in class_config_dict
                and isinstance(class_config_dict[key], dict)
            ):
                # Recursive diff for nested configurations
                diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
                # Ensure model_type is set even if not in the diff
                if "model_type" in value:
                    diff["model_type"] = value["model_type"]
                # Include in serializable dictionary if there are differences
                if len(diff) > 0:
                    serializable_config_dict[key] = diff
            # Include if key not in default config, or values differ from default or class-specific configs
            elif (
                key not in default_config_dict
                or key == "transformers_version"
                or value != default_config_dict[key]
                or (key in class_config_dict and value != class_config_dict[key])
            ):
                serializable_config_dict[key] = value

        # Handle special case for quantization_config
        if hasattr(self, "quantization_config"):
            if isinstance(self.quantization_config, dict):
                serializable_config_dict["quantization_config"] = self.quantization_config
            else:
                serializable_config_dict["quantization_config"] = self.quantization_config.to_dict()

            # Remove _pre_quantization_dtype as it's not serializable
            _ = serializable_config_dict.pop("_pre_quantization_dtype", None)

        # Convert torch dtypes to strings in the dictionary
        self.dict_torch_dtype_to_str(serializable_config_dict)

        # Remove internal implementation detail if present
        if "_attn_implementation_internal" in serializable_config_dict:
            del serializable_config_dict["_attn_implementation_internal"]

        return serializable_config_dict
    def to_dict(self) -> Dict[str, Any]:
        """
        将当前实例序列化为一个 Python 字典。

        Returns:
            `Dict[str, Any]`: 包含构成该配置实例的所有属性的字典。
        """
        # 深拷贝实例的所有属性到输出字典
        output = copy.deepcopy(self.__dict__)
        # 如果类定义了 model_type 属性,则将其加入输出字典
        if hasattr(self.__class__, "model_type"):
            output["model_type"] = self.__class__.model_type
        # 删除输出字典中的特定内部属性
        if "_auto_class" in output:
            del output["_auto_class"]
        if "_commit_hash" in output:
            del output["_commit_hash"]
        if "_attn_implementation_internal" in output:
            del output["_attn_implementation_internal"]

        # 添加 Transformers 的版本信息到输出字典
        output["transformers_version"] = __version__

        # 处理嵌套的配置(例如 CLIP),将其转换为字典形式
        for key, value in output.items():
            if isinstance(value, PretrainedConfig):
                value = value.to_dict()
                # 移除嵌套配置中的 Transformers 版本信息
                del value["transformers_version"]
            output[key] = value

        # 如果实例有 quantization_config 属性,将其转换为字典形式并加入输出字典
        if hasattr(self, "quantization_config"):
            output["quantization_config"] = (
                self.quantization_config.to_dict()
                if not isinstance(self.quantization_config, dict)
                else self.quantization_config
            )

            # 移除输出字典中的 _pre_quantization_dtype 属性,因为 torch.dtypes 不可序列化
            _ = output.pop("_pre_quantization_dtype", None)

        # 对输出字典中的 torch 数据类型进行转换处理
        self.dict_torch_dtype_to_str(output)

        # 返回最终的输出字典
        return output

    def to_json_string(self, use_diff: bool = True) -> str:
        """
        将当前实例序列化为 JSON 字符串。

        Args:
            use_diff (`bool`, *optional*, 默认为 `True`):
                如果设置为 `True`,则只序列化配置实例与默认 `PretrainedConfig()` 之间的差异。

        Returns:
            `str`: 包含构成该配置实例的所有属性的 JSON 格式字符串。
        """
        # 根据 use_diff 参数决定是否只序列化差异部分
        if use_diff is True:
            config_dict = self.to_diff_dict()
        else:
            config_dict = self.to_dict()
        # 将字典转换为 JSON 字符串,缩进为 2,按键排序
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
        """
        将当前实例保存为 JSON 文件。

        Args:
            json_file_path (`str` 或 `os.PathLike`):
                保存配置实例参数的 JSON 文件路径。
            use_diff (`bool`, *optional*, 默认为 `True`):
                如果设置为 `True`,则只序列化配置实例与默认 `PretrainedConfig()` 之间的差异。
        """
        # 打开指定路径的 JSON 文件,将实例转换为 JSON 字符串并写入文件
        with open(json_file_path, "w", encoding="utf-8") as writer:
            writer.write(self.to_json_string(use_diff=use_diff))
    def update(self, config_dict: Dict[str, Any]):
        """
        Updates attributes of this class with attributes from `config_dict`.

        Args:
            config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
        """
        # 遍历传入的字典,将每个键值对应用到当前类的属性上
        for key, value in config_dict.items():
            setattr(self, key, value)

    def update_from_string(self, update_str: str):
        """
        Updates attributes of this class with attributes from `update_str`.

        The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
        "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"

        The keys to change have to already exist in the config object.

        Args:
            update_str (`str`): String with attributes that should be updated for this class.

        """
        # 将传入的字符串按逗号分割成键值对,构建字典
        d = dict(x.split("=") for x in update_str.split(","))
        # 遍历字典中的每个键值对
        for k, v in d.items():
            # 检查当前类是否存在名为 k 的属性
            if not hasattr(self, k):
                raise ValueError(f"key {k} isn't in the original config dict")

            # 获取当前属性的旧值
            old_v = getattr(self, k)
            # 根据旧值的类型转换新值 v 的类型,并设置为当前类的属性
            if isinstance(old_v, bool):
                if v.lower() in ["true", "1", "y", "yes"]:
                    v = True
                elif v.lower() in ["false", "0", "n", "no"]:
                    v = False
                else:
                    raise ValueError(f"can't derive true or false from {v} (key {k})")
            elif isinstance(old_v, int):
                v = int(v)
            elif isinstance(old_v, float):
                v = float(v)
            elif not isinstance(old_v, str):
                raise ValueError(
                    f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
                )

            setattr(self, k, v)

    def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
        """
        Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
        converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
        string, which can then be stored in the json format.
        """
        # 检查传入的字典是否包含名为 torch_dtype 的键,并且其值不为 None
        if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
            # 将 torch.dtype 转换为只包含类型的字符串,例如将 torch.float32 转换为 "float32"
            d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
        # 递归处理字典中的每个值,如果值是字典,则继续调用该方法
        for value in d.values():
            if isinstance(value, dict):
                self.dict_torch_dtype_to_str(value)

    @classmethod
    def`
# 注册自动配置类方法,用于将当前类注册到指定的自动配置类中
def register_for_auto_class(cls, auto_class="AutoConfig"):
    """
    Register this class with a given auto class. This should only be used for custom configurations as the ones in
    the library are already mapped with `AutoConfig`.

    <Tip warning={true}>
    This API is experimental and may have some slight breaking changes in the next releases.
    </Tip>

    Args:
        auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
            The auto class to register this new configuration with.
    """
    # 如果 auto_class 不是字符串,将其转换为类名字符串
    if not isinstance(auto_class, str):
        auto_class = auto_class.__name__

    # 导入 transformers.models.auto 模块
    import transformers.models.auto as auto_module

    # 如果 auto_class 在 auto_module 中不存在,抛出 ValueError
    if not hasattr(auto_module, auto_class):
        raise ValueError(f"{auto_class} is not a valid auto class.")

    # 将 auto_class 赋值给当前类的 _auto_class 属性
    cls._auto_class = auto_class

@staticmethod
# 返回默认的生成参数字典
def _get_generation_defaults() -> Dict[str, Any]:
    return {
        "max_length": 20,
        "min_length": 0,
        "do_sample": False,
        "early_stopping": False,
        "num_beams": 1,
        "num_beam_groups": 1,
        "diversity_penalty": 0.0,
        "temperature": 1.0,
        "top_k": 50,
        "top_p": 1.0,
        "typical_p": 1.0,
        "repetition_penalty": 1.0,
        "length_penalty": 1.0,
        "no_repeat_ngram_size": 0,
        "encoder_no_repeat_ngram_size": 0,
        "bad_words_ids": None,
        "num_return_sequences": 1,
        "output_scores": False,
        "return_dict_in_generate": False,
        "forced_bos_token_id": None,
        "forced_eos_token_id": None,
        "remove_invalid_values": False,
        "exponential_decay_length_penalty": None,
        "suppress_tokens": None,
        "begin_suppress_tokens": None,
    }

# 判断当前实例是否具有非默认生成参数
def _has_non_default_generation_parameters(self) -> bool:
    """
    Whether or not this instance holds non-default generation parameters.
    """
    # 获取默认的生成参数字典
    defaults = self._get_generation_defaults()

    # 遍历生成参数字典,检查当前实例是否有非默认值的生成参数
    for parameter_name, default_value in defaults.items():
        if hasattr(self, parameter_name) and getattr(self, parameter_name) != default_value:
            return True
    return False
# 获取用于此版本 transformers 的配置文件。
def get_configuration_file(configuration_files: List[str]) -> str:
    """
    Get the configuration file to use for this version of transformers.

    Args:
        configuration_files (`List[str]`): The list of available configuration files.

    Returns:
        `str`: The configuration file to use.
    """
    # 初始化一个空字典,用于存储版本号与配置文件名的映射关系
    configuration_files_map = {}
    # 遍历每个配置文件名
    for file_name in configuration_files:
        # 使用正则表达式搜索文件名中的版本号信息
        search = _re_configuration_file.search(file_name)
        # 如果找到匹配项
        if search is not None:
            # 提取版本号信息并存储到字典中
            v = search.groups()[0]
            configuration_files_map[v] = file_name

    # 对版本号进行排序
    available_versions = sorted(configuration_files_map.keys())

    # 默认使用 FULL_CONFIGURATION_FILE,然后尝试使用一些更新的版本
    configuration_file = CONFIG_NAME
    transformers_version = version.parse(__version__)
    # 遍历所有可用版本
    for v in available_versions:
        # 如果当前版本小于等于 transformers 的版本
        if version.parse(v) <= transformers_version:
            # 更新配置文件为对应版本的配置文件
            configuration_file = configuration_files_map[v]
        else:
            # 因为版本已排序,所以不再继续查找
            break

    # 返回选择的配置文件名
    return configuration_file


# 递归比较两个嵌套字典的差异,返回仅包含 dict_a 中不同于 dict_b 的值的字典
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
    """
    Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
    values from `dict_a` that are different from values in `dict_b`.
    """
    # 初始化一个空字典,用于存储差异
    diff = {}
    # 如果传入了 config_obj 参数,则获取其默认配置的字典表示
    default = config_obj.__class__().to_dict() if config_obj is not None else {}
    # 遍历 dict_a 的每一个键值对
    for key, value in dict_a.items():
        # 尝试从 config_obj 中获取与当前键对应的值
        obj_value = getattr(config_obj, str(key), None)
        # 如果 obj_value 是 PretrainedConfig 类型,并且 dict_b 中存在当前键,并且 dict_b 中的值也是字典
        if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
            # 递归调用自身,比较当前值与 dict_b[key] 的差异
            diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
            # 如果有差异,则将其存储到 diff 字典中
            if len(diff_value) > 0:
                diff[key] = diff_value
        # 如果当前键不在 dict_b 中,或者当前值与 dict_b[key] 的值不同,或者当前键在 default 中但值不同于 default 中的值
        elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
            # 将当前键值对存储到 diff 字典中
            diff[key] = value
    # 返回差异字典
    return diff


# 将 PretrainedConfig 类的 push_to_hub 方法复制给 PretrainedConfig.push_to_hub
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
# 如果 PretrainedConfig.push_to_hub 方法有文档字符串
if PretrainedConfig.push_to_hub.__doc__ is not None:
    # 使用格式化字符串,将文档字符串中的占位符替换为实际值
    PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
        object="config", object_class="AutoConfig", object_files="configuration file"
    )

.\convert_graph_to_onnx.py

# 版权声明和许可信息
# 版权所有 2020 年 HuggingFace 团队保留所有权利。
# 
# 根据 Apache 许可证 2.0 版本(“许可证”)许可;
# 除非符合许可证的规定,否则您不得使用此文件。
# 您可以在以下网址获取许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则依据许可证分发的软件是基于“按原样”基础分发的,
# 没有任何明示或暗示的保证或条件。
# 请查阅许可证以获取具体的法律语言。
#

import warnings  # 导入警告模块
from argparse import ArgumentParser  # 从 argparse 模块导入 ArgumentParser 类
from os import listdir, makedirs  # 从 os 模块导入 listdir 和 makedirs 函数
from pathlib import Path  # 导入 Path 类
from typing import Dict, List, Optional, Tuple  # 导入类型提示

from packaging.version import Version, parse  # 从 packaging.version 模块导入 Version 和 parse 函数

# 导入 transformers 库中的相关模块和类
from transformers.pipelines import Pipeline, pipeline  
from transformers.tokenization_utils import BatchEncoding  
from transformers.utils import ModelOutput, is_tf_available, is_torch_available  

# 定义最小支持的 ONNX Runtime 版本
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")

# 定义支持的 pipeline 类型列表
SUPPORTED_PIPELINES = [
    "feature-extraction",
    "ner",
    "sentiment-analysis",
    "fill-mask",
    "question-answering",
    "text-generation",
    "translation_en_to_fr",
    "translation_en_to_de",
    "translation_en_to_ro",
]

# 定义一个 ArgumentParser 的子类,用于解析 ONNX 转换器的命令行参数
class OnnxConverterArgumentParser(ArgumentParser):
    """
    Wraps all the script arguments supported to export transformers models to ONNX IR
    """

    def __init__(self):
        super().__init__("ONNX Converter")  # 调用父类构造函数,设置解析器的描述信息为 "ONNX Converter"

        # 添加命令行参数
        self.add_argument(
            "--pipeline",
            type=str,
            choices=SUPPORTED_PIPELINES,
            default="feature-extraction",
        )
        self.add_argument(
            "--model",
            type=str,
            required=True,
            help="Model's id or path (ex: google-bert/bert-base-cased)",
        )
        self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
        self.add_argument(
            "--framework",
            type=str,
            choices=["pt", "tf"],
            help="Framework for loading the model",
        )
        self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
        self.add_argument(
            "--check-loading",
            action="store_true",
            help="Check ONNX is able to load the model",
        )
        self.add_argument(
            "--use-external-format",
            action="store_true",
            help="Allow exporting model >= than 2Gb",
        )
        self.add_argument(
            "--quantize",
            action="store_true",
            help="Quantize the neural network to be run with int8",
        )
        self.add_argument("output")  # 添加输出参数

# 定义一个函数,生成带有标识符的文件名
def generate_identified_filename(filename: Path, identifier: str) -> Path:
    """
    # 在提供的文件路径末尾(在扩展名之前,如果有的话)添加一个字符串标识符
    
    Args:
        filename: pathlib.Path 实际的路径对象,我们希望在其末尾添加标识符后缀
        identifier: 要添加的后缀
    
    Returns: 添加了标识符的字符串,连接在文件名的末尾
# 检查 onnxruntime 的安装情况及版本是否符合要求
def check_onnxruntime_requirements(minimum_version: Version):
    """
    Check onnxruntime is installed and if the installed version match is recent enough

    Raises:
        ImportError: If onnxruntime is not installed or too old version is found
    """
    try:
        import onnxruntime

        # 解析已安装的 onnxruntime 的版本
        ort_version = parse(onnxruntime.__version__)

        # 要求最低版本为 1.4.0
        if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
            raise ImportError(
                f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
                f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
                "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
            )

    except ImportError:
        raise ImportError(
            "onnxruntime doesn't seem to be currently installed. "
            "Please install the onnxruntime by running `pip install onnxruntime`"
            " and relaunch the conversion."
        )


# 确保输入在正确顺序中,没有非法输入
def ensure_valid_input(model, tokens, input_names):
    """
    Ensure inputs are presented in the correct order, without any Non

    Args:
        model: The model used to forward the input data
        tokens: BatchEncoding holding the input data
        input_names: The name of the inputs

    Returns: Tuple

    """
    print("Ensuring inputs are in correct order")

    # 获取模型前向方法的参数名列表
    model_args_name = model.forward.__code__.co_varnames
    model_args, ordered_input_names = [], []
    for arg_name in model_args_name[1:]:  # 从索引1开始以跳过 "self" 参数
        if arg_name in input_names:
            ordered_input_names.append(arg_name)
            model_args.append(tokens[arg_name])
        else:
            print(f"{arg_name} is not present in the generated input list.")
            break

    # 打印生成的输入顺序
    print(f"Generated inputs order: {ordered_input_names}")
    return ordered_input_names, tuple(model_args)


# 推断模型输入输出张量的静态与动态轴
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
    """
    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model

    Args:
        nlp: The pipeline object holding the model to be exported
        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)

    Returns:

        - List of the inferred input variable names
        - List of the inferred output variable names
        - Dictionary with input/output variables names as key and shape tensor as value
        - a BatchEncoding reference which was used to infer all the above information
    """
    def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
        # 如果 tensor 是元组或列表,则递归调用 build_shape_dict 处理每个元素
        if isinstance(tensor, (tuple, list)):
            return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
        
        else:
            # 假设第一个维度是批处理维度,且只有一个元素
            axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
            # 如果是输入数据,判断维度是否为二维,将第二个维度标记为 "sequence"
            if is_input:
                if len(tensor.shape) == 2:
                    axes[1] = "sequence"
                else:
                    raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
            else:
                # 找到与指定序列长度相匹配的维度,并将其标记为 "sequence"
                seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
                axes.update({dim: "sequence" for dim in seq_axes})

        # 打印找到的输入或输出的名称、形状信息
        print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
        return axes

    # 使用 NLP 模型的分词器生成 tokens,并返回张量表示
    tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
    # 获取序列长度
    seq_len = tokens.input_ids.shape[-1]
    # 根据框架类型调用 NLP 模型
    outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
    # 如果输出是 ModelOutput 类型,则转换为元组
    if isinstance(outputs, ModelOutput):
        outputs = outputs.to_tuple()
    # 如果输出不是列表或元组,则将其包装成元组
    if not isinstance(outputs, (list, tuple)):
        outputs = (outputs,)

    # 生成输入变量的名称及其动态轴信息
    input_vars = list(tokens.keys())
    input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}

    # 将可能包含分组输出(例如 gpt2 中的过去状态或注意力)展平
    outputs_flat = []
    for output in outputs:
        if isinstance(output, (tuple, list)):
            outputs_flat.extend(output)
        else:
            outputs_flat.append(output)

    # 生成输出变量的名称及其动态轴信息
    output_names = [f"output_{i}" for i in range(len(outputs_flat))]
    output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}

    # 创建汇总的动态轴表示
    dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
    return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(
    pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
) -> Pipeline:
    """
    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model)

    Args:
        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
        framework: The actual model to convert the pipeline from ("pt" or "tf")
        model: The model name which will be loaded by the pipeline
        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value

    Returns: Pipeline object
    """
    # 如果未提供 tokenizer,则使用 model 作为 tokenizer
    if tokenizer is None:
        tokenizer = model

    # 检查所需的 framework 是否可用
    if framework == "pt" and not is_torch_available():
        raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
    if framework == "tf" and not is_tf_available():
        raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")

    print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")

    # 分配 tokenizer 和 model
    return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)


def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
    """
    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR)

    Args:
        nlp: The pipeline to be exported
        opset: The actual version of the ONNX operator set to use
        output: Path where will be stored the generated ONNX model
        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB

    Returns:
    """
    if not is_torch_available():
        raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")

    import torch
    from torch.onnx import export

    print(f"Using framework PyTorch: {torch.__version__}")

    # 通过 infer_shapes 推断输入、输出和动态轴
    with torch.no_grad():
        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
        # 确保输入名称有效,并按顺序提供模型参数
        ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)

        # 导出模型到 ONNX
        export(
            nlp.model,
            model_args,
            f=output.as_posix(),
            input_names=ordered_input_names,
            output_names=output_names,
            dynamic_axes=dynamic_axes,
            do_constant_folding=True,
            opset_version=opset,
        )


def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
    """
    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)

    Args:
        nlp: The pipeline to be exported
        opset: The actual version of the ONNX operator set to use
        output: Path where will be stored the generated ONNX model
    """
    # 检查是否安装了 TensorFlow
    if not is_tf_available():
        raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
    # 检查 TensorFlow 是否可用,若不可用则引发异常提示安装 TensorFlow
    if not is_tf_available():
        raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
    
    # 提示用户注意:TensorFlow 不支持导出超过2GB的模型
    print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
    
    try:
        # 尝试导入 TensorFlow 和 tf2onnx
        import tensorflow as tf
        import tf2onnx
        from tf2onnx import __version__ as t2ov
        
        # 打印当前使用的框架和 tf2onnx 的版本信息
        print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
    
        # 推断模型输入形状等信息,并获取 tokens
        input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
    
        # 使用模型进行前向推断
        nlp.model.predict(tokens.data)
        
        # 根据 tokens 的数据创建输入签名
        input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
        
        # 使用 tf2onnx 将 Keras 模型转换为 ONNX 格式
        model_proto, _ = tf2onnx.convert.from_keras(
            nlp.model, input_signature, opset=opset, output_path=output.as_posix()
        )
    
    except ImportError as e:
        # 若导入出错,引发异常提示缺少必要的包
        raise Exception(
            f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
        )
# 定义一个函数 convert,用于将管道对象转换为 ONNX 中间表示(IR)格式
def convert(
    framework: str,
    model: str,
    output: Path,
    opset: int,
    tokenizer: Optional[str] = None,
    use_external_format: bool = False,
    pipeline_name: str = "feature-extraction",
    **model_kwargs,
):
    """
    Convert the pipeline object to the ONNX Intermediate Representation (IR) format

    Args:
        framework: 管道所使用的框架 ("pt" 或 "tf")
        model: 管道加载的模型名称
        output: 存储 ONNX 图的路径
        opset: 使用的 ONNX 运算集的实际版本
        tokenizer: 管道所使用的分词器名称,如果未提供则默认使用模型名称
        use_external_format:
            是否将模型定义与其参数分离,以允许超过 2GB 的模型大小(仅适用于 PyTorch)
        pipeline_name: 实例化的管道类型(ner、question-answering 等)
        model_kwargs: 转发给模型构造函数的关键字参数

    Returns:

    """
    # 发出警告,指示 `transformers.convert_graph_to_onnx` 包已过时,并将在 Transformers 的第五个版本中移除
    warnings.warn(
        "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
        " Transformers",
        FutureWarning,
    )
    # 打印设置的 ONNX 运算集版本号
    print(f"ONNX opset version set to: {opset}")

    # 加载管道对象
    nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)

    # 检查输出路径的父目录是否存在,若不存在则创建
    if not output.parent.exists():
        print(f"Creating folder {output.parent}")
        makedirs(output.parent.as_posix())
    # 若输出路径的父目录非空,则抛出异常
    elif len(listdir(output.parent.as_posix())) > 0:
        raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")

    # 根据不同的框架导出图
    if framework == "pt":
        convert_pytorch(nlp, opset, output, use_external_format)
    else:
        convert_tensorflow(nlp, opset, output)


# 定义一个函数 optimize,用于优化 ONNX 模型
def optimize(onnx_model_path: Path) -> Path:
    """
    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
    optimizations possible

    Args:
        onnx_model_path: 模型二进制描述文件的路径

    Returns: 优化后的模型二进制描述文件保存的路径

    """
    from onnxruntime import InferenceSession, SessionOptions

    # 生成带有后缀 "-optimized" 的优化模型文件名
    opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
    sess_option = SessionOptions()
    # 设置优化后的模型文件路径
    sess_option.optimized_model_filepath = opt_model_path.as_posix()
    _ = InferenceSession(onnx_model_path.as_posix(), sess_option)

    # 打印优化后的模型写入路径
    print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
    # 提示优化后的模型包含特定硬件操作符,可能不具备可移植性
    print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")

    return opt_model_path


# 定义一个函数 quantize,用于将模型权重从 float32 量化为 int8,以实现在现代 CPU 上高效推断
def quantize(onnx_model_path: Path) -> Path:
    """
    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU

    Args:
        onnx_model_path: 模型二进制描述文件的路径

    Returns: 量化后的模型二进制描述文件保存的路径

    """
    # 函数体未完,暂时省略
    # 导入必要的库和模块
    import onnx
    import onnxruntime
    from onnx.onnx_pb import ModelProto
    from onnxruntime.quantization import QuantizationMode
    from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
    from onnxruntime.quantization.registry import IntegerOpsRegistry

    # 加载指定路径下的 ONNX 模型
    onnx_model = onnx.load(onnx_model_path.as_posix())

    # 检查 ONNX 版本是否小于 1.5.0,提示模型大小限制问题
    if parse(onnx.__version__) < parse("1.5.0"):
        print(
            "Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
            "Please upgrade to onnxruntime >= 1.5.0."
        )

    # 创建 ONNX 模型的副本
    copy_model = ModelProto()
    copy_model.CopyFrom(onnx_model)

    # 构造量化器
    # 检查 ONNX Runtime 版本,根据版本选择合适的量化器参数设置
    if parse(onnxruntime.__version__) < parse("1.13.1"):
        quantizer = ONNXQuantizer(
            model=copy_model,
            per_channel=False,
            reduce_range=False,
            mode=QuantizationMode.IntegerOps,
            static=False,
            weight_qType=True,
            input_qType=False,
            tensors_range=None,
            nodes_to_quantize=None,
            nodes_to_exclude=None,
            op_types_to_quantize=list(IntegerOpsRegistry),
        )
    else:
        quantizer = ONNXQuantizer(
            model=copy_model,
            per_channel=False,
            reduce_range=False,
            mode=QuantizationMode.IntegerOps,
            static=False,
            weight_qType=True,
            activation_qType=False,
            tensors_range=None,
            nodes_to_quantize=None,
            nodes_to_exclude=None,
            op_types_to_quantize=list(IntegerOpsRegistry),
        )

    # 执行模型量化
    quantizer.quantize_model()

    # 生成量化后模型的文件名,并在原模型文件名末尾添加 "-quantized" 后缀
    quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")

    # 保存量化后的模型
    print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
    onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())

    # 返回量化后模型的路径
    return quantized_model_path
def verify(path: Path):
    # 引入需要的库和模块
    from onnxruntime import InferenceSession, SessionOptions
    from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException

    # 打印正在加载的 ONNX 模型路径
    print(f"Checking ONNX model loading from: {path} ...")
    try:
        # 设置 ONNX 运行时的选项
        onnx_options = SessionOptions()
        # 创建推理会话,加载模型并指定 CPU 执行提供者
        _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
        # 打印模型加载成功的消息
        print(f"Model {path} correctly loaded: \N{heavy check mark}")
    except RuntimeException as re:
        # 捕获模型加载时的异常并打印错误消息
        print(f"Error while loading the model {re}: \N{heavy ballot x}")


if __name__ == "__main__":
    # 解析命令行参数
    parser = OnnxConverterArgumentParser()
    args = parser.parse_args()

    # 确保输出路径为绝对路径
    args.output = Path(args.output).absolute()

    try:
        print("\n====== Converting model to ONNX ======")
        # 执行模型转换
        convert(
            args.framework,
            args.model,
            args.output,
            args.opset,
            args.tokenizer,
            args.use_external_format,
            args.pipeline,
        )

        if args.quantize:
            # 确保满足 quantization 在 onnxruntime 上的要求
            check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)

            # 对于 TensorFlow 框架,性能优化不如 PyTorch 显著
            if args.framework == "tf":
                print(
                    "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
                    "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
                    "\t For more information, please refer to the onnxruntime documentation:\n"
                    "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
                )

            print("\n====== Optimizing ONNX model ======")

            # 对优化后的模型进行量化
            args.optimized_output = optimize(args.output)

            # 在正确的图上执行量化
            args.quantized_output = quantize(args.optimized_output)

        # 验证转换后的模型
        if args.check_loading:
            print("\n====== Check exported ONNX model(s) ======")
            verify(args.output)

            if hasattr(args, "optimized_output"):
                verify(args.optimized_output)

            if hasattr(args, "quantized_output"):
                verify(args.quantized_output)

    except Exception as e:
        # 捕获转换过程中的异常并打印错误消息
        print(f"Error while converting the model: {e}")
        exit(1)

.\convert_pytorch_checkpoint_to_tf2.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert pytorch checkpoints to TensorFlow"""


import argparse  # 导入解析命令行参数的模块
import os  # 导入操作系统功能的模块

from . import (  # 导入当前包中的模块和符号
    ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    BART_PRETRAINED_MODEL_ARCHIVE_LIST,
    BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
    DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
    DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
    DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
    ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
    FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
    LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
    LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
    T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
    TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
    WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
    XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
    XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
    XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
    AlbertConfig,
    BartConfig,
    BertConfig,
    CamembertConfig,
    CTRLConfig,
    DistilBertConfig,
    DPRConfig,
    ElectraConfig,
    FlaubertConfig,
    GPT2Config,
    LayoutLMConfig,
    LxmertConfig,
    OpenAIGPTConfig,
    RobertaConfig,
    T5Config,
    TFAlbertForPreTraining,
    TFBartForConditionalGeneration,
    TFBartForSequenceClassification,
    TFBertForPreTraining,
    TFBertForQuestionAnswering,
    TFBertForSequenceClassification,
    TFCamembertForMaskedLM,
    TFCTRLLMHeadModel,
    TFDistilBertForMaskedLM,
    TFDistilBertForQuestionAnswering,
    TFDPRContextEncoder,
    TFDPRQuestionEncoder,
    TFDPRReader,
    TFElectraForPreTraining,
    TFFlaubertWithLMHeadModel,
    TFGPT2LMHeadModel,
    TFLayoutLMForMaskedLM,
    TFLxmertForPreTraining,
    TFLxmertVisualFeatureEncoder,
    TFOpenAIGPTLMHeadModel,
    TFRobertaForCausalLM,
    TFRobertaForMaskedLM,
    TFRobertaForSequenceClassification,
    TFT5ForConditionalGeneration,
    TFTransfoXLLMHeadModel,
    TFWav2Vec2Model,
    TFXLMRobertaForMaskedLM,
    TFXLMWithLMHeadModel,
    TFXLNetLMHeadModel,
    TransfoXLConfig,
    Wav2Vec2Config,
    Wav2Vec2Model,
    XLMConfig,
    XLMRobertaConfig,
    XLNetConfig,
    is_torch_available,
    load_pytorch_checkpoint_in_tf2_model,
)
# 从当前包的utils模块中导入所需的符号:CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging

# 如果torch可用,导入必要的模块:numpy和torch
if is_torch_available():
    import numpy as np
    import torch

    # 从当前包中导入多个模型类
    from . import (
        AlbertForPreTraining,
        BartForConditionalGeneration,
        BertForPreTraining,
        BertForQuestionAnswering,
        BertForSequenceClassification,
        CamembertForMaskedLM,
        CTRLLMHeadModel,
        DistilBertForMaskedLM,
        DistilBertForQuestionAnswering,
        DPRContextEncoder,
        DPRQuestionEncoder,
        DPRReader,
        ElectraForPreTraining,
        FlaubertWithLMHeadModel,
        GPT2LMHeadModel,
        LayoutLMForMaskedLM,
        LxmertForPreTraining,
        LxmertVisualFeatureEncoder,
        OpenAIGPTLMHeadModel,
        RobertaForMaskedLM,
        RobertaForSequenceClassification,
        T5ForConditionalGeneration,
        TransfoXLLMHeadModel,
        XLMRobertaForMaskedLM,
        XLMWithLMHeadModel,
        XLNetLMHeadModel,
    )

    # 从pytorch_utils模块中导入is_torch_greater_or_equal_than_1_13函数
    from .pytorch_utils import is_torch_greater_or_equal_than_1_13

# 设置日志记录的详细程度为INFO级别
logging.set_verbosity_info()

# 定义模型类的映射字典,键为模型名称,值为元组,包含相应模型的配置类、TF/PyTorch模型类、预训练模型类以及预训练模型的存档列表
MODEL_CLASSES = {
    "bart": (
        BartConfig,
        TFBartForConditionalGeneration,
        TFBartForSequenceClassification,
        BartForConditionalGeneration,
        BART_PRETRAINED_MODEL_ARCHIVE_LIST,
    ),
    "bert": (
        BertConfig,
        TFBertForPreTraining,
        BertForPreTraining,
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
        BertConfig,
        TFBertForQuestionAnswering,
        BertForQuestionAnswering,
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
        BertConfig,
        TFBertForQuestionAnswering,
        BertForQuestionAnswering,
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "google-bert/bert-base-cased-finetuned-mrpc": (
        BertConfig,
        TFBertForSequenceClassification,
        BertForSequenceClassification,
        BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "dpr": (
        DPRConfig,
        TFDPRQuestionEncoder,
        TFDPRContextEncoder,
        TFDPRReader,
        DPRQuestionEncoder,
        DPRContextEncoder,
        DPRReader,
        DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
        DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
        DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
    ),
    "openai-community/gpt2": (
        GPT2Config,
        TFGPT2LMHeadModel,
        GPT2LMHeadModel,
        GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "xlnet": (
        XLNetConfig,
        TFXLNetLMHeadModel,
        XLNetLMHeadModel,
        XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "xlm": (
        XLMConfig,
        TFXLMWithLMHeadModel,
        XLMWithLMHeadModel,
        XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
    ),
    "xlm-roberta": (
        XLMRobertaConfig,                           # XLMRoberta 模型的配置类
        TFXLMRobertaForMaskedLM,                    # 用于 TensorFlow 的 XLMRoberta 语言模型(MLM)
        XLMRobertaForMaskedLM,                      # XLMRoberta 语言模型(MLM)
        XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,  # XLMRoberta 预训练模型配置文件的映射
    ),
    "transfo-xl": (
        TransfoXLConfig,                            # TransfoXL 模型的配置类
        TFTransfoXLLMHeadModel,                     # 用于 TensorFlow 的 TransfoXL 语言模型头部
        TransfoXLLMHeadModel,                       # TransfoXL 语言模型头部
        TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,   # TransfoXL 预训练模型配置文件的映射
    ),
    "openai-community/openai-gpt": (
        OpenAIGPTConfig,                            # OpenAI GPT 模型的配置类
        TFOpenAIGPTLMHeadModel,                     # 用于 TensorFlow 的 OpenAI GPT 语言模型头部
        OpenAIGPTLMHeadModel,                       # OpenAI GPT 语言模型头部
        OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,   # OpenAI GPT 预训练模型配置文件的映射
    ),
    "roberta": (
        RobertaConfig,                              # Roberta 模型的配置类
        TFRobertaForCausalLM,                       # 用于 TensorFlow 的 Roberta 因果语言模型
        TFRobertaForMaskedLM,                       # 用于 TensorFlow 的 Roberta 语言模型(MLM)
        RobertaForMaskedLM,                         # Roberta 语言模型(MLM)
        ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,       # Roberta 预训练模型配置文件的映射
    ),
    "layoutlm": (
        LayoutLMConfig,                             # LayoutLM 模型的配置类
        TFLayoutLMForMaskedLM,                      # 用于 TensorFlow 的 LayoutLM 语言模型(MLM)
        LayoutLMForMaskedLM,                        # LayoutLM 语言模型(MLM)
        LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,     # LayoutLM 预训练模型的存档列表
    ),
    "FacebookAI/roberta-large-mnli": (
        RobertaConfig,                              # Roberta 模型的配置类
        TFRobertaForSequenceClassification,          # 用于 TensorFlow 的 Roberta 序列分类模型
        RobertaForSequenceClassification,           # Roberta 序列分类模型
        ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,       # Roberta 预训练模型配置文件的映射
    ),
    "camembert": (
        CamembertConfig,                            # Camembert 模型的配置类
        TFCamembertForMaskedLM,                     # 用于 TensorFlow 的 Camembert 语言模型(MLM)
        CamembertForMaskedLM,                       # Camembert 语言模型(MLM)
        CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,     # Camembert 预训练模型配置文件的映射
    ),
    "flaubert": (
        FlaubertConfig,                             # Flaubert 模型的配置类
        TFFlaubertWithLMHeadModel,                  # 用于 TensorFlow 的 Flaubert 语言模型头部
        FlaubertWithLMHeadModel,                    # Flaubert 语言模型头部
        FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,      # Flaubert 预训练模型配置文件的映射
    ),
    "distilbert": (
        DistilBertConfig,                           # DistilBERT 模型的配置类
        TFDistilBertForMaskedLM,                    # 用于 TensorFlow 的 DistilBERT 语言模型(MLM)
        DistilBertForMaskedLM,                      # DistilBERT 语言模型(MLM)
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,    # DistilBERT 预训练模型配置文件的映射
    ),
    "distilbert-base-distilled-squad": (
        DistilBertConfig,                           # DistilBERT 模型的配置类
        TFDistilBertForQuestionAnswering,           # 用于 TensorFlow 的 DistilBERT 问答模型
        DistilBertForQuestionAnswering,             # DistilBERT 问答模型
        DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,    # DistilBERT 预训练模型配置文件的映射
    ),
    "lxmert": (
        LxmertConfig,                               # LXMERT 模型的配置类
        TFLxmertForPreTraining,                     # 用于 TensorFlow 的 LXMERT 预训练模型
        LxmertForPreTraining,                       # LXMERT 预训练模型
        LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,        # LXMERT 预训练模型配置文件的映射
    ),
    "lxmert-visual-feature-encoder": (
        LxmertConfig,                               # LXMERT 模型的配置类
        TFLxmertVisualFeatureEncoder,               # 用于 TensorFlow 的 LXMERT 视觉特征编码器
        LxmertVisualFeatureEncoder,                 # LXMERT 视觉特征编码器
        LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,        # LXMERT 预训练模型配置文件的映射
    ),
    "Salesforce/ctrl": (
        CTRLConfig,                                 # CTRL 模型的配置类
        TFCTRLLMHeadModel,                          # 用于 TensorFlow 的 CTRL 语言模型头部
        CTRLLMHeadModel,                            # CTRL 语言模型头部
        CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,          # CTRL 预训练模型配置文件的映射
    ),
    "albert": (
        AlbertConfig,                               # ALBERT 模型的配置类
        TFAlbertForPreTraining,                     # 用于 TensorFlow 的 ALBERT 预训练模型
        AlbertForPreTraining,                       # ALBERT 预训练模型
        ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,        # ALBERT 预训练模型配置文件的映射
    ),
    "t5": (
        T5Config,                                   # T5 模型的配置类
        TFT5ForConditionalGeneration,               # 用于 TensorFlow 的 T5 条件生成模型
        T5ForConditionalGeneration,                 # T5 条件生成模型
        T5_PRETRAINED_CONFIG_ARCHIVE_MAP,            # T5 预训练模型配置文件的映射
    ),
    "electra": (
        ElectraConfig,                              # Electra 模型的配置类
        TFElectraForPreTraining,                    # 用于 TensorFlow 的 Electra 预训练模型
        ElectraForPreTraining,                      # Electra 预训练模型
        ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,       # Electra 预训练模型配置文件的映射
    ),
    "wav2vec2": (
        Wav2Vec2Config,                             # Wav2Vec2 模型的配置类
        TFWav2Vec2Model,                            # 用于 TensorFlow 的 Wav2Vec2 模型
        Wav2Vec2Model,                              # Wav2Vec2 模型
        WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,   # Wav2Vec2 预训练模型配置文件的映射
    ),
}

# 将 PyTorch 检查点转换为 TensorFlow 格式
def convert_pt_checkpoint_to_tf(
    model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
):
    # 检查模型类型是否在已知的模型类别中
    if model_type not in MODEL_CLASSES:
        raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")

    # 根据模型类型获取相应的类别信息
    config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]

    # 初始化 TensorFlow 模型
    if config_file in aws_config_map:
        # 如果配置文件在 AWS 配置映射中,可能需要缓存或下载配置文件
        config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
    config = config_class.from_json_file(config_file)
    config.output_hidden_states = True
    config.output_attentions = True
    print(f"Building TensorFlow model from configuration: {config}")
    tf_model = model_class(config)

    # 从 TensorFlow 检查点加载权重
    if pytorch_checkpoint_path in aws_config_map.keys():
        # 如果 PyTorch 检查点路径在 AWS 配置映射中,可能需要缓存或下载检查点文件
        pytorch_checkpoint_path = cached_file(
            pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
        )
    # 将 PyTorch 检查点加载到 TensorFlow 模型中
    tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)

    # 如果需要与 PyTorch 模型进行比较
    if compare_with_pt_model:
        # 构建 TensorFlow 模型以获取网络结构
        tfo = tf_model(tf_model.dummy_inputs, training=False)

        # 根据 Torch 版本选择权重仅参数
        weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
        # 从 PyTorch 检查点加载状态字典
        state_dict = torch.load(
            pytorch_checkpoint_path,
            map_location="cpu",
            **weights_only_kwarg,
        )
        # 使用 PyTorch 模型类从预训练模型名称或路径加载模型
        pt_model = pt_model_class.from_pretrained(
            pretrained_model_name_or_path=None, config=config, state_dict=state_dict
        )

        # 通过禁用梯度计算运行 PyTorch 模型
        with torch.no_grad():
            pto = pt_model(**pt_model.dummy_inputs)

        # 转换为 NumPy 数组并计算模型输出的最大绝对差异
        np_pt = pto[0].numpy()
        np_tf = tfo[0].numpy()
        diff = np.amax(np.abs(np_pt - np_tf))
        print(f"Max absolute difference between models outputs {diff}")
        # 断言模型的最大绝对差异是否在可接受范围内
        assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"

    # 保存 PyTorch 模型的权重
    print(f"Save TensorFlow model to {tf_dump_path}")
    tf_model.save_weights(tf_dump_path, save_format="h5")


# 将所有 PyTorch 检查点转换为 TensorFlow 格式
def convert_all_pt_checkpoints_to_tf(
    args_model_type,
    tf_dump_path,
    model_shortcut_names_or_path=None,
    config_shortcut_names_or_path=None,
    compare_with_pt_model=False,
    use_cached_models=False,
    remove_cached_files=False,
    only_convert_finetuned_models=False,
):
    # 如果未提供模型类型参数,则使用所有已知模型类型
    if args_model_type is None:
        model_types = list(MODEL_CLASSES.keys())
    else:
        model_types = [args_model_type]
    # 对于每个模型类型,在循环中进行迭代,使用 enumerate 函数生成索引和元素
    for j, model_type in enumerate(model_types, start=1):
        # 打印分隔线,用于区分不同模型类型的输出
        print("=" * 100)
        # 打印当前转换的模型类型信息,包括总数和当前处理的序号
        print(f" Converting model type {j}/{len(model_types)}: {model_type}")
        # 打印分隔线,用于区分不同模型类型的输出
        print("=" * 100)
        
        # 如果当前模型类型不在预定义的模型类别中,则抛出数值错误异常
        if model_type not in MODEL_CLASSES:
            raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")

        # 从预定义的模型映射中获取配置类、模型类、PyTorch模型类以及AWS相关映射信息
        config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]

        # 如果未提供模型路径或名称,则使用AWS模型映射中的名称列表作为默认值
        if model_shortcut_names_or_path is None:
            model_shortcut_names_or_path = list(aws_model_maps.keys())
        # 如果未提供配置路径或名称,则使用模型快捷名称列表作为默认值
        if config_shortcut_names_or_path is None:
            config_shortcut_names_or_path = model_shortcut_names_or_path

        # 对于每个模型快捷名称和配置快捷名称的组合,使用 zip 函数生成索引和元素
        for i, (model_shortcut_name, config_shortcut_name) in enumerate(
            zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
        ):
            # 打印分隔线,用于区分不同模型转换过程的输出
            print("-" * 100)
            
            # 如果模型快捷名称中包含特定字符串(如"-squad"、"-mrpc"、"-mnli"),并且不是仅转换微调模型,则跳过当前模型的转换
            if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
                if not only_convert_finetuned_models:
                    # 打印跳过信息,指示未转换的微调模型
                    print(f"    Skipping finetuned checkpoint {model_shortcut_name}")
                    continue
                # 将模型类型设为当前模型的名称,用于后续转换过程
                model_type = model_shortcut_name
            elif only_convert_finetuned_models:
                # 如果仅转换微调模型选项为真,则跳过非微调模型的转换
                print(f"    Skipping not finetuned checkpoint {model_shortcut_name}")
                continue
            
            # 打印当前转换的检查点信息,包括总数和当前处理的序号,以及模型快捷名称和模型类型
            print(f"    Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}")
            # 打印分隔线,用于区分不同模型转换过程的输出
            print("-" * 100)

            # 如果配置快捷名称存在于AWS配置映射中,则根据配置快捷名称下载配置文件
            if config_shortcut_name in aws_config_map:
                config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
            else:
                # 否则,将配置快捷名称作为配置文件名
                config_file = config_shortcut_name

            # 如果模型快捷名称存在于AWS模型映射中,则根据模型快捷名称下载模型权重文件
            if model_shortcut_name in aws_model_maps:
                model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
            else:
                # 否则,将模型快捷名称作为模型文件名
                model_file = model_shortcut_name

            # 如果模型快捷名称对应的文件已存在,则将模型快捷名称设为"converted_model"
            if os.path.isfile(model_shortcut_name):
                model_shortcut_name = "converted_model"

            # 调用转换函数,将PyTorch模型检查点转换为TensorFlow模型
            convert_pt_checkpoint_to_tf(
                model_type=model_type,
                pytorch_checkpoint_path=model_file,
                config_file=config_file,
                tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
                compare_with_pt_model=compare_with_pt_model,
            )
            # 如果设定移除缓存文件选项为真,则删除配置文件和模型文件
            if remove_cached_files:
                os.remove(config_file)
                os.remove(model_file)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # 创建参数解析器

    # 必选参数
    parser.add_argument(
        "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
    )
    # 输出TensorFlow转储文件的路径

    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        help=(
            f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
            "convert all the models from AWS."
        ),
    )
    # 模型类型,可以选择预定义的模型类别或者从AWS下载转换所有模型

    parser.add_argument(
        "--pytorch_checkpoint_path",
        default=None,
        type=str,
        help=(
            "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
            "If not given, will download and convert all the checkpoints from AWS."
        ),
    )
    # PyTorch检查点文件路径或者从AWS下载的快捷名称

    parser.add_argument(
        "--config_file",
        default=None,
        type=str,
        help=(
            "The config json file corresponding to the pre-trained model. \n"
            "This specifies the model architecture. If not given and "
            "--pytorch_checkpoint_path is not given or is a shortcut name "
            "use the configuration associated to the shortcut name on the AWS"
        ),
    )
    # 预训练模型对应的配置文件,用于指定模型架构

    parser.add_argument(
        "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
    )
    # 比较TensorFlow和PyTorch模型预测结果

    parser.add_argument(
        "--use_cached_models",
        action="store_true",
        help="Use cached models if possible instead of updating to latest checkpoint versions.",
    )
    # 如果可能的话使用缓存的模型,而不是更新到最新的检查点版本

    parser.add_argument(
        "--remove_cached_files",
        action="store_true",
        help="Remove pytorch models after conversion (save memory when converting in batches).",
    )
    # 在转换完成后删除PyTorch模型文件,以节省内存(批量转换时)

    parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
    # 只转换微调过的模型

    args = parser.parse_args()

    # if args.pytorch_checkpoint_path is not None:
    #     convert_pt_checkpoint_to_tf(args.model_type.lower(),
    #                                 args.pytorch_checkpoint_path,
    #                                 args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
    #                                 args.tf_dump_path,
    #                                 compare_with_pt_model=args.compare_with_pt_model,
    #                                 use_cached_models=args.use_cached_models)
    # else:
    # 转换所有的 PyTorch 检查点到 TensorFlow 格式
    convert_all_pt_checkpoints_to_tf(
        # 将模型类型参数转换为小写,如果未提供则为 None
        args.model_type.lower() if args.model_type is not None else None,
        # TensorFlow 转换后的输出路径
        args.tf_dump_path,
        # 模型的快捷方式名称或路径的列表,如果提供了 PyTorch 检查点路径则作为单个元素传递,否则为 None
        model_shortcut_names_or_path=[args.pytorch_checkpoint_path] if args.pytorch_checkpoint_path is not None else None,
        # 配置文件的快捷方式名称或路径的列表,如果提供了配置文件路径则作为单个元素传递,否则为 None
        config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
        # 是否与 PyTorch 模型进行比较
        compare_with_pt_model=args.compare_with_pt_model,
        # 是否使用缓存的模型(如果可用)
        use_cached_models=args.use_cached_models,
        # 是否删除缓存的文件
        remove_cached_files=args.remove_cached_files,
        # 是否仅转换微调过的模型
        only_convert_finetuned_models=args.only_convert_finetuned_models,
    )

.\convert_slow_tokenizer.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to convert slow tokenizers in their fast tokenizers counterparts.

All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
allow to make our dependency on SentencePiece optional.
"""

import warnings
from typing import Dict, List, Tuple

from packaging import version
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece

from .utils import is_protobuf_available, requires_backends
from .utils.import_utils import PROTOBUF_IMPORT_ERROR


def import_protobuf(error_message=""):
    # 检查是否可以导入 protobuf 库
    if is_protobuf_available():
        import google.protobuf

        # 如果 protobuf 版本低于 4.0.0,则使用旧版的 sentencepiece_model_pb2
        if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
            from transformers.utils import sentencepiece_model_pb2
        else:
            from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2

        # 返回 sentencepiece_model_pb2 模块
        return sentencepiece_model_pb2
    else:
        # 如果无法导入 protobuf,则抛出 ImportError 异常
        raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))


class SentencePieceExtractor:
    """
    Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
    """

    def __init__(self, model: str):
        # 检查是否已经导入 sentencepiece 库,如果没有则引发异常
        requires_backends(self, "sentencepiece")
        # 导入 SentencePieceProcessor 类
        from sentencepiece import SentencePieceProcessor

        # 创建 SentencePieceProcessor 实例并加载模型
        self.sp = SentencePieceProcessor()
        self.sp.Load(model)
    def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
        """
        By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
        order the merges with respect to the piece scores instead.
        """
        # 获取 SentencePiece 对象的实例
        sp = self.sp
        # 创建一个字典,将每个索引映射到对应的 Piece(词汇)
        vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}

        # 根据传入的 vocab_scores 是否为 None 来决定使用哪种排序方式
        if vocab_scores is not None:
            # 如果 vocab_scores 不为 None,则将其转换为字典并设置 reverse 为 True
            vocab_scores, reverse = dict(vocab_scores), True
        else:
            # 如果 vocab_scores 为 None,则使用默认的 vocab,并将 reverse 设置为 False
            vocab_scores, reverse = vocab, False

        # Merges(合并操作)
        merges = []
        # 遍历 vocab_scores 中的每个 merge 和对应的 piece_score
        for merge, piece_score in vocab_scores.items():
            local = []
            # 将 merge 分解为 piece_l 和 piece_r 的组合,并检查其在 vocab 中是否存在
            for index in range(1, len(merge)):
                piece_l, piece_r = merge[:index], merge[index:]
                if piece_l in vocab and piece_r in vocab:
                    local.append((piece_l, piece_r, piece_score))
            # 对 local 按照 vocab 中 piece_l 和 piece_r 的索引排序
            local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
            # 将 local 的内容扩展到 merges 列表中
            merges.extend(local)

        # 按照 piece_score 进行降序排序,并转换为 (piece_l, piece_r) 的形式
        merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
        merges = [(val[0], val[1]) for val in merges]
        # 返回 vocab 和 merges
        return vocab, merges
class GemmaSentencePieceExtractor(SentencePieceExtractor):
    # GemmaSentencePieceExtractor 类继承自 SentencePieceExtractor,用于实现定制的 SentencePiece 提取器

    def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
        """
        By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
        order the merges with respect to the piece scores instead.
        """
        # extract 方法用于从 SentencePiece 模型中提取词汇表和合并列表
        sp = self.sp  # 获取 SentencePiece 对象
        vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
        # 根据索引从 SentencePiece 对象中获取词汇,并创建词汇到索引的映射字典

        # 补充一个缺失的特殊词汇 "<0x09>" 作为 "\t" 的字节回退表示
        vocab["\t"] = vocab.pop("<0x09>")
        
        if vocab_scores is not None:
            vocab_scores, reverse = dict(vocab_scores), True
        else:
            vocab_scores, reverse = vocab, False

        # Merges
        merges = []
        for merge, piece_score in vocab_scores.items():
            local = []
            for index in range(1, len(merge)):
                piece_l, piece_r = merge[:index], merge[index:]
                if piece_l in vocab and piece_r in vocab:
                    local.append((piece_l, piece_r, piece_score))
            local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
            merges.extend(local)

        merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
        merges = [(val[0], val[1]) for val in merges]
        return vocab, merges
        # 返回提取后的词汇表和合并列表


def check_number_comma(piece: str) -> bool:
    # check_number_comma 函数用于检查给定的字符串 piece 是否满足特定条件(结尾不是逗号或倒数第二个字符不是数字)
    return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()


class Converter:
    # Converter 类用作基类,包含一些基本结构但没有实现具体的转换逻辑

    def __init__(self, original_tokenizer):
        # 初始化方法,接受一个原始的 tokenizer 并存储在实例变量中
        self.original_tokenizer = original_tokenizer

    def converted(self) -> Tokenizer:
        # converted 方法声明但未实现,用于子类重写实现具体的转换逻辑
        raise NotImplementedError()


class BertConverter(Converter):
    # BertConverter 类继承自 Converter 类,用于实现针对 Bert 模型的具体转换逻辑
    # 定义一个方法,用于将原始的 tokenizer 转换为新的 Tokenizer 对象,并返回
    def converted(self) -> Tokenizer:
        # 获取原始 tokenizer 的词汇表
        vocab = self.original_tokenizer.vocab
        # 使用 WordPiece 模型和未知标记符初始化新的 Tokenizer 对象
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))

        # 初始化用于标准化文本的参数,默认为 False
        tokenize_chinese_chars = False
        strip_accents = False
        do_lower_case = False
        
        # 如果原始 tokenizer 中包含 basic_tokenizer 属性,则从中获取相关参数值
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

        # 设置新 Tokenizer 对象的标准化器为 BertNormalizer,使用指定的参数
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )
        
        # 设置新 Tokenizer 对象的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

        # 从原始 tokenizer 中获取特殊标记的字符串表示和标记 ID
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设置新 Tokenizer 对象的后处理器为 TemplateProcessing,指定单句和双句处理模板及特殊标记
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:0 $A:0 {sep}:0",
            pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
            ],
        )

        # 设置新 Tokenizer 对象的解码器为 WordPiece 解码器,前缀为 '##'
        tokenizer.decoder = decoders.WordPiece(prefix="##")

        # 返回配置好的新 Tokenizer 对象
        return tokenizer
# 定义一个名为 SplinterConverter 的类,它继承自 Converter 类
class SplinterConverter(Converter):
    
    # 重写父类方法 converted,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        
        # 获取原始分词器的词汇表
        vocab = self.original_tokenizer.vocab
        
        # 使用 WordPiece 模型和未知标记初始化 Tokenizer 对象
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
        
        # 初始化变量用于存储是否分词中文字符、是否去除重音符号、是否小写化的标志
        tokenize_chinese_chars = False
        strip_accents = False
        do_lower_case = False
        
        # 如果原始分词器具有 basic_tokenizer 属性,获取其相应的属性值
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
        
        # 设置 Tokenizer 的 normalizer 为 BertNormalizer 对象,配置各项参数
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )
        
        # 设置 Tokenizer 的 pre_tokenizer 为 BertPreTokenizer 对象
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
        
        # 获取特殊标记(如 CLS、SEP、QUESTION 和 DOT)的字符串形式和对应的标记 ID
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        question = str(self.original_tokenizer.question_token)
        dot = "."
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id
        question_token_id = self.original_tokenizer.question_token_id
        
        # 使用原始分词器将 DOT 转换为其对应的标记 ID
        dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
        
        # 根据原始分词器的填充位置确定 pair 的模板字符串
        if self.original_tokenizer.padding_side == "right":
            pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
        else:
            pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
        
        # 设置 Tokenizer 的 post_processor 为 TemplateProcessing 对象,配置各项参数
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:0 $A:0 {sep}:0",
            pair=pair,
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
                (question, question_token_id),
                (dot, dot_token_id),
            ],
        )
        
        # 设置 Tokenizer 的 decoder 为 WordPiece 对象,配置前缀为 "##"
        tokenizer.decoder = decoders.WordPiece(prefix="##")
        
        # 返回配置好的 Tokenizer 对象
        return tokenizer
    # 定义一个方法,用于将原始的分词器转换为新的 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始分词器的词汇表
        vocab = self.original_tokenizer.vocab
        # 使用 WordPiece 模型和未知标记来初始化 Tokenizer 对象
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))

        # 初始化用于标准化文本的参数
        tokenize_chinese_chars = False
        strip_accents = False
        do_lower_case = False
        # 检查原始分词器是否有基本分词器属性,设置标志位
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

        # 设置 Tokenizer 对象的标准化器为 BertNormalizer,配置参数包括是否清洗文本、处理中文字符、去除重音、小写化
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )
        # 设置 Tokenizer 对象的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

        # 获取原始分词器的特殊标记(例如 [CLS] 和 [SEP])
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设置 Tokenizer 对象的后处理器为 TemplateProcessing,配置单句和双句模板及特殊标记
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:2 $A:0 {sep}:0",  # token_type_id is 2 for Funnel transformer
            pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
            ],
        )
        # 设置 Tokenizer 对象的解码器为 WordPiece 解码器,前缀为 "##"
        tokenizer.decoder = decoders.WordPiece(prefix="##")

        # 返回配置完成的 Tokenizer 对象
        return tokenizer
class MPNetConverter(Converter):
    # MPNetConverter 类继承自 Converter 类,用于将原始 tokenizer 转换为 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始 tokenizer 的词汇表
        vocab = self.original_tokenizer.vocab
        # 创建一个 Tokenizer 对象,使用 WordPiece 模型,设置未知标记为原始 tokenizer 的未知标记
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))

        # 初始化一些变量用于记录是否执行特定的文本清洗和处理操作
        tokenize_chinese_chars = False
        strip_accents = False
        do_lower_case = False
        
        # 检查原始 tokenizer 是否具有 basic_tokenizer 属性,如果有,则更新相关变量
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

        # 设置 tokenizer 的文本清洗器为 BertNormalizer,配置其参数
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )
        # 设置 tokenizer 的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

        # 获取特殊标记的字符串形式
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设置 tokenizer 的后处理器为 TemplateProcessing,配置其参数
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:0 $A:0 {sep}:0",
            pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1",  # MPNet 使用两个 [SEP] 标记
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
            ],
        )
        # 设置 tokenizer 的解码器为 WordPiece 解码器,前缀为 "##"
        tokenizer.decoder = decoders.WordPiece(prefix="##")

        # 返回转换后的 Tokenizer 对象
        return tokenizer


class OpenAIGPTConverter(Converter):
    # OpenAIGPTConverter 类继承自 Converter 类,用于将原始 tokenizer 转换为 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始 tokenizer 的编码器和 BPE 合并列表
        vocab = self.original_tokenizer.encoder
        merges = list(self.original_tokenizer.bpe_ranks.keys())
        unk_token = self.original_tokenizer.unk_token

        # 创建一个 Tokenizer 对象,使用 BPE 模型,设置未知标记为原始 tokenizer 的未知标记
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,
                merges=merges,
                dropout=None,
                unk_token=str(unk_token),
                end_of_word_suffix="</w>",
                fuse_unk=False,
            )
        )

        # 如果 tokenizer 中已经包含原始 tokenizer 的未知标记,则添加到特殊标记列表中
        if tokenizer.token_to_id(str(unk_token)) is not None:
            tokenizer.add_special_tokens([str(unk_token)])

        # 设置 tokenizer 的文本清洗器为 BertNormalizer,只设置小写处理
        tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
        # 设置 tokenizer 的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
        # 设置 tokenizer 的解码器为 BPEDecoder,后缀为 "</w>"
        tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")

        # 返回转换后的 Tokenizer 对象
        return tokenizer


class GPT2Converter(Converter):
    # GPT2Converter 类继承自 Converter 类,用于将原始 tokenizer 转换为 Tokenizer 对象
    # 定义一个方法 converted,返回类型为 Tokenizer
    def converted(self) -> Tokenizer:
        # 获取原始分词器的词汇表
        vocab = self.original_tokenizer.encoder
        # 获取原始分词器的 BPE 合并列表
        merges = list(self.original_tokenizer.bpe_ranks.keys())

        # 创建一个新的 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,  # 设置词汇表
                merges=merges,  # 设置合并列表
                dropout=None,  # 没有 dropout
                continuing_subword_prefix="",  # 子词前缀为空字符串
                end_of_word_suffix="",  # 词尾后缀为空字符串
                fuse_unk=False,  # 不融合未知标记
            )
        )

        # 设置 Tokenizer 的预分词器为 ByteLevel,并根据原始分词器设置是否添加前缀空格
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
        # 设置 Tokenizer 的解码器为 ByteLevel
        tokenizer.decoder = decoders.ByteLevel()

        # 如果原始分词器设置了开始词头(bos),则设置后处理器为 TemplateProcessing
        if self.original_tokenizer.add_bos_token:
            bos = self.original_tokenizer.bos_token
            bos_token_id = self.original_tokenizer.bos_token_id
            tokenizer.post_processor = processors.TemplateProcessing(
                single=f"{bos}:0 $A:0",  # 单句模板,以 bos 开始
                pair=f"{bos}:0 $A:0 $B:1",  # 双句模板,以 bos 开始
                special_tokens=[  # 特殊标记列表,包括 bos 和其对应的 id
                    (bos, bos_token_id),
                ],
            )
        else:
            # 如果没有设置开始词头,设置后处理器为 ByteLevel,trim_offsets=False 表示不修剪偏移量
            tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
        
        # 返回创建的 Tokenizer 对象
        return tokenizer
# HerbertConverter 类,继承自 Converter 类
class HerbertConverter(Converter):
    # 覆盖了 converted 方法,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # tokenizer_info_str 字符串,用于版本信息
        tokenizer_info_str = "#version:"
        # token_suffix 字符串,用于表示词尾
        token_suffix = "</w>"

        # 获取原始分词器的编码器字典
        vocab = self.original_tokenizer.encoder
        # 获取原始分词器的 BPE 合并操作列表
        merges = list(self.original_tokenizer.bpe_ranks.keys())
        
        # 如果 merges 的第一个元素包含 tokenizer_info_str,则从 merges 中移除该元素
        if tokenizer_info_str in merges[0][0]:
            merges = merges[1:]

        # 创建 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab,
                merges,
                dropout=None,
                unk_token=self.original_tokenizer.unk_token,
                end_of_word_suffix=token_suffix,
            )
        )

        # 设置 Tokenizer 对象的正规化器为 BertNormalizer
        tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
        # 设置 Tokenizer 对象的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
        # 设置 Tokenizer 对象的解码器为 BPEDecoder,指定词尾后缀
        tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
        # 设置 Tokenizer 对象的后处理器为 BertProcessing,指定特殊标记
        tokenizer.post_processor = processors.BertProcessing(
            sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
            cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
        )

        # 返回创建的 Tokenizer 对象
        return tokenizer


# Qwen2Converter 类,继承自 Converter 类
class Qwen2Converter(Converter):
    # 覆盖了 converted 方法,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始分词器的编码器字典
        vocab = self.original_tokenizer.encoder
        # 获取原始分词器的 BPE 合并操作列表
        merges = list(self.original_tokenizer.bpe_ranks.keys())

        # 创建 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,
                merges=merges,
                dropout=None,
                unk_token=None,
                continuing_subword_prefix="",
                end_of_word_suffix="",
                fuse_unk=False,
                byte_fallback=False,
            )
        )

        # 设置 Tokenizer 对象的正规化器为 NFC(Unicode 标准化)
        tokenizer.normalizer = normalizers.NFC()

        # 设置 Tokenizer 对象的预处理器为 Sequence,包含两个预处理步骤
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
            [
                # 第一个预处理步骤:使用正则表达式拆分,匹配单词和数字,以及特定标点
                pre_tokenizers.Split(
                    Regex(
                        r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
                    ),
                    behavior="isolated",
                    invert=False,
                ),
                # 第二个预处理步骤:使用 ByteLevel 拆分字节级别的处理
                pre_tokenizers.ByteLevel(
                    add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
                    use_regex=False,
                ),
            ]
        )

        # 设置 Tokenizer 对象的解码器为 ByteLevel
        tokenizer.decoder = decoders.ByteLevel()
        # 设置 Tokenizer 对象的后处理器为 ByteLevel,不修剪偏移量
        tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

        # 返回创建的 Tokenizer 对象
        return tokenizer


# RobertaConverter 类,继承自 Converter 类,该部分代码尚未提供完整
    # 定义一个方法 `converted`,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始的分词器对象
        ot = self.original_tokenizer
        # 获取原始分词器的词汇表
        vocab = ot.encoder
        # 获取原始分词器的合并列表
        merges = list(ot.bpe_ranks.keys())

        # 创建一个新的 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,  # 设置词汇表
                merges=merges,  # 设置合并列表
                dropout=None,  # 不使用 dropout
                continuing_subword_prefix="",  # 设置持续子词前缀为空字符串
                end_of_word_suffix="",  # 设置词尾后缀为空字符串
                fuse_unk=False,  # 不融合未知标记
            )
        )

        # 设置 Tokenizer 的预分词器为 ByteLevel,并保留原始分词器的前缀空格设置
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
        # 设置 Tokenizer 的解码器为 ByteLevel
        tokenizer.decoder = decoders.ByteLevel()
        # 设置 Tokenizer 的后处理器为 RobertaProcessing,设置分隔符、CLS和SEP标记及其ID,同时修剪偏移量
        tokenizer.post_processor = processors.RobertaProcessing(
            sep=(ot.sep_token, ot.sep_token_id),  # 设置分隔符及其ID
            cls=(ot.cls_token, ot.cls_token_id),  # 设置CLS标记及其ID
            add_prefix_space=ot.add_prefix_space,  # 保留原始分词器的前缀空格设置
            trim_offsets=True,  # 在Roberta上默认为True(历史遗留)
        )

        # 返回配置好的 Tokenizer 对象
        return tokenizer
class RoFormerConverter(Converter):
    # RoFormerConverter 类继承自 Converter 类,用于转换器功能
    def converted(self) -> Tokenizer:
        # 返回类型为 Tokenizer 的 converted 方法
        from .models.roformer.tokenization_utils import JiebaPreTokenizer

        # 导入 RoFormer 所需的 JiebaPreTokenizer

        # 获取原始分词器的词汇表
        vocab = self.original_tokenizer.vocab
        # 创建一个 Tokenizer 实例,使用 WordPiece 方法和未知标记
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))

        # 初始化 strip_accents 和 do_lower_case 为 False
        strip_accents = False
        do_lower_case = False
        # 如果原始分词器具有 basic_tokenizer 属性
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            # 获取 strip_accents 和 do_lower_case 的设置
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

        # 设置 tokenizer 的 normalizer 为 BertNormalizer 实例
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=False,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )
        # 设置 tokenizer 的 pre_tokenizer 为 JiebaPreTokenizer 实例
        tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))

        # 获取 cls 和 sep 的字符串表示
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        # 获取 cls_token_id 和 sep_token_id
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设置 tokenizer 的 post_processor 为 TemplateProcessing 实例
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:0 $A:0 {sep}:0",
            pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
            ],
        )
        # 设置 tokenizer 的 decoder 为 WordPiece 实例,前缀为 "##"
        tokenizer.decoder = decoders.WordPiece(prefix="##")

        # 返回设置好的 tokenizer 实例
        return tokenizer


class DebertaConverter(Converter):
    # DebertaConverter 类继承自 Converter 类,用于转换器功能
    def converted(self) -> Tokenizer:
        # 返回类型为 Tokenizer 的 converted 方法
        ot = self.original_tokenizer
        # 获取原始分词器的 encoder 和 bpe_ranks

        # 创建一个 Tokenizer 实例,使用 BPE 方法和给定的参数
        tokenizer = Tokenizer(
            BPE(
                vocab=ot.encoder,
                merges=list(ot.bpe_ranks.keys()),
                dropout=None,
                continuing_subword_prefix="",
                end_of_word_suffix="",
                fuse_unk=False,
            )
        )

        # 设置 tokenizer 的 pre_tokenizer 为 ByteLevel 实例
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
        # 设置 tokenizer 的 decoder 为 ByteLevel 实例
        tokenizer.decoder = decoders.ByteLevel()
        # 设置 tokenizer 的 post_processor 为 TemplateProcessing 实例
        tokenizer.post_processor = processors.TemplateProcessing(
            single="[CLS]:0 $A:0 [SEP]:0",
            pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
            special_tokens=[
                ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
                ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
            ],
        )

        # 返回设置好的 tokenizer 实例
        return tokenizer


class SpmConverter(Converter):
    # SpmConverter 类继承自 Converter 类,用于转换器功能
    # 初始化方法,接受任意数量参数
    def __init__(self, *args):
        # 检查是否需要后端支持protobuf,如果不支持则抛出异常
        requires_backends(self, "protobuf")

        # 调用父类的初始化方法,传入所有参数
        super().__init__(*args)

        # 导入protobuf模型,此处调用import_protobuf函数,返回的模型对象赋值给model_pb2
        model_pb2 = import_protobuf()

        # 创建一个新的ModelProto对象m,从self.original_tokenizer.vocab_file文件中解析数据到m对象
        with open(self.original_tokenizer.vocab_file, "rb") as f:
            m.ParseFromString(f.read())

        # 将解析后的m对象赋值给self.proto
        self.proto = m

        # 如果self.proto.trainer_spec.byte_fallback为True,则进行以下处理
        if self.proto.trainer_spec.byte_fallback:
            # 如果没有定义handle_byte_fallback属性,则发出警告
            if not getattr(self, "handle_byte_fallback", None):
                warnings.warn(
                    "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
                    " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
                    " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
                    "unknown tokens into a sequence of byte tokens matching the original piece of text."
                )

    # 返回proto中pieces属性的列表,每个元素为(piece.piece, piece.score)元组
    def vocab(self, proto):
        return [(piece.piece, piece.score) for piece in proto.pieces]

    # 返回proto中trainer_spec属性的unk_id
    def unk_id(self, proto):
        return proto.trainer_spec.unk_id

    # 根据proto的trainer_spec.model_type选择合适的Tokenizer类型,并返回对应的实例
    def tokenizer(self, proto):
        # 获取model_type值
        model_type = proto.trainer_spec.model_type
        # 获取vocab信息
        vocab_scores = self.vocab(proto)
        # 获取unk_id信息
        unk_id = self.unk_id(proto)

        # 根据model_type的值选择合适的Tokenizer类型
        if model_type == 1:
            # 使用Unigram模型创建Tokenizer实例
            tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
        elif model_type == 2:
            # 从self.original_tokenizer.vocab_file中提取_, merges变量
            _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
            # 创建BPE类型的Tokenizer实例,使用给定的参数
            bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
            tokenizer = Tokenizer(
                BPE(
                    bpe_vocab,
                    merges,
                    unk_token=proto.trainer_spec.unk_piece,
                    fuse_unk=True,
                )
            )
        else:
            # 如果model_type不是1或2,则抛出异常
            raise Exception(
                "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
            )

        # 返回创建的Tokenizer实例
        return tokenizer

    # 根据proto的normalizer_spec属性返回合适的normalizer对象
    def normalizer(self, proto):
        # 获取precompiled_charsmap信息
        precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
        # 定义_normalizers列表,包含两种normalizers对象
        _normalizers = [
            normalizers.Strip(left=False, right=True),  # 去除空格,保留左侧,右侧去除
            normalizers.Replace(Regex(" {2,}"), "▁"),  # 替换多个空格为特殊字符"▁"
        ]
        # 如果没有预编译字符映射,则返回一个Sequence对象,包含_normalizers中的内容
        if not precompiled_charsmap:
            return normalizers.Sequence(_normalizers)
        else:
            # 否则返回一个Sequence对象,包含precompiled_charsmap映射后的内容和_normalizers
            return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)

    # 根据replacement和add_prefix_space创建并返回pre_tokenizers.Metaspace对象
    def pre_tokenizer(self, replacement, add_prefix_space):
        # 初始化prepend_scheme为"always"
        prepend_scheme = "always"
        # 如果self.original_tokenizer存在legacy属性且为False,则设置prepend_scheme为"first"
        if hasattr(self.original_tokenizer, "legacy") and not self.original_tokenizer.legacy:
            prepend_scheme = "first"
        # 返回一个Metaspace对象,使用给定的参数
        return pre_tokenizers.Metaspace(
            replacement=replacement, add_prefix_space=add_prefix_space, prepend_scheme=prepend_scheme
        )
    # 定义一个方法 `post_processor`,返回 `None`
    def post_processor(self):
        return None

    # 定义一个方法 `decoder`,接受 `replacement` 和 `add_prefix_space` 两个参数,返回一个 `decoders.Metaspace` 对象
    def decoder(self, replacement, add_prefix_space):
        return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)

    # 定义一个方法 `converted`,返回一个 `Tokenizer` 对象
    def converted(self) -> Tokenizer:
        # 使用 `self.tokenizer` 类型创建一个 `tokenizer` 对象,使用 `self.proto` 作为参数
        tokenizer = self.tokenizer(self.proto)

        # Tokenizer 组装过程
        # 使用 `self.normalizer` 类型创建一个 `normalizer` 对象,使用 `self.proto` 作为参数
        normalizer = self.normalizer(self.proto)
        if normalizer is not None:
            tokenizer.normalizer = normalizer

        # 设置 `replacement` 和 `add_prefix_space` 的默认值
        replacement = "▁"
        add_prefix_space = True

        # 检查 `self.original_tokenizer` 是否有 `add_prefix_space` 属性,更新 `add_prefix_space` 变量
        if hasattr(self.original_tokenizer, "add_prefix_space"):
            add_prefix_space = self.original_tokenizer.add_prefix_space

        # 使用 `self.pre_tokenizer` 类型创建一个 `pre_tokenizer` 对象,使用 `replacement` 和 `add_prefix_space` 作为参数
        pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
        if pre_tokenizer is not None:
            tokenizer.pre_tokenizer = pre_tokenizer

        # 使用 `self.decoder` 方法创建一个 `decoder` 对象,使用 `replacement` 和 `add_prefix_space` 作为参数
        tokenizer.decoder = self.decoder(replacement, add_prefix_space)

        # 调用 `self.post_processor` 方法获取 `post_processor` 对象
        post_processor = self.post_processor()
        if post_processor:
            tokenizer.post_processor = post_processor

        # 返回最终组装好的 `tokenizer` 对象
        return tokenizer
# AlbertConverter 类,继承自 SpmConverter 类
class AlbertConverter(SpmConverter):
    
    # 重写 vocab 方法,返回一个包含单词片段和分数的列表
    def vocab(self, proto):
        return [
            (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
            for piece in proto.pieces
        ]

    # 重写 normalizer 方法,返回一个正则化序列对象
    def normalizer(self, proto):
        # 列出要应用的正则化器列表
        list_normalizers = [
            normalizers.Replace("``", '"'),
            normalizers.Replace("''", '"'),
        ]
        # 如果不保留重音符号,添加相应的正则化器
        if not self.original_tokenizer.keep_accents:
            list_normalizers.append(normalizers.NFKD())
            list_normalizers.append(normalizers.StripAccents())
        # 如果执行小写化,添加小写化正则化器
        if self.original_tokenizer.do_lower_case:
            list_normalizers.append(normalizers.Lowercase())
        
        # 获取预编译字符映射表
        precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
        
        # 如果存在预编译字符映射表,添加预编译正则化器
        if precompiled_charsmap:
            list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
        
        # 添加空格合并的正则化器
        list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
        
        # 返回正则化序列对象
        return normalizers.Sequence(list_normalizers)

    # 重写 post_processor 方法,返回一个模板处理对象
    def post_processor(self):
        return processors.TemplateProcessing(
            single="[CLS]:0 $A:0 [SEP]:0",
            pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
            special_tokens=[
                ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
                ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
            ],
        )


# BarthezConverter 类,继承自 SpmConverter 类
class BarthezConverter(SpmConverter):
    
    # 重写 unk_id 方法,返回未知标记的 ID
    def unk_id(self, proto):
        unk_id = 3
        return unk_id

    # 重写 post_processor 方法,返回一个模板处理对象
    def post_processor(self):
        return processors.TemplateProcessing(
            single="<s> $A </s>",
            pair="<s> $A </s> </s> $B </s>",
            special_tokens=[
                ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )


# CamembertConverter 类,继承自 SpmConverter 类
class CamembertConverter(SpmConverter):
    
    # 重写 vocab 方法,返回一个词汇表,包含词汇和分数的元组列表
    def vocab(self, proto):
        vocab = [
            ("<s>NOTUSED", 0.0),
            ("<pad>", 0.0),
            ("</s>NOTUSED", 0.0),
            ("<unk>", 0.0),
            ("<unk>NOTUSED", -100),
        ]
        # 将 proto.pieces 中的片段和分数添加到词汇表中
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
        # 添加 "<mask>" 到词汇表中
        vocab += [("<mask>", 0.0)]
        return vocab

    # 重写 unk_id 方法,返回未知标记的 ID
    def unk_id(self, proto):
        # 见 vocab 方法中的 unk 位置
        return 3

    # 重写 post_processor 方法,返回一个模板处理对象
    def post_processor(self):
        return processors.TemplateProcessing(
            single="<s> $A </s>",
            pair="<s> $A </s> </s> $B </s>",
            special_tokens=[
                ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )


# DebertaV2Converter 类,继承自 SpmConverter 类
    # 定义一个预处理器函数,用于生成一个包含预处理器序列的对象
    def pre_tokenizer(self, replacement, add_prefix_space):
        # 初始化一个空列表,用于存储预处理器对象
        list_pretokenizers = []
        # 如果原始分词器支持按标点符号切分,则添加一个按独立标点切分的预处理器
        if self.original_tokenizer.split_by_punct:
            list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
        # 添加一个 Metaspace 预处理器,用于处理元空间
        list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
        # 返回一个预处理器序列对象,其中包含以上构建的预处理器列表
        return pre_tokenizers.Sequence(list_pretokenizers)

    # 定义一个正则化器函数,用于生成一个包含正则化器序列的对象
    def normalizer(self, proto):
        # 初始化一个空列表,用于存储正则化器对象
        list_normalizers = []
        # 如果原始分词器需要进行小写处理,则添加一个小写化正则化器
        if self.original_tokenizer.do_lower_case:
            list_normalizers.append(normalizers.Lowercase())
        # 添加一个去除空格的正则化器
        list_normalizers.append(normalizers.Strip())

        # 获取预编译字符映射表
        precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
        # 如果存在预编译字符映射表,则添加一个预编译正则化器
        if precompiled_charsmap:
            list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
        # 添加一个替换连续空格为单个空格的正则化器
        list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))

        # 返回一个正则化器序列对象,其中包含以上构建的正则化器列表
        return normalizers.Sequence(list_normalizers)

    # 定义一个后处理器函数,用于生成一个模板处理器对象
    def post_processor(self):
        return processors.TemplateProcessing(
            # 单文本处理模板,用特定标记替换各个部分
            single="[CLS]:0 $A:0 [SEP]:0",
            # 双文本处理模板,用特定标记替换各个部分,包括两个分隔符
            pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
            # 特殊标记的映射,将特殊标记与其在原始分词器中对应的 ID 关联起来
            special_tokens=[
                ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
                ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
            ],
        )
# 定义一个名为 MBartConverter 的类,继承自 SpmConverter 类
class MBartConverter(SpmConverter):
    
    # 定义一个名为 vocab 的方法,接收 proto 参数,返回一个词汇表列表
    def vocab(self, proto):
        # 初始化词汇表,包括常见特殊 token 和初始权重
        vocab = [
            ("<s>", 0.0),
            ("<pad>", 0.0),
            ("</s>", 0.0),
            ("<unk>", 0.0),
        ]
        # 将 proto 对象中的子词片段(从第四个开始)添加到词汇表中
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        # 添加特定语言标识符和对应的初始权重
        vocab += [
            ("ar_AR", 0.0),
            ("cs_CZ", 0.0),
            ("de_DE", 0.0),
            ("en_XX", 0.0),
            ("es_XX", 0.0),
            ("et_EE", 0.0),
            ("fi_FI", 0.0),
            ("fr_XX", 0.0),
            ("gu_IN", 0.0),
            ("hi_IN", 0.0),
            ("it_IT", 0.0),
            ("ja_XX", 0.0),
            ("kk_KZ", 0.0),
            ("ko_KR", 0.0),
            ("lt_LT", 0.0),
            ("lv_LV", 0.0),
            ("my_MM", 0.0),
            ("ne_NP", 0.0),
            ("nl_XX", 0.0),
            ("ro_RO", 0.0),
            ("ru_RU", 0.0),
            ("si_LK", 0.0),
            ("tr_TR", 0.0),
            ("vi_VN", 0.0),
            ("zh_CN", 0.0),
        ]
        # 添加一个特殊的 mask 标识符和初始权重
        vocab += [("<mask>", 0.0)]
        # 返回完整的词汇表
        return vocab
    
    # 定义一个名为 unk_id 的方法,接收 proto 参数,返回未知 token 的 id(这里固定为 3)
    def unk_id(self, proto):
        return 3

    # 定义一个名为 post_processor 的方法,返回一个 TemplateProcessing 的处理器对象
    def post_processor(self):
        return processors.TemplateProcessing(
            # 单文本模板,使用 $A 作为占位符,并以 "</s> en_XX" 结尾
            single="$A </s> en_XX",
            # 双文本模板,使用 $A 和 $B 作为占位符,并以 "</s> en_XX" 结尾
            pair="$A $B </s> en_XX",
            # 特殊标记列表,包括 en_XX 和 </s> 的 token 到 id 映射
            special_tokens=[
                ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )


# 定义一个名为 MBart50Converter 的类,继承自 SpmConverter 类
class MBart50Converter(SpmConverter):
    
    # 定义一个名为 vocab 的方法,接收 proto 参数,返回一个词汇表列表
    def vocab(self, proto):
        # 初始化词汇表,包括常见特殊 token 和初始权重
        vocab = [
            ("<s>", 0.0),
            ("<pad>", 0.0),
            ("</s>", 0.0),
            ("<unk>", 0.0),
        ]
        # 将 proto 对象中的子词片段(从第四个开始)添加到词汇表中
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        # 添加多种语言的标识符和对应的初始权重
        vocab += [
            ("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0),
            ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0),
            ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0),
            ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0),
            ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0),
            ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0),
            ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0),
            ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)  # fmt: skip
        ]
        # 添加一个特殊的 mask 标识符和初始权重
        vocab += [("<mask>", 0.0)]
        # 返回完整的词汇表
        return vocab
    
    # 定义一个名为 unk_id 的方法,接收 proto 参数,返回未知 token 的 id(这里固定为 3)
    def unk_id(self, proto):
        return 3
    # 定义一个方法 `post_processor`,用于生成处理器对象
    def post_processor(self):
        # 返回一个模板处理器对象,配置了单句和双句模板以及特殊令牌信息
        return processors.TemplateProcessing(
            single="en_XX $A </s>",  # 单句模板,用 `en_XX $A </s>` 表示
            pair="en_XX $A $B </s>",  # 双句模板,用 `en_XX $A $B </s>` 表示
            special_tokens=[
                # 定义特殊令牌列表,包括 ("en_XX", en_XX 对应的 ID) 和 ("</s>", </s> 对应的 ID)
                ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )
# 定义 NllbConverter 类,继承自 SpmConverter 类
class NllbConverter(SpmConverter):

    # 定义 vocab 方法,接受 proto 参数
    def vocab(self, proto):
        # 初始化词汇表,包括四个特殊标记和 proto 中的 piece 的内容与得分(从第四个 piece 开始)
        vocab = [
            ("<s>", 0.0),
            ("<pad>", 0.0),
            ("</s>", 0.0),
            ("<unk>", 0.0),
        ]
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]  # 添加 proto 中的 piece 的内容与得分
        return vocab  # 返回词汇表

    # 定义 unk_id 方法,接受 proto 参数
    def unk_id(self, proto):
        return 3  # 返回未知标记的 id,这里始终为 3

    # 定义 post_processor 方法
    def post_processor(self):
        # 返回 TemplateProcessing 处理器的实例,用于后处理文本
        return processors.TemplateProcessing(
            single="eng_Latn $A </s>",  # 单句模板
            pair="eng_Latn $A $B </s>",  # 双句模板
            special_tokens=[
                ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),  # 特殊标记:eng_Latn 的 id
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),  # 特殊标记:</s> 的 id
            ],
        )


# 定义 SeamlessM4TConverter 类,继承自 SpmConverter 类
class SeamlessM4TConverter(SpmConverter):

    # 定义 vocab 方法,接受 proto 参数
    def vocab(self, proto):
        # 初始化词汇表,包括四个特殊标记和 proto 中的 piece 的内容与得分(从第四个 piece 开始)
        vocab = [
            ("<pad>", 0.0),
            ("<unk>", 0.0),
            ("<s>", 0.0),
            ("</s>", 0.0),
        ]
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]  # 添加 proto 中的 piece 的内容与得分
        return vocab  # 返回词汇表

    # 定义 unk_id 方法,接受 proto 参数
    def unk_id(self, proto):
        return self.original_tokenizer.unk_token_id  # 返回原始 tokenizer 的未知标记 id

    # 定义 post_processor 方法
    def post_processor(self):
        # 返回 TemplateProcessing 处理器的实例,用于后处理文本
        return processors.TemplateProcessing(
            single="__eng__ $A </s>",  # 单句模板
            pair="__eng__ $A $B </s>",  # 双句模板
            special_tokens=[
                ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),  # 特殊标记:__eng__ 的 id
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),  # 特殊标记:</s> 的 id
            ],
        )


# 定义 XLMRobertaConverter 类,继承自 SpmConverter 类
class XLMRobertaConverter(SpmConverter):

    # 定义 vocab 方法,接受 proto 参数
    def vocab(self, proto):
        # 初始化词汇表,包括五个特殊标记、proto 中的 piece 的内容与得分(从第四个 piece 开始)以及额外的 <mask> 标记
        vocab = [
            ("<s>", 0.0),
            ("<pad>", 0.0),
            ("</s>", 0.0),
            ("<unk>", 0.0),
        ]
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]  # 添加 proto 中的 piece 的内容与得分
        vocab += [("<mask>", 0.0)]  # 添加 <mask> 标记
        return vocab  # 返回词汇表

    # 定义 unk_id 方法,接受 proto 参数
    def unk_id(self, proto):
        unk_id = 3  # 设置未知标记的 id
        return unk_id  # 返回未知标记的 id

    # 定义 post_processor 方法
    def post_processor(self):
        # 返回 TemplateProcessing 处理器的实例,用于后处理文本
        return processors.TemplateProcessing(
            single="<s> $A </s>",  # 单句模板
            pair="<s> $A </s> </s> $B </s>",  # 双句模板
            special_tokens=[
                ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),  # 特殊标记:<s> 的 id
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),  # 特殊标记:</s> 的 id
            ],
        )


# 定义 XLNetConverter 类,继承自 SpmConverter 类
class XLNetConverter(SpmConverter):

    # 定义 vocab 方法,接受 proto 参数
    def vocab(self, proto):
        # 返回根据 piece.piece 是否包含数字或逗号来决定是否减去 100 分的词汇表列表
        return [
            (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
            for piece in proto.pieces
        ]
    # 定义一个方法用于文本规范化处理,接受参数 proto 作为输入
    def normalizer(self, proto):
        # 定义一个列表,包含一系列的文本规范化器,用于处理文本中的特定模式替换
        list_normalizers = [
            normalizers.Replace("``", '"'),  # 替换双反引号为双引号
            normalizers.Replace("''", '"'),  # 替换单反引号为双引号
        ]
        # 如果原始分词器不保留重音符号,添加将 Unicode 数据标准化为分解形式 (NFKD) 的规范化器
        if not self.original_tokenizer.keep_accents:
            list_normalizers.append(normalizers.NFKD())
            # 添加去除重音符号的规范化器
            list_normalizers.append(normalizers.StripAccents())
        # 如果原始分词器需要小写化处理,添加小写化规范化器
        if self.original_tokenizer.do_lower_case:
            list_normalizers.append(normalizers.Lowercase())

        # 获取预编译字符映射表
        precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap

        # 如果存在预编译字符映射表,添加预编译规范化器
        if precompiled_charsmap:
            list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))

        # 添加用正则表达式替换多个连续空格为单个空格的规范化器
        list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
        
        # 返回一个组合了所有规范化器的序列化规范化器对象
        return normalizers.Sequence(list_normalizers)

    # 定义一个方法用于后处理器,返回一个模板处理对象
    def post_processor(self):
        return processors.TemplateProcessing(
            # 单个文本序列的模板,使用特定占位符和分隔符
            single="$A:0 <sep>:0 <cls>:2",
            # 成对文本序列的模板,使用特定占位符和分隔符
            pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
            # 定义特殊标记及其对应的 token ID
            special_tokens=[
                ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
                ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
            ],
        )
class ReformerConverter(SpmConverter):
    pass



class RemBertConverter(SpmConverter):
    # 受 AlbertConverter 启发

    # 标准化器方法,处理给定的 proto 对象
    def normalizer(self, proto):
        # 定义一组标准化器列表,用于处理文本
        list_normalizers = [
            normalizers.Replace("``", '"'),
            normalizers.Replace("''", '"'),
            normalizers.Replace(Regex(" {2,}"), " "),
        ]
        # 如果不保留重音符号,则添加相应的标准化器
        if not self.original_tokenizer.keep_accents:
            list_normalizers.append(normalizers.NFKD())
            list_normalizers.append(normalizers.StripAccents())
        # 如果执行小写转换,则添加小写化标准化器
        if self.original_tokenizer.do_lower_case:
            list_normalizers.append(normalizers.Lowercase())

        # 从 proto 中获取预编译的字符映射,如果存在,则添加预编译标准化器
        precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
        if precompiled_charsmap:
            list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))

        # 返回一个序列化的标准化器对象
        return normalizers.Sequence(list_normalizers)

    # 后处理器方法,返回一个模板处理器对象
    def post_processor(self):
        return processors.TemplateProcessing(
            single="[CLS]:0 $A:0 [SEP]:0",
            pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
            special_tokens=[
                ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
                ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
            ],
        )



class BertGenerationConverter(SpmConverter):
    pass



class PegasusConverter(SpmConverter):
    # 词汇表方法,生成给定 proto 的词汇表
    def vocab(self, proto):
        vocab = [
            (self.original_tokenizer.pad_token, 0.0),
            (self.original_tokenizer.eos_token, 0.0),
        ]

        # 如果存在 mask_token_sent,则添加到词汇表
        if self.original_tokenizer.mask_token_sent is not None:
            vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]

        # 如果存在 mask_token 并且其 ID 小于偏移值,则添加到词汇表
        if (
            self.original_tokenizer.mask_token is not None
            and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
        ):
            vocab += [(self.original_tokenizer.mask_token, 0.0)]

        # 添加未知词标记,对于从 2 到偏移值的范围,使用固定的负分数
        vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
        # 添加 proto 对象中第二个元素之后的所有片段和它们的分数
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
        return vocab

    # 未知词 ID 方法,根据 proto 对象返回未知词 ID
    def unk_id(self, proto):
        return proto.trainer_spec.unk_id + self.original_tokenizer.offset

    # 预分词器方法,返回一个序列化的预分词器对象
    def pre_tokenizer(self, replacement, add_prefix_space):
        return pre_tokenizers.Sequence(
            [
                pre_tokenizers.WhitespaceSplit(),
                pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
            ]
        )

    # 后处理器方法,返回一个模板处理器对象
    def post_processor(self):
        eos = self.original_tokenizer.eos_token
        special_tokens = [
            (eos, self.original_tokenizer.eos_token_id),
        ]
        return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)



class T5Converter(SpmConverter):
    pass
    # 定义一个方法用于生成词汇表,接收一个 proto 参数
    def vocab(self, proto):
        # 获取原始分词器的额外 ID 数量
        num_extra_ids = self.original_tokenizer._extra_ids
        # 从 proto 的 pieces 属性中提取词汇和对应的分数,组成列表
        vocab = [(piece.piece, piece.score) for piece in proto.pieces]
        # 添加额外的特殊标记到词汇表中,这些标记是以 "<extra_id_i>" 格式的字符串
        vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
        # 返回生成的词汇表
        return vocab

    # 定义一个方法用于生成后处理器
    def post_processor(self):
        # 返回一个模板处理器对象,配置了不同长度的模板以及特殊标记的转换
        return processors.TemplateProcessing(
            single=["$A", "</s>"],  # 单文本模板,包含 "$A" 和 "</s>"
            pair=["$A", "</s>", "$B", "</s>"],  # 双文本模板,包含 "$A", "</s>", "$B", "</s>"
            special_tokens=[  # 特殊标记的配置,将 "</s>" 映射到其对应的 ID
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )
# 定义名为 UdopConverter 的类,继承自 SpmConverter 类
class UdopConverter(SpmConverter):
    
    # 定义 post_processor 方法,用于创建处理器对象
    def post_processor(self):
        # 返回 TemplateProcessing 处理器对象,配置如下参数:
        return processors.TemplateProcessing(
            # 单句模板,使用变量 $A 和结束标记 </s>
            single=["$A", "</s>"],
            # 双句模板,使用变量 $A 和 $B,并以 </s> 作为结束标记
            pair=["$A", "</s>", "$B", "</s>"],
            # 特殊 token 配置,包含结束标记 </s> 的 ID 映射
            special_tokens=[
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )


# 定义名为 WhisperConverter 的类,继承自 Converter 类
class WhisperConverter(Converter):
    
    # 定义 converted 方法,返回 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始分词器的词汇表和合并列表
        vocab = self.original_tokenizer.encoder
        merges = list(self.original_tokenizer.bpe_ranks.keys())

        # 创建 Tokenizer 对象,配置如下参数:
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,
                merges=merges,
                dropout=None,
                continuing_subword_prefix="",
                end_of_word_suffix="",
                fuse_unk=False,
            )
        )

        # 设置 Tokenizer 的预处理器和解码器为 ByteLevel
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
        tokenizer.decoder = decoders.ByteLevel()

        # 获取原始分词器的前缀 token ID 和对应的 token 列表
        prefix_token_ids = self.original_tokenizer.prefix_tokens
        prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
        eos = self.original_tokenizer.eos_token
        eos_token_id = self.original_tokenizer.eos_token_id
        
        # 构建前缀模板字符串,以及设置 Tokenizer 的后处理器
        prefix_template = " ".join([f"{token}:0" for token in prefixes])
        tokenizer.post_processor = processors.TemplateProcessing(
            # 单句模板,包含前缀模板、变量 $A 和结束标记的 ID 映射
            single=f"{prefix_template} $A:0 {eos}:0",
            # 双句模板,包含前缀模板、变量 $A 和 $B,以及结束标记的 ID 映射
            pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
            # 特殊 token 配置,包含结束标记和前缀 token 的 ID 映射
            special_tokens=[
                (eos, eos_token_id),
                *zip(prefixes, prefix_token_ids),
            ],
        )

        # 返回配置完成的 Tokenizer 对象
        return tokenizer


# 定义名为 BigBirdConverter 的类,继承自 SpmConverter 类
class BigBirdConverter(SpmConverter):
    
    # 定义 post_processor 方法,用于创建处理器对象
    def post_processor(self):
        # 返回 TemplateProcessing 处理器对象,配置如下参数:
        return processors.TemplateProcessing(
            # 单句模板,使用固定 token [CLS] 和变量 $A,以及固定 token [SEP]
            single="[CLS]:0 $A:0 [SEP]:0",
            # 双句模板,使用固定 token [CLS]、变量 $A 和 $B,以及两个 [SEP] 标记
            pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
            # 特殊 token 配置,包含 [CLS] 和 [SEP] 的 ID 映射
            special_tokens=[
                ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
                ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
            ],
        )


class CLIPConverter(Converter):
    # 这里是未完成的类定义,需要在此处继续补充代码
    # 定义一个方法 `converted`,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 从原始分词器获取词汇表
        vocab = self.original_tokenizer.encoder
        # 从原始分词器获取BPE合并操作列表
        merges = list(self.original_tokenizer.bpe_ranks.keys())
        # 获取原始分词器的未知标记
        unk_token = self.original_tokenizer.unk_token

        # 创建一个 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,  # 设置词汇表
                merges=merges,  # 设置合并操作列表
                dropout=None,  # 不使用dropout
                continuing_subword_prefix="",  # 继续子词前缀为空
                end_of_word_suffix="</w>",  # 设置词尾标记
                fuse_unk=False,  # 禁用未知标记融合
                unk_token=str(unk_token),  # 设置未知标记
            )
        )

        # 设置标准化器为 NFC、替换多余空格为单个空格、转换为小写的序列
        tokenizer.normalizer = normalizers.Sequence(
            [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
        )

        # 设置预分词器序列,包括使用正则表达式和字节级预处理器
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
            [
                pre_tokenizers.Split(
                    Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
                    behavior="removed",  # 移除匹配的内容
                    invert=True,  # 反转操作
                ),
                pre_tokenizers.ByteLevel(add_prefix_space=False),  # 字节级处理,无前缀空格
            ]
        )

        # 设置解码器为字节级解码器
        tokenizer.decoder = decoders.ByteLevel()

        # 使用 RobertaProcessing 处理器进行后处理,设置分隔符和特殊标记
        tokenizer.post_processor = processors.RobertaProcessing(
            sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),  # 分隔符设定
            cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),  # 类标记设定
            add_prefix_space=False,  # 不添加前缀空格
            trim_offsets=False,  # 不修剪偏移量
        )

        # 返回创建的 Tokenizer 对象
        return tokenizer
class LayoutLMv2Converter(Converter):
    # LayoutLMv2Converter 类,继承自 Converter 类,用于实现转换器功能
    def converted(self) -> Tokenizer:
        # 转换方法,返回一个 Tokenizer 对象
        vocab = self.original_tokenizer.vocab
        # 获取原始 tokenizer 的词汇表

        # 创建 Tokenizer 对象,使用 WordPiece 模型,并传入 unk_token 作为未知标记
        tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))

        tokenize_chinese_chars = False
        strip_accents = False
        do_lower_case = True

        # 检查原始 tokenizer 是否具有 basic_tokenizer 属性,根据属性值设置相应变量
        if hasattr(self.original_tokenizer, "basic_tokenizer"):
            tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
            strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
            do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case

        # 设置 tokenizer 的正则化器为 BertNormalizer,配置各种参数
        tokenizer.normalizer = normalizers.BertNormalizer(
            clean_text=True,
            handle_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            lowercase=do_lower_case,
        )

        # 设置 tokenizer 的预处理器为 BertPreTokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()

        # 获取特殊标记的字符串表示,并分配给相应变量
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设置 tokenizer 的后处理器为 TemplateProcessing,根据单句和双句模板配置
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls}:0 $A:0 {sep}:0",
            pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
            special_tokens=[
                (cls, cls_token_id),
                (sep, sep_token_id),
            ],
        )

        # 设置 tokenizer 的解码器为 WordPiece 解码器,前缀为 "##"
        tokenizer.decoder = decoders.WordPiece(prefix="##")

        return tokenizer
        # 返回配置好的 Tokenizer 对象


class BlenderbotConverter(Converter):
    # BlenderbotConverter 类,继承自 Converter 类,用于实现转换器功能
    def converted(self) -> Tokenizer:
        # 转换方法,返回一个 Tokenizer 对象
        ot = self.original_tokenizer
        vocab = ot.encoder
        merges = list(ot.bpe_ranks.keys())

        # 创建 Tokenizer 对象,使用 BPE 模型,并传入相应的参数
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,
                merges=merges,
                dropout=None,
                continuing_subword_prefix="",
                end_of_word_suffix="",
                fuse_unk=False,
            )
        )

        # 设置 tokenizer 的预处理器为 ByteLevel,并根据原始 tokenizer 的属性配置添加前缀空格
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)

        # 设置 tokenizer 的解码器为 ByteLevel 解码器
        tokenizer.decoder = decoders.ByteLevel()

        # 设置 tokenizer 的后处理器为 TemplateProcessing,根据单句模板配置
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"$A:0 {ot.eos_token}:0",
            special_tokens=[
                (ot.eos_token, ot.eos_token_id),
            ],
        )

        return tokenizer
        # 返回配置好的 Tokenizer 对象


class XGLMConverter(SpmConverter):
    # XGLMConverter 类,继承自 SpmConverter 类,用于实现转换器功能
    def vocab(self, proto):
        # 生成词汇表的方法,接受一个 proto 参数
        vocab = [
            ("<s>", 0.0),
            ("<pad>", 0.0),
            ("</s>", 0.0),
            ("<unk>", 0.0),
        ]
        # 将 proto 中的 pieces 转换为词汇表的元组,从第三个元素开始
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        # 添加一些假词汇到词汇表中

        return vocab
        # 返回生成的词汇表


    def unk_id(self, proto):
        # 获取未知标记的 ID 的方法,接受一个 proto 参数
        unk_id = 3
        return unk_id
        # 返回未知标记的 ID
    # 定义一个方法 `post_processor`,用于生成一个处理器对象 `TemplateProcessing`
    def post_processor(self):
        # 返回一个 TemplateProcessing 对象,配置如下参数:
        return processors.TemplateProcessing(
            # 当处理单个句子时的模板,插入特殊标记 `$A`
            single="</s> $A",
            # 当处理句对时的模板,插入特殊标记 `$A` 和 `$B`
            pair="</s> $A </s> </s> $B",
            # 定义一些特殊标记及其对应的 ID,使用了 `original_tokenizer` 中的方法将特殊标记转换为 ID
            special_tokens=[
                ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
                ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
            ],
        )
# GemmaConvert 类继承自 SpmConverter,用于特定文本转换任务的定制化处理
class GemmaConvert(SpmConverter):
    # 设置字节处理回退选项为 True
    handle_byte_fallback = True

    # 下面是一个多行字符串,可能用于配置参数,未直接使用于代码逻辑中
    """"
    split_by_unicode_script: true
    split_by_number: true
    split_by_whitespace: true
    treat_whitespace_as_suffix: false
    allow_whitespace_only_pieces: true
    split_digits: true
    byte_fallback: true
    """

    # 标准化器函数,返回一个替换空格为特定符号的标准化器对象
    def normalizer(self, proto):
        return normalizers.Replace(" ", "▁")

    # 词汇表函数,根据 proto 对象的 pieces 属性生成词汇表列表
    def vocab(self, proto):
        vocab = [
            (self.original_tokenizer.pad_token, 0.0),  # 添加填充标记和对应的得分
            (self.original_tokenizer.eos_token, 0.0),  # 添加结束标记和对应的得分
            (self.original_tokenizer.bos_token, 0.0),  # 添加起始标记和对应的得分
        ]
        # 遍历 proto 对象的 pieces 属性,从第四个元素开始添加到词汇表中
        for piece in proto.pieces[3:]:
            if piece.piece == "<0x09>":
                vocab += [("\t", piece.score)]  # 如果词素是 "<0x09>",则用制表符 "\t" 替代
            else:
                vocab += [(piece.piece, piece.score)]  # 否则直接添加词素和得分
        # 返回生成的词汇表
        return vocab

    # 预处理分词器函数,返回 None 表示没有预处理分词器
    def pre_tokenizer(self, replacement, add_prefix_space):
        return None

    # 未知标记 ID 函数,始终返回整数值 3 作为未知标记 ID
    def unk_id(self, proto):
        unk_id = 3
        return unk_id

    # 解码器函数,返回一个序列解码器对象,按顺序执行替换、字节回退和融合操作
    def decoder(self, replacement, add_prefix_space):
        return decoders.Sequence(
            [
                decoders.Replace("▁", " "),  # 将特定符号 "▁" 替换为空格
                decoders.ByteFallback(),  # 字节回退解码器
                decoders.Fuse(),  # 融合解码器
            ]
        )
    # 定义一个方法 `tokenizer`,接受一个参数 `proto`
    def tokenizer(self, proto):
        # 从参数 `proto` 的 `trainer_spec` 属性中获取 `model_type`
        model_type = proto.trainer_spec.model_type
        # 调用当前对象的 `vocab` 方法,获取词汇表和分数
        vocab_scores = self.vocab(proto)
        
        # 根据 `model_type` 的值进行条件判断
        if model_type == 1:
            # 如果 `model_type` 为 1,导入 `tokenizers` 模块
            import tokenizers

            # 检查 `tokenizers` 模块的版本是否小于 "0.14.0"
            if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
                # 如果版本小于 "0.14.0",创建一个 `Tokenizer` 对象,使用 `Unigram` 模型和词汇分数
                tokenizer = Tokenizer(Unigram(vocab_scores, 0))
            else:
                # 如果版本大于等于 "0.14.0",创建一个 `Tokenizer` 对象,使用 `Unigram` 模型、词汇分数和字节回退
                tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))

        elif model_type == 2:
            # 如果 `model_type` 为 2,调用 `GemmaSentencePieceExtractor`,提取词汇分数和合并列表
            _, merges = GemmaSentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
            # 创建 BPE 词汇表,将词汇与索引对应起来
            bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}

            # 创建一个 `Tokenizer` 对象,使用 BPE 模型、词汇表、合并列表和其他参数
            tokenizer = Tokenizer(
                BPE(
                    bpe_vocab,
                    merges,
                    unk_token=proto.trainer_spec.unk_piece,
                    fuse_unk=True,
                    byte_fallback=True,
                    dropout=None,
                )
            )
            # 向 `tokenizer` 添加特殊标记
            tokenizer.add_special_tokens(
                [
                    AddedToken("<pad>", normalized=False, special=True),
                    AddedToken("<eos>", normalized=False, special=True),
                    AddedToken("<bos>", normalized=False, special=True),
                    AddedToken("<unk>", normalized=False, special=True),
                ]
            )
        else:
            # 如果 `model_type` 既不是 1 也不是 2,抛出异常
            raise Exception(
                "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
            )
        
        # 根据 `proto` 的 `trainer_spec.user_defined_symbols` 创建用户自定义符号的 `AddedToken` 列表
        user_defined_symbols = [
            AddedToken(token, normalized=False, special=False) for token in proto.trainer_spec.user_defined_symbols
        ]
        # 向 `tokenizer` 添加用户自定义符号
        tokenizer.add_tokens(user_defined_symbols)
        
        # 返回创建的 `tokenizer` 对象
        return tokenizer
class LlamaConverter(SpmConverter):
    # 设置处理字节回退的开关为 True
    handle_byte_fallback = True

    # 构建词汇表的方法,接受一个 proto 参数
    def vocab(self, proto):
        # 初始词汇表包含特殊标记和默认得分
        vocab = [
            ("<unk>", 0.0),
            ("<s>", 0.0),
            ("</s>", 0.0),
        ]
        # 将 proto 中第三个位置之后的词片段及其得分加入词汇表
        vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
        return vocab

    # 返回未知标记的 ID,默认为 0
    def unk_id(self, proto):
        unk_id = 0
        return unk_id

    # 返回解码器对象,用于文本序列的解析和替换
    def decoder(self, replacement, add_prefix_space):
        sequence = [
            decoders.Replace("▁", " "),  # 将 "▁" 替换为 " "
            decoders.ByteFallback(),  # 字节回退处理器
            decoders.Fuse(),  # 合并处理器
        ]
        # 如果需要在前缀空格之前添加处理器,则添加处理器去除左边的空格
        if add_prefix_space:
            sequence += [decoders.Strip(content=" ", left=1)]
        return decoders.Sequence(sequence)

    # 返回标记器对象,根据模型类型选择不同的标记化方法
    def tokenizer(self, proto):
        model_type = proto.trainer_spec.model_type
        vocab_scores = self.vocab(proto)

        # 根据模型类型选择合适的标记器
        if model_type == 1:
            import tokenizers

            # 根据 tokenizers 库的版本选择 Unigram 标记器
            if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
                tokenizer = Tokenizer(Unigram(vocab_scores, 0))
            else:
                tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))

        elif model_type == 2:
            _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
            bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
            # 使用 BPE 标记器,并添加特殊标记
            tokenizer = Tokenizer(
                BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
            )
            tokenizer.add_special_tokens(
                [
                    AddedToken("<unk>", normalized=False, special=True),
                    AddedToken("<s>", normalized=False, special=True),
                    AddedToken("</s>", normalized=False, special=True),
                ]
            )
        else:
            # 抛出异常,提示模型类型与训练算法不匹配
            raise Exception(
                "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
            )

        return tokenizer

    # 返回正规化器对象,处理文本的规范化过程
    def normalizer(self, proto):
        sequence = []
        # 如果原始标记器具有添加前缀空格的功能,则在序列中添加前缀处理器
        if hasattr(self.original_tokenizer, "add_prefix_space"):
            if self.original_tokenizer.add_prefix_space:
                sequence += [normalizers.Prepend(prepend="▁")]
        # 将空格替换为 "▁" 的处理器添加到序列中
        sequence += [normalizers.Replace(pattern=" ", content="▁")]
        return normalizers.Sequence(sequence)

    # 返回预标记器对象,用于预处理文本中的特定标记
    def pre_tokenizer(self, replacement, add_prefix_space):
        # 返回空值,表示没有预标记器
        return None

    # 返回后处理器对象,用于进一步处理标记化后的文本
    def post_processor(self):
        # 返回空值,表示没有后处理器
        # 后处理器在 LlamaTokenizerFast 类中定义
        return None
    # 定义一个方法 `converted`,返回一个 Tokenizer 对象
    def converted(self) -> Tokenizer:
        # 获取原始的分词器对象
        ot = self.original_tokenizer
        # 获取分词器的词汇表
        vocab = ot.encoder
        # 获取分词器的合并列表
        merges = list(ot.bpe_ranks.keys())

        # 创建一个新的 Tokenizer 对象,使用 BPE 分词器
        tokenizer = Tokenizer(
            BPE(
                vocab=vocab,  # 设定词汇表
                merges=merges,  # 设定合并列表
                dropout=None,  # 不使用 dropout
                continuing_subword_prefix="",  # 设定连续子词的前缀
                end_of_word_suffix="",  # 设定单词结束后缀
                fuse_unk=False,  # 不融合未知标记
                unk_token=self.original_tokenizer.unk_token,  # 设定未知标记
            )
        )

        # 设定预分词器为 ByteLevel,并根据原始分词器的参数设定
        tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
        # 设定解码器为 ByteLevel
        tokenizer.decoder = decoders.ByteLevel()

        # 获取原始分词器的特殊标记(如 `[CLS]` 和 `[SEP]`)
        cls = str(self.original_tokenizer.cls_token)
        sep = str(self.original_tokenizer.sep_token)
        cls_token_id = self.original_tokenizer.cls_token_id
        sep_token_id = self.original_tokenizer.sep_token_id

        # 设定后处理器为 TemplateProcessing,根据原始分词器的特殊标记设定模板
        tokenizer.post_processor = processors.TemplateProcessing(
            single=f"{cls} $A {sep}",  # 单句模板
            pair=f"{cls} $A {sep} $B {sep}",  # 双句模板
            special_tokens=[
                (cls, cls_token_id),  # 添加 `[CLS]` 特殊标记
                (sep, sep_token_id),  # 添加 `[SEP]` 特殊标记
            ],
        )

        # 返回创建好的 Tokenizer 对象
        return tokenizer
# 定义一个映射字典,将慢速tokenizer的类名映射到相应的快速converter类
SLOW_TO_FAST_CONVERTERS = {
    "AlbertTokenizer": AlbertConverter,
    "BartTokenizer": RobertaConverter,
    "BarthezTokenizer": BarthezConverter,
    "BertTokenizer": BertConverter,
    "BigBirdTokenizer": BigBirdConverter,
    "BlenderbotTokenizer": BlenderbotConverter,
    "CamembertTokenizer": CamembertConverter,
    "CLIPTokenizer": CLIPConverter,
    "CodeGenTokenizer": GPT2Converter,
    "ConvBertTokenizer": BertConverter,
    "DebertaTokenizer": DebertaConverter,
    "DebertaV2Tokenizer": DebertaV2Converter,
    "DistilBertTokenizer": BertConverter,
    "DPRReaderTokenizer": BertConverter,
    "DPRQuestionEncoderTokenizer": BertConverter,
    "DPRContextEncoderTokenizer": BertConverter,
    "ElectraTokenizer": BertConverter,
    "FNetTokenizer": AlbertConverter,
    "FunnelTokenizer": FunnelConverter,
    "GPT2Tokenizer": GPT2Converter,
    "HerbertTokenizer": HerbertConverter,
    "LayoutLMTokenizer": BertConverter,
    "LayoutLMv2Tokenizer": BertConverter,
    "LayoutLMv3Tokenizer": RobertaConverter,
    "LayoutXLMTokenizer": XLMRobertaConverter,
    "LongformerTokenizer": RobertaConverter,
    "LEDTokenizer": RobertaConverter,
    "LxmertTokenizer": BertConverter,
    "MarkupLMTokenizer": MarkupLMConverter,
    "MBartTokenizer": MBartConverter,
    "MBart50Tokenizer": MBart50Converter,
    "MPNetTokenizer": MPNetConverter,
    "MobileBertTokenizer": BertConverter,
    "MvpTokenizer": RobertaConverter,
    "NllbTokenizer": NllbConverter,
    "OpenAIGPTTokenizer": OpenAIGPTConverter,
    "PegasusTokenizer": PegasusConverter,
    "Qwen2Tokenizer": Qwen2Converter,
    "RealmTokenizer": BertConverter,
    "ReformerTokenizer": ReformerConverter,
    "RemBertTokenizer": RemBertConverter,
    "RetriBertTokenizer": BertConverter,
    "RobertaTokenizer": RobertaConverter,
    "RoFormerTokenizer": RoFormerConverter,
    "SeamlessM4TTokenizer": SeamlessM4TConverter,
    "SqueezeBertTokenizer": BertConverter,
    "T5Tokenizer": T5Converter,
    "UdopTokenizer": UdopConverter,
    "WhisperTokenizer": WhisperConverter,
    "XLMRobertaTokenizer": XLMRobertaConverter,
    "XLNetTokenizer": XLNetConverter,
    "SplinterTokenizer": SplinterConverter,
    "XGLMTokenizer": XGLMConverter,
    "LlamaTokenizer": LlamaConverter,
    "CodeLlamaTokenizer": LlamaConverter,
    "GemmaTokenizer": GemmaConvert,
}

# 定义函数,将慢速tokenizer实例转换为对应的快速tokenizer实例
def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
    """
    Utilities to convert a slow tokenizer instance in a fast tokenizer instance.

    Args:
        transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
            Instance of a slow tokenizer to convert in the backend tokenizer for
            [`~tokenization_utils_base.PreTrainedTokenizerFast`].

    Return:
        A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
        [`~tokenization_utils_base.PreTrainedTokenizerFast`]
    """

    # 获取tokenizer的类名
    tokenizer_class_name = transformer_tokenizer.__class__.__name__
    # 检查要转换的分词器类名是否存在于SLOW_TO_FAST_CONVERTERS字典中
    if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
        # 如果不存在,则抛出值错误异常,指明无法将该分词器类转换为快速分词器实例
        raise ValueError(
            f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
            " No converter was found. Currently available slow->fast convertors:"
            f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
        )

    # 根据分词器类名从SLOW_TO_FAST_CONVERTERS字典中获取对应的转换器类
    converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]

    # 返回通过转换器类对transformer_tokenizer进行转换后的结果
    return converter_class(transformer_tokenizer).converted()
posted @ 2024-06-29 15:49  绝不原创的飞龙  阅读(17)  评论(0编辑  收藏  举报