Transformers-源码解析-四-

Transformers 源码解析(四)

.\deepspeed.py

# 引入警告模块,用于向用户显示有关未来变更或不推荐使用的信息
import warnings

# 发出警告,提示用户 transformers.deepspeed 模块已被弃用,并且将在未来的版本中移除
warnings.warn(
    "transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
    FutureWarning,
)

# 导入用于向后兼容的模块,确保所有对象可以在 integrations/deepspeed 中找到
from .integrations.deepspeed import (  # noqa
    HfDeepSpeedConfig,                  # 导入 HfDeepSpeedConfig 类
    HfTrainerDeepSpeedConfig,           # 导入 HfTrainerDeepSpeedConfig 类
    deepspeed_config,                   # 导入 deepspeed_config 函数
    deepspeed_init,                     # 导入 deepspeed_init 函数
    deepspeed_load_checkpoint,          # 导入 deepspeed_load_checkpoint 函数
    deepspeed_optim_sched,              # 导入 deepspeed_optim_sched 函数
    is_deepspeed_available,             # 导入 is_deepspeed_available 函数
    is_deepspeed_zero3_enabled,         # 导入 is_deepspeed_zero3_enabled 函数
    set_hf_deepspeed_config,            # 导入 set_hf_deepspeed_config 函数
    unset_hf_deepspeed_config,          # 导入 unset_hf_deepspeed_config 函数
)

.\dependency_versions_check.py

# 从依赖版本表导入依赖字典
from .dependency_versions_table import deps
# 从工具目录下的版本模块导入版本检查函数
from .utils.versions import require_version, require_version_core

# 定义需要在运行时检查的模块版本列表
# 通常包括在 setup.py 的 install_requires 中定义的模块
#
# 特定顺序的注意事项:
# - 必须在 tokenizers 之前检查 tqdm

pkgs_to_check_at_runtime = [
    "python",
    "tqdm",
    "regex",
    "requests",
    "packaging",
    "filelock",
    "numpy",
    "tokenizers",
    "huggingface-hub",
    "safetensors",
    "accelerate",
    "pyyaml",
]

# 遍历需要在运行时检查的模块列表
for pkg in pkgs_to_check_at_runtime:
    # 如果依赖字典中存在该模块
    if pkg in deps:
        # 如果当前模块是 "tokenizers"
        if pkg == "tokenizers":
            # 必须在这里加载,否则 tqdm 的检查可能会失败
            from .utils import is_tokenizers_available

            # 如果 tokenizers 模块不可用,跳过检查版本,只在安装时检查
            if not is_tokenizers_available():
                continue
        # 如果当前模块是 "accelerate"
        elif pkg == "accelerate":
            # 必须在这里加载,否则 tqdm 的检查可能会失败
            from .utils import is_accelerate_available

            # 或许将来可以在这里切换为 is_torch_available,以便 Accelerate 成为 Transformers 与 PyTorch 的硬依赖
            if not is_accelerate_available():
                continue

        # 要求核心版本满足依赖字典中对应模块的要求
        require_version_core(deps[pkg])
    else:
        # 如果依赖字典中找不到当前模块,则抛出异常
        raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")


def dep_version_check(pkg, hint=None):
    # 要求满足依赖字典中对应模块的版本要求
    require_version(deps[pkg], hint)

.\dependency_versions_table.py

# 自动化生成的依赖字典,用于设置项目的依赖关系和版本限制
deps = {
    "Pillow": "Pillow>=10.0.1,<=15.0",  # 图像处理库 Pillow 的版本要求在 10.0.1 到 15.0 之间
    "accelerate": "accelerate>=0.21.0",  # 加速计算库 accelerate 的版本要求至少为 0.21.0
    "av": "av==9.2.0",  # 多媒体处理库 av 的版本要求为精确匹配 9.2.0
    "beautifulsoup4": "beautifulsoup4",  # 解析 HTML 和 XML 文档的库 beautifulsoup4,版本不做限制
    "codecarbon": "codecarbon==1.2.0",  # 计算代码碳足迹的库 codecarbon 的版本要求为精确匹配 1.2.0
    "cookiecutter": "cookiecutter==1.7.3",  # 项目模板生成工具 cookiecutter 的版本要求为精确匹配 1.7.3
    "dataclasses": "dataclasses",  # Python 3.7 引入的 dataclasses 库,版本不做限制
    "datasets": "datasets!=2.5.0",  # 数据集处理库 datasets 的版本要求排除 2.5.0
    "decord": "decord==0.6.0",  # 多媒体处理库 decord 的版本要求为精确匹配 0.6.0
    "deepspeed": "deepspeed>=0.9.3",  # 分布式训练加速库 deepspeed 的版本要求至少为 0.9.3
    "diffusers": "diffusers",  # 数据扰动库 diffusers,版本不做限制
    "dill": "dill<0.3.5",  # 对象序列化库 dill 的版本要求小于 0.3.5
    "evaluate": "evaluate>=0.2.0",  # 评估工具库 evaluate 的版本要求至少为 0.2.0
    "faiss-cpu": "faiss-cpu",  # 向量相似度搜索库 faiss-cpu,版本不做限制
    "fastapi": "fastapi",  # 高性能 API 框架 fastapi,版本不做限制
    "filelock": "filelock",  # 文件锁定库 filelock,版本不做限制
    "flax": "flax>=0.4.1,<=0.7.0",  # JAX 的神经网络库 flax 的版本要求在 0.4.1 到 0.7.0 之间
    "fsspec": "fsspec<2023.10.0",  # 文件系统库 fsspec 的版本要求小于 2023.10.0
    "ftfy": "ftfy",  # 处理 Unicode 文本的库 ftfy,版本不做限制
    "fugashi": "fugashi>=1.0",  # 日语分词器 fugashi 的版本要求至少为 1.0
    "GitPython": "GitPython<3.1.19",  # Git 操作库 GitPython 的版本要求小于 3.1.19
    "hf-doc-builder": "hf-doc-builder>=0.3.0",  # Hugging Face 文档构建工具的版本要求至少为 0.3.0
    "huggingface-hub": "huggingface-hub>=0.19.3,<1.0",  # Hugging Face 模型中心库的版本要求在 0.19.3 到 1.0 之间
    "importlib_metadata": "importlib_metadata",  # 导入库信息的元数据库 importlib_metadata,版本不做限制
    "ipadic": "ipadic>=1.0.0,<2.0",  # 日语词典 ipadic 的版本要求在 1.0.0 到 2.0 之间
    "isort": "isort>=5.5.4",  # Python 代码排序工具 isort 的版本要求至少为 5.5.4
    "jax": "jax>=0.4.1,<=0.4.13",  # 数值计算库 JAX 的版本要求在 0.4.1 到 0.4.13 之间
    "jaxlib": "jaxlib>=0.4.1,<=0.4.13",  # JAX 的线性代数库 jaxlib 的版本要求在 0.4.1 到 0.4.13 之间
    "jieba": "jieba",  # 中文分词库 jieba,版本不做限制
    "kenlm": "kenlm",  # 语言模型工具 kenlm,版本不做限制
    "keras": "keras<2.16",  # 深度学习库 Keras 的版本要求小于 2.16
    "keras-nlp": "keras-nlp>=0.3.1",  # Keras 自然语言处理库 keras-nlp 的版本要求至少为 0.3.1
    "librosa": "librosa",  # 音频处理库 librosa,版本不做限制
    "nltk": "nltk",  # 自然语言工具包 NLTK,版本不做限制
    "natten": "natten>=0.14.6,<0.15.0",  # 多头自注意力模型库 natten 的版本要求在 0.14.6 到 0.15.0 之间
    "numpy": "numpy>=1.17",  # 数值计算库 numpy 的版本要求至少为 1.17
    "onnxconverter-common": "onnxconverter-common",  # ONNX 模型转换通用库 onnxconverter-common,版本不做限制
    "onnxruntime-tools": "onnxruntime-tools>=1.4.2",  # ONNX 运行时工具库 onnxruntime-tools 的版本要求至少为 1.4.2
    "onnxruntime": "onnxruntime>=1.4.0",  # ONNX 运行时库 onnxruntime 的版本要求至少为 1.4.0
    "opencv-python": "opencv-python",  # 计算机视觉库 opencv-python,版本不做限制
    "optuna": "optuna",  # 自动机器学习工具 optuna,版本不做限制
    "optax": "optax>=0.0.8,<=0.1.4",  # 优化库 optax 的版本要求在 0.0.8 到 0.1.4 之间
    "packaging": "packaging>=20.0",  # 打包工具库 packaging 的版本要求至少为 20.0
    "parameterized": "parameterized",  # 参数化测试工具 parameterized,版本不做限制
    "phonemizer": "phonemizer",  # 文本到音素转换库 phonemizer,版本不做限制
    "protobuf": "protobuf",  # Google 的序列化库 protobuf,版本不做限制
    "psutil": "psutil",  # 进程和系统工具库 psutil,版本不做限制
    "pyyaml": "pyyaml>=5.1",  # YAML 解析器库 pyyaml 的版本要求至少为 5.1
    "pydantic": "pydantic",  # 数据验证库 pydantic,版本不做限制
    "pytest": "pytest>=7.2.0,<8.0.0",  # 测试框架 pytest 的版本要求在 7.2.0 到 8.0.0 之间
    "pytest-timeout": "pytest-timeout",  # pytest 插件 pytest-timeout,版本不做限制
    "pytest-xdist": "pytest-xdist",  # pytest 插件 pytest-xdist,版本不做限制
    "python": "python>=3.8.0",  # Python 解释器的版本要求至少为 3.8.0
    "ray[t
    "pyctcdecode": "pyctcdecode>=0.4.0",
    # 定义依赖项:pyctcdecode 库,版本需大于或等于 0.4.0

    "tqdm": "tqdm>=4.27",
    # 定义依赖项:tqdm 库,版本需大于或等于 4.27

    "unidic": "unidic>=1.0.2",
    # 定义依赖项:unidic 库,版本需大于或等于 1.0.2

    "unidic_lite": "unidic_lite>=1.0.7",
    # 定义依赖项:unidic_lite 库,版本需大于或等于 1.0.7

    "urllib3": "urllib3<2.0.0",
    # 定义依赖项:urllib3 库,版本需小于 2.0.0

    "uvicorn": "uvicorn",
    # 定义依赖项:uvicorn 库,无指定版本要求
}


注释:


# 这行代码表示一个代码块的结束,对应于一个以 '{' 开始的代码块的结束

.\dynamic_module_utils.py

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to dynamically load objects from the Hub."""
import filecmp
import importlib
import os
import re
import shutil
import signal
import sys
import typing
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from huggingface_hub import try_to_load_from_cache

from .utils import (
    HF_MODULES_CACHE,
    TRANSFORMERS_DYNAMIC_MODULE_NAME,
    cached_file,
    extract_commit_hash,
    is_offline_mode,
    logging,
)

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def init_hf_modules():
    """
    Creates the cache directory for modules with an init, and adds it to the Python path.
    """
    # 如果 HF_MODULES_CACHE 已经在 Python 路径中,说明函数已经执行过,直接返回
    if HF_MODULES_CACHE in sys.path:
        return

    # 将 HF_MODULES_CACHE 加入到 Python 路径中
    sys.path.append(HF_MODULES_CACHE)
    # 创建 HF_MODULES_CACHE 目录,如果目录已存在则不做操作
    os.makedirs(HF_MODULES_CACHE, exist_ok=True)
    # 在 HF_MODULES_CACHE 目录下创建 __init__.py 文件,如果文件已存在则不做操作
    init_path = Path(HF_MODULES_CACHE) / "__init__.py"
    if not init_path.exists():
        init_path.touch()
        # 清除 importlib 缓存,使得新创建的模块可以被正确加载
        importlib.invalidate_caches()


def create_dynamic_module(name: Union[str, os.PathLike]):
    """
    Creates a dynamic module in the cache directory for modules.

    Args:
        name (`str` or `os.PathLike`):
            The name of the dynamic module to create.
    """
    # 初始化 HF 模块,确保 HF 模块缓存目录存在并在 Python 路径中
    init_hf_modules()
    # 获取动态模块的完整路径
    dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
    # 如果父目录不存在,则递归创建
    if not dynamic_module_path.parent.exists():
        create_dynamic_module(dynamic_module_path.parent)
    # 创建动态模块的目录,如果目录已存在则不做操作
    os.makedirs(dynamic_module_path, exist_ok=True)
    # 在动态模块目录下创建 __init__.py 文件,如果文件已存在则不做操作
    init_path = dynamic_module_path / "__init__.py"
    if not init_path.exists():
        init_path.touch()
        # 清除 importlib 缓存,确保新创建的模块可以被正确加载
        importlib.invalidate_caches()


def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
    """
    Get the list of modules that are relatively imported in a module file.

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.

    Returns:
        `List[str]`: The list of relative imports in the module.
    """
    # 使用 `utf-8` 编码打开指定文件 `module_file` 并读取其内容
    with open(module_file, "r", encoding="utf-8") as f:
        content = f.read()
    
    # 查找内容中形如 `import .xxx` 的相对导入语句,并将结果存入 `relative_imports`
    relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
    # 查找内容中形如 `from .xxx import yyy` 的相对导入语句,并将结果追加到 `relative_imports`
    relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
    # 将 `relative_imports` 列表转换为集合,以去除重复项,然后再转换回列表形式
    return list(set(relative_imports))
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
    """
    Get the list of all files that are needed for a given module. Note that this function recurses through the relative
    imports (if a imports b and b imports c, it will return module files for b and c).

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.

    Returns:
        `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
        of module files a given module needs.
    """
    no_change = False  # 标志变量,用于检测是否有新的相对导入被找到
    files_to_check = [module_file]  # 初始时待检查的文件列表,从传入的模块文件开始
    all_relative_imports = []  # 存储所有找到的相对导入模块文件的列表

    # Let's recurse through all relative imports
    while not no_change:  # 进入循环,直到没有新的相对导入被找到为止
        new_imports = []
        for f in files_to_check:
            new_imports.extend(get_relative_imports(f))  # 递归获取当前文件 f 的相对导入

        module_path = Path(module_file).parent  # 获取传入模块文件的父目录路径
        new_import_files = [str(module_path / m) for m in new_imports]  # 构建新的相对导入文件列表
        new_import_files = [f for f in new_import_files if f not in all_relative_imports]  # 去重,确保不重复添加
        files_to_check = [f"{f}.py" for f in new_import_files]  # 将新的相对导入文件名列表加上 '.py' 后缀,准备下一轮检查

        no_change = len(new_import_files) == 0  # 如果没有新的相对导入被找到,则结束循环
        all_relative_imports.extend(files_to_check)  # 将新找到的相对导入文件列表加入到总列表中

    return all_relative_imports  # 返回所有找到的相对导入文件列表


def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
    """
    Extracts all the libraries (not relative imports this time) that are imported in a file.

    Args:
        filename (`str` or `os.PathLike`): The module file to inspect.

    Returns:
        `List[str]`: The list of all packages required to use the input module.
    """
    with open(filename, "r", encoding="utf-8") as f:
        content = f.read()  # 读取文件内容

    # filter out try/except block so in custom code we can have try/except imports
    content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)

    # Imports of the form `import xxx`
    imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
    # Imports of the form `from xxx import yyy`
    imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
    # Only keep the top-level module
    imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]  # 提取导入的顶级模块名称
    return list(set(imports))  # 返回去重后的模块名称列表


def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
    """
    Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
    library is missing.

    Args:
        filename (`str` or `os.PathLike`): The module file to check.

    Returns:
        `List[str]`: The list of relative imports in the file.
    """
    imports = get_imports(filename)  # 获取文件中所有的非相对导入模块名称
    missing_packages = []
    for imp in imports:
        try:
            importlib.import_module(imp)  # 尝试导入模块,如果失败则捕获 ImportError
        except ImportError:
            missing_packages.append(imp)  # 将缺失的模块名称加入到缺失列表中
    # 检查缺失的包列表是否有内容
    if len(missing_packages) > 0:
        # 如果有缺失的包,则抛出 ImportError 异常,提示用户缺少哪些包
        raise ImportError(
            "This modeling file requires the following packages that were not found in your environment: "
            f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
        )

    # 如果没有缺失的包,返回模块文件的相对导入路径列表
    return get_relative_imports(filename)
# 从指定的模块文件中获取指定名称的类对象

def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
    """
    Import a module on the cache directory for modules and extract a class from it.

    Args:
        class_name (`str`): The name of the class to import.
        module_path (`str` or `os.PathLike`): The path to the module to import.

    Returns:
        `typing.Type`: The class looked for.
    """
    # 标准化模块路径,替换路径分隔符和去掉.py后缀,生成模块名
    name = os.path.normpath(module_path).replace(".py", "").replace(os.path.sep, ".")
    # 构建模块文件的完整路径
    module_path = str(Path(HF_MODULES_CACHE) / module_path)
    # 使用 SourceFileLoader 加载模块文件并返回模块对象
    module = importlib.machinery.SourceFileLoader(name, module_path).load_module()
    # 从加载的模块中获取指定名称的类对象并返回
    return getattr(module, class_name)


def get_cached_module_file(
    pretrained_model_name_or_path: Union[str, os.PathLike],
    module_file: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    repo_type: Optional[str] = None,
    _commit_hash: Optional[str] = None,
    **deprecated_kwargs,
) -> str:
    """
    Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
    Transformers module.
    """
    # 从参数中弹出并获取 `use_auth_token`,用于兼容旧的参数命名
    use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
    # 如果 `use_auth_token` 参数被指定了,则发出警告并提示将在 Transformers v5 版本中移除
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        # 如果 `token` 参数也被指定了,则引发 ValueError,因为不能同时设置 `token` 和 `use_auth_token`
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        # 将 `token` 参数设置为 `use_auth_token` 的值,以实现向后兼容性
        token = use_auth_token
    # 如果处于离线模式且不限制只使用本地文件,则强制设置 local_files_only=True
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

    # 将 pretrained_model_name_or_path 转换为字符串类型
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
    # 检查 pretrained_model_name_or_path 是否为本地目录
    is_local = os.path.isdir(pretrained_model_name_or_path)
    if is_local:
        # 如果是本地目录,则 submodule 为该目录的基本名称
        submodule = os.path.basename(pretrained_model_name_or_path)
    else:
        # 如果不是本地目录,则将 pretrained_model_name_or_path 中的 '/' 替换为系统路径分隔符
        submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
        # 尝试从缓存中加载模块文件
        cached_module = try_to_load_from_cache(
            pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
        )

    # 用于存储新添加的文件
    new_files = []
    try:
        # 尝试从 URL 或缓存中加载模块文件
        resolved_module_file = cached_file(
            pretrained_model_name_or_path,
            module_file,
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            repo_type=repo_type,
            _commit_hash=_commit_hash,
        )
        # 如果不是本地模式且缓存的模块文件与解析的模块文件不同,则将模块文件添加到 new_files 列表中
        if not is_local and cached_module != resolved_module_file:
            new_files.append(module_file)

    # 如果发生环境错误,则记录错误信息并抛出异常
    except EnvironmentError:
        logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
        raise

    # 检查当前环境中是否存在所需的模块
    modules_needed = check_imports(resolved_module_file)

    # 将模块移动到缓存的动态模块中
    full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
    create_dynamic_module(full_submodule)
    submodule_path = Path(HF_MODULES_CACHE) / full_submodule

    # 如果 submodule 是 pretrained_model_name_or_path 的基本名称
    if submodule == os.path.basename(pretrained_model_name_or_path):
        # 为避免在 sys.path 中添加过多文件夹,将本地文件复制到 submodule_path 中
        # 当文件是新的或自上次复制以来已更改时执行复制操作
        if not (submodule_path / module_file).exists() or not filecmp.cmp(
            resolved_module_file, str(submodule_path / module_file)
        ):
            shutil.copy(resolved_module_file, submodule_path / module_file)
            importlib.invalidate_caches()

        # 复制所需的模块文件到 submodule_path 中
        for module_needed in modules_needed:
            module_needed = f"{module_needed}.py"
            module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
            if not (submodule_path / module_needed).exists() or not filecmp.cmp(
                module_needed_file, str(submodule_path / module_needed)
            ):
                shutil.copy(module_needed_file, submodule_path / module_needed)
                importlib.invalidate_caches()
    else:
        # 提取提交哈希值
        commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)

        # 模块文件将被放置在具有存储库 git 哈希的子文件夹中,以便进行版本控制。
        submodule_path = submodule_path / commit_hash
        full_submodule = full_submodule + os.path.sep + commit_hash
        create_dynamic_module(full_submodule)

        # 如果子模块路径下的模块文件不存在,则复制已解析的模块文件
        if not (submodule_path / module_file).exists():
            shutil.copy(resolved_module_file, submodule_path / module_file)
            importlib.invalidate_caches()

        # 确保我们也有每个相对的文件
        for module_needed in modules_needed:
            # 如果子模块路径下的模块文件不存在,则获取缓存的模块文件
            if not (submodule_path / f"{module_needed}.py").exists():
                get_cached_module_file(
                    pretrained_model_name_or_path,
                    f"{module_needed}.py",
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                    token=token,
                    revision=revision,
                    local_files_only=local_files_only,
                    _commit_hash=commit_hash,
                )
                new_files.append(f"{module_needed}.py")

    # 如果有新的文件被下载并且没有指定 revision,则生成警告消息
    if len(new_files) > 0 and revision is None:
        new_files = "\n".join([f"- {f}" for f in new_files])
        repo_type_str = "" if repo_type is None else f"{repo_type}s/"
        url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
        logger.warning(
            f"A new version of the following files was downloaded from {url}:\n{new_files}"
            "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
            "versions of the code file, you can pin a revision."
        )

    # 返回完整的子模块路径和模块文件名
    return os.path.join(full_submodule, module_file)
# 从动态模块中获取指定类的定义
def get_class_from_dynamic_module(
    # 类的完整引用路径,例如 "module.submodule.ClassName"
    class_reference: str,
    # 预训练模型的名称或路径,可以是字符串或路径对象
    pretrained_model_name_or_path: Union[str, os.PathLike],
    # 缓存目录的路径,可选参数,默认为 None
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    # 是否强制重新下载模型文件,默认为 False
    force_download: bool = False,
    # 是否恢复之前中断的下载,默认为 False
    resume_download: bool = False,
    # 可选的代理设置,字典类型,用于网络请求
    proxies: Optional[Dict[str, str]] = None,
    # 访问模型所需的令牌,可以是布尔值或字符串,可选
    token: Optional[Union[bool, str]] = None,
    # 模型所在仓库的版本号或标签,可选
    revision: Optional[str] = None,
    # 是否仅使用本地已有文件,默认为 False
    local_files_only: bool = False,
    # 仓库类型,例如 git、hg 等,可选
    repo_type: Optional[str] = None,
    # 代码的特定版本号,可选
    code_revision: Optional[str] = None,
    # 其他参数作为关键字参数传递,用于模块初始化
    **kwargs,
) -> typing.Type:
    """
    从本地文件夹或模型仓库中提取一个类的定义。

    <Tip warning={true}>

    调用此函数将执行本地或从 Hub 下载的模块文件中的代码。因此,应仅在可信任的仓库中调用。

    </Tip>
    # 加载指定类的配置和模型数据
    Args:
        class_reference (`str`):
            要加载的类的完整名称,包括其模块和可选的存储库。
        pretrained_model_name_or_path (`str` or `os.PathLike`):
            可以是以下之一:

            - 字符串,表示在 huggingface.co 模型仓库中预训练模型配置的 *模型 ID*。
            - 目录路径,包含使用 [`~PreTrainedTokenizer.save_pretrained`] 方法保存的配置文件,例如 `./my_model_directory/`。

            当 `class_reference` 没有指定其他存储库时使用。
        module_file (`str`):
            包含要查找的类的模块文件名。
        class_name (`str`):
            要在模块中导入的类的名称。
        cache_dir (`str` or `os.PathLike`, *optional*):
            下载预训练模型配置时应该缓存的目录路径,如果不想使用标准缓存。
        force_download (`bool`, *optional*, defaults to `False`):
            是否强制下载配置文件,并覆盖已存在的缓存版本。
        resume_download (`bool`, *optional*, defaults to `False`):
            是否删除未完全接收的文件。如果存在这样的文件,则尝试恢复下载。
        proxies (`Dict[str, str]`, *optional*):
            使用的代理服务器字典,按协议或端点分组,例如 `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`。
            代理服务器会在每个请求上使用。
        token (`str` or `bool`, *optional*):
            用作远程文件的 HTTP Bearer 授权令牌。如果是 `True`,将使用运行 `huggingface-cli login` 时生成的令牌(存储在 `~/.huggingface` 中)。
        revision (`str`, *optional*, defaults to `"main"`):
            要使用的特定模型版本。可以是分支名称、标签名称或提交 ID。由于我们在 huggingface.co 上使用基于 Git 的系统存储模型和其他工件,因此 `revision` 可以是 Git 允许的任何标识符。
        local_files_only (`bool`, *optional*, defaults to `False`):
            如果为 `True`,将仅尝试从本地文件加载 tokenizer 配置。
        repo_type (`str`, *optional*):
            指定存储库类型(在下载时特别有用,例如从空间下载)。
        code_revision (`str`, *optional*, defaults to `"main"`):
            在 Hub 上使用的代码的特定版本。如果代码存储在与模型其余部分不同的存储库中,可以是分支名称、标签名称或提交 ID。由于我们在 huggingface.co 上使用基于 Git 的系统存储模型和其他工件,因此 `revision` 可以是 Git 允许的任何标识符。
    Passing `token=True` is required when you want to use a private model.



    </Tip>



    Returns:
        `typing.Type`: The class, dynamically imported from the module.



    Examples:

    ```
    # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
    # module.
    cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")

    # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
    # module.
    cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
    ```"""



    use_auth_token = kwargs.pop("use_auth_token", None)
    if use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
            FutureWarning,
        )
        if token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        token = use_auth_token



    # Catch the name of the repo if it's specified in `class_reference`
    if "--" in class_reference:
        repo_id, class_reference = class_reference.split("--")
    else:
        repo_id = pretrained_model_name_or_path
    module_file, class_name = class_reference.split(".")



    if code_revision is None and pretrained_model_name_or_path == repo_id:
        code_revision = revision
    # And lastly we get the class inside our newly created module
    final_module = get_cached_module_file(
        repo_id,
        module_file + ".py",
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        token=token,
        revision=code_revision,
        local_files_only=local_files_only,
        repo_type=repo_type,
    )
    return get_class_in_module(class_name, final_module)
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
    """
    Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
    adds the proper fields in a config.

    Args:
        obj (`Any`): The object for which to save the module files.
        folder (`str` or `os.PathLike`): The folder where to save.
        config (`PretrainedConfig` or dictionary, `optional`):
            A config in which to register the auto_map corresponding to this custom object.

    Returns:
        `List[str]`: The list of files saved.
    """
    # Check if the object is defined in the '__main__' module; issue a warning if true and return.
    if obj.__module__ == "__main__":
        logger.warning(
            f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
            "this code in a separate module so we can include it in the saved folder and make it easier to share via "
            "the Hub."
        )
        return

    def _set_auto_map_in_config(_config):
        # Get the module name where the object's class is defined.
        module_name = obj.__class__.__module__
        # Extract the last module name from the full module path.
        last_module = module_name.split(".")[-1]
        # Construct the full name of the object's class.
        full_name = f"{last_module}.{obj.__class__.__name__}"

        # Special handling for tokenizers
        if "Tokenizer" in full_name:
            slow_tokenizer_class = None
            fast_tokenizer_class = None
            if obj.__class__.__name__.endswith("Fast"):
                # For fast tokenizers, capture the fast tokenizer class and check for a slow tokenizer attribute.
                fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
                if getattr(obj, "slow_tokenizer_class", None) is not None:
                    slow_tokenizer = getattr(obj, "slow_tokenizer_class")
                    slow_tok_module_name = slow_tokenizer.__module__
                    last_slow_tok_module = slow_tok_module_name.split(".")[-1]
                    slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
            else:
                # For slow tokenizers, only record the slow tokenizer class.
                slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"

            # Assign both tokenizer classes to full_name.
            full_name = (slow_tokenizer_class, fast_tokenizer_class)

        # Update the auto_map in the provided config.
        if isinstance(_config, dict):
            auto_map = _config.get("auto_map", {})
            auto_map[obj._auto_class] = full_name
            _config["auto_map"] = auto_map
        elif getattr(_config, "auto_map", None) is not None:
            _config.auto_map[obj._auto_class] = full_name
        else:
            _config.auto_map = {obj._auto_class: full_name}

    # Add object class to the config auto_map based on the type of config provided.
    if isinstance(config, (list, tuple)):
        for cfg in config:
            _set_auto_map_in_config(cfg)
    elif config is not None:
        _set_auto_map_in_config(config)

    result = []
    # Get the file path of the module where the object's class is defined.
    object_file = sys.modules[obj.__module__].__file__
    # 构建目标文件路径,将对象文件复制到目标路径中
    dest_file = Path(folder) / (Path(object_file).name)
    shutil.copy(object_file, dest_file)
    result.append(dest_file)

    # 递归获取对象文件的所有相对导入文件,并确保它们也被复制到目标路径中
    for needed_file in get_relative_import_files(object_file):
        # 构建相对导入文件的目标路径,复制文件到目标路径中
        dest_file = Path(folder) / (Path(needed_file).name)
        shutil.copy(needed_file, dest_file)
        result.append(dest_file)

    # 返回复制操作完成后的结果列表
    return result
# 定义一个处理超时错误的函数,当超时发生时抛出 ValueError 异常
def _raise_timeout_error(signum, frame):
    raise ValueError(
        "Loading this model requires you to execute custom code contained in the model repository on your local "
        "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
    )

# 设定远程代码加载超时时间为 15 秒
TIME_OUT_REMOTE_CODE = 15

# 解析是否信任远程代码的函数,根据不同情况返回信任标志
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
    # 如果未设置信任远程代码选项
    if trust_remote_code is None:
        # 如果本地存在代码,则默认不信任远程代码
        if has_local_code:
            trust_remote_code = False
        # 如果存在远程代码且设置了正超时时间,则尝试获取用户输入以决定是否信任远程代码
        elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
            try:
                # 设置信号处理函数为 _raise_timeout_error,并启动超时定时器
                signal.signal(signal.SIGALRM, _raise_timeout_error)
                signal.alarm(TIME_OUT_REMOTE_CODE)
                # 在用户未作出决定之前循环提示
                while trust_remote_code is None:
                    answer = input(
                        f"The repository for {model_name} contains custom code which must be executed to correctly "
                        f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
                        f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
                        f"Do you wish to run the custom code? [y/N] "
                    )
                    # 根据用户输入确定是否信任远程代码
                    if answer.lower() in ["yes", "y", "1"]:
                        trust_remote_code = True
                    elif answer.lower() in ["no", "n", "0", ""]:
                        trust_remote_code = False
                # 取消超时定时器
                signal.alarm(0)
            except Exception:
                # 捕获可能出现的异常(如操作系统不支持 signal.SIGALRM)
                raise ValueError(
                    f"The repository for {model_name} contains custom code which must be executed to correctly "
                    f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
                )
        # 对于存在远程代码但超时时间设置为 0 的情况,抛出超时错误
        elif has_remote_code:
            _raise_timeout_error(None, None)

    # 如果存在远程代码但本地没有代码且用户不信任远程代码,则抛出 ValueError 异常
    if has_remote_code and not has_local_code and not trust_remote_code:
        raise ValueError(
            f"Loading {model_name} requires you to execute the configuration file in that"
            " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
            " set the option `trust_remote_code=True` to remove this error."
        )

    # 返回最终的信任远程代码标志
    return trust_remote_code

.\feature_extraction_sequence_utils.py

`
"""
Sequence feature extraction class for common feature extractors to preprocess sequences.
"""
# 导入必要的模块和库
from typing import Dict, List, Optional, Union  # 导入类型提示相关模块

import numpy as np  # 导入 NumPy 库

from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin  # 导入自定义模块
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy  # 导入自定义工具模块

logger = logging.get_logger(__name__)  # 获取日志记录器对象


class SequenceFeatureExtractor(FeatureExtractionMixin):
    """
    This is a general feature extraction class for speech recognition.

    Args:
        feature_size (`int`):
            The feature dimension of the extracted features.
        sampling_rate (`int`):
            The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
        padding_value (`float`):
            The value that is used to fill the padding values / vectors.
    """

    def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
        self.feature_size = feature_size  # 初始化特征维度大小
        self.sampling_rate = sampling_rate  # 初始化采样率
        self.padding_value = padding_value  # 初始化填充值

        self.padding_side = kwargs.pop("padding_side", "right")  # 初始化填充位置,默认为右侧
        self.return_attention_mask = kwargs.pop("return_attention_mask", True)  # 是否返回注意力掩码,默认为 True

        super().__init__(**kwargs)  # 调用父类初始化方法

    def pad(
        self,
        processed_features: Union[
            BatchFeature,
            List[BatchFeature],
            Dict[str, BatchFeature],
            Dict[str, List[BatchFeature]],
            List[Dict[str, BatchFeature]],
        ],
        padding: Union[bool, str, PaddingStrategy] = True,
        max_length: Optional[int] = None,
        truncation: bool = False,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ):
        """
        Pad sequences of features to the same length.

        Args:
            processed_features (Union[BatchFeature, List[BatchFeature], Dict[str, BatchFeature], ...]):
                The processed features to be padded.
            padding (Union[bool, str, PaddingStrategy]):
                Strategy for padding. Can be a boolean, string, or enum from PaddingStrategy.
            max_length (Optional[int]):
                Maximum length to pad or truncate the sequences.
            truncation (bool):
                Whether to truncate sequences that exceed `max_length`.
            pad_to_multiple_of (Optional[int]):
                Pad to a multiple of this value.
            return_attention_mask (Optional[bool]):
                Whether to return attention masks.
            return_tensors (Optional[Union[str, TensorType]]):
                The type of tensor(s) to be returned.

        Returns:
            Padded sequences of features.
        """
        pass  # Placeholder for method implementation

    def _pad(
        self,
        processed_features: Union[Dict[str, np.ndarray], BatchFeature],
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_attention_mask: Optional[bool] = None,
    ):
        """
        Internal method for padding sequences of features.

        Args:
            processed_features (Union[Dict[str, np.ndarray], BatchFeature]):
                The processed features to be padded.
            max_length (Optional[int]):
                Maximum length to pad or truncate the sequences.
            padding_strategy (PaddingStrategy):
                Strategy for padding. Default is DO_NOT_PAD.
            pad_to_multiple_of (Optional[int]):
                Pad to a multiple of this value.
            return_attention_mask (Optional[bool]):
                Whether to return attention masks.
        """
        pass  # Placeholder for method implementation

    def _truncate(
        self,
        processed_features: Union[Dict[str, np.ndarray], BatchFeature],
        max_length: Optional[int] = None,
        pad_to_multiple_of: Optional[int] = None,
        truncation: Optional[bool] = None,
    ):
        """
        Internal method for truncating sequences of features.

        Args:
            processed_features (Union[Dict[str, np.ndarray], BatchFeature]):
                The processed features to be truncated.
            max_length (Optional[int]):
                Maximum length to truncate the sequences.
            pad_to_multiple_of (Optional[int]):
                Pad to a multiple of this value.
            truncation (Optional[bool]):
                Whether to truncate sequences that exceed `max_length`.
        """
        pass  # Placeholder for method implementation
        """
        Truncate inputs to predefined length or max length in the batch

        Args:
            processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
                Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
                of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
            max_length (`int`, *optional*):
                maximum length of the returned list and optionally padding length (see below)
            pad_to_multiple_of (`int`, *optional*) :
                Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
                enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
                which benefit from having sequence lengths be a multiple of 128.
            truncation (`bool`, *optional*):
                Activates truncation to cut input sequences longer than `max_length` to `max_length`.
        """
        # 如果不进行截断,则直接返回处理后的特征
        if not truncation:
            return processed_features
        # 如果需要截断但未指定最大长度,则抛出数值错误异常
        elif truncation and max_length is None:
            raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")

        # 获取需要进行截断的输入数据
        required_input = processed_features[self.model_input_names[0]]

        # 根据 `pad_to_multiple_of` 找到适合的 `max_length`
        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        # 判断是否需要进行截断操作
        needs_to_be_truncated = len(required_input) > max_length

        # 如果需要截断,则对输入数据进行截断操作
        if needs_to_be_truncated:
            processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
            # 如果存在 `attention_mask`,则同步截断 `attention_mask`
            if "attention_mask" in processed_features:
                processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]

        # 返回处理后的特征
        return processed_features
    def _get_padding_strategies(self, padding=False, max_length=None):
        """
        Find the correct padding strategy
        """

        # 获取填充策略
        if padding is not False:
            if padding is True:
                padding_strategy = PaddingStrategy.LONGEST  # 默认为将批次中的序列填充到最长的序列长度
            elif not isinstance(padding, PaddingStrategy):
                padding_strategy = PaddingStrategy(padding)
            elif isinstance(padding, PaddingStrategy):
                padding_strategy = padding
        else:
            padding_strategy = PaddingStrategy.DO_NOT_PAD  # 不进行填充

        # 如果需要,设置最大长度
        if max_length is None:
            if padding_strategy == PaddingStrategy.MAX_LENGTH:
                raise ValueError(
                    f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
                )

        # 检查是否有填充值
        if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
            raise ValueError(
                "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
                " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
            )

        return padding_strategy

.\feature_extraction_utils.py

# 设置文件编码为 UTF-8
# 版权声明及许可信息
#
# 根据 Apache 许可证 2.0 版本,除非符合许可证要求或书面同意,否则禁止使用此文件
# 可以通过访问 http://www.apache.org/licenses/LICENSE-2.0 获取许可证的副本
#
# 除非适用法律要求或书面同意,本软件是基于“按现状”提供的,不提供任何形式的担保或条件,无论是明示的还是默示的
# 有关许可证的详细信息,请参阅许可证文本

"""
用于常见特征提取器的特征提取保存/加载的类。
"""

import copy  # 导入深复制模块
import json  # 导入 JSON 处理模块
import os  # 导入操作系统模块
import warnings  # 导入警告模块
from collections import UserDict  # 导入用户字典模块
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union  # 导入类型提示相关模块

import numpy as np  # 导入 NumPy 库

from .dynamic_module_utils import custom_object_save  # 导入自定义对象保存函数
from .utils import (  # 导入工具函数
    FEATURE_EXTRACTOR_NAME,
    PushToHubMixin,
    TensorType,
    add_model_info_to_auto_map,
    cached_file,
    copy_func,
    download_url,
    is_flax_available,
    is_jax_tensor,
    is_numpy_array,
    is_offline_mode,
    is_remote_url,
    is_tf_available,
    is_torch_available,
    is_torch_device,
    is_torch_dtype,
    logging,
    requires_backends,
)

if TYPE_CHECKING:  # 如果是类型检查阶段
    if is_torch_available():  # 如果 Torch 可用
        import torch  # 导入 Torch 库(用于类型检查)

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"]  # 预训练特征提取器类型定义

class BatchFeature(UserDict):  # 批次特征类,继承自用户字典
    r"""
    Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.

    This class is derived from a python dictionary and can be used as a dictionary.

    Args:
        data (`dict`, *optional*):
            Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
            etc.).
        tensor_type (`Union[None, str, TensorType]`, *optional*):
            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
            initialization.
    """

    def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
        super().__init__(data)  # 调用父类的初始化方法
        self.convert_to_tensors(tensor_type=tensor_type)  # 将数据转换为张量类型

    def __getitem__(self, item: str) -> Union[Any]:
        """
        If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
        etc.).
        """
        if isinstance(item, str):  # 如果索引是字符串类型
            return self.data[item]  # 返回字典中与键关联的值
        else:
            raise KeyError("Indexing with integers is not available when using Python based feature extractors")  # 抛出索引错误

    def __getattr__(self, item: str):
        try:
            return self.data[item]  # 返回属性对应的数据项
        except KeyError:
            raise AttributeError  # 抛出属性错误

    def __getstate__(self):
        return {"data": self.data}  # 返回对象的状态信息
    # 实现对象状态的反序列化方法,如果状态中包含"data"字段,则将其赋值给当前对象的"data"属性
    def __setstate__(self, state):
        if "data" in state:
            self.data = state["data"]

    # 从self.data中获取所有键的方法,模仿transformers.tokenization_utils_base.BatchEncoding.keys的功能
    def keys(self):
        return self.data.keys()

    # 从self.data中获取所有值的方法,模仿transformers.tokenization_utils_base.BatchEncoding.values的功能
    def values(self):
        return self.data.values()

    # 从self.data中获取所有键值对的方法,模仿transformers.tokenization_utils_base.BatchEncoding.items的功能
    def items(self):
        return self.data.items()

    # 根据指定的tensor_type获取对应的转换和判断函数
    def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
        if tensor_type is None:
            return None, None

        # 将tensor_type转换为TensorType类型
        if not isinstance(tensor_type, TensorType):
            tensor_type = TensorType(tensor_type)

        # 根据tensor_type选择合适的框架,并获取相应的转换和判断函数
        if tensor_type == TensorType.TENSORFLOW:
            # 如果选择的是TensorFlow,则检查TensorFlow是否可用,若不可用则抛出ImportError异常
            if not is_tf_available():
                raise ImportError(
                    "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
                )
            import tensorflow as tf

            as_tensor = tf.constant  # 定义TensorFlow下的转换函数
            is_tensor = tf.is_tensor  # 定义TensorFlow下的判断函数
        elif tensor_type == TensorType.PYTORCH:
            # 如果选择的是PyTorch,则检查PyTorch是否可用,若不可用则抛出ImportError异常
            if not is_torch_available():
                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
            import torch  # noqa

            # 定义PyTorch下的转换函数
            def as_tensor(value):
                if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
                    value = np.array(value)
                return torch.tensor(value)

            is_tensor = torch.is_tensor  # 定义PyTorch下的判断函数
        elif tensor_type == TensorType.JAX:
            # 如果选择的是JAX,则检查JAX是否可用,若不可用则抛出ImportError异常
            if not is_flax_available():
                raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
            import jax.numpy as jnp  # noqa: F811

            as_tensor = jnp.array  # 定义JAX下的转换函数
            is_tensor = is_jax_tensor  # 定义JAX下的判断函数
        else:
            # 如果未知的tensor_type,则使用通用的转换函数
            def as_tensor(value, dtype=None):
                if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
                    value_lens = [len(val) for val in value]
                    if len(set(value_lens)) > 1 and dtype is None:
                        # 处理不规则列表
                        value = as_tensor([np.asarray(val) for val in value], dtype=object)
                return np.asarray(value, dtype=dtype)

            is_tensor = is_numpy_array  # 定义通用的判断函数

        return is_tensor, as_tensor  # 返回判断函数和转换函数
    def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
        """
        Convert the inner content to tensors.

        Args:
            tensor_type (`str` or [`~utils.TensorType`], *optional*):
                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
                `None`, no modification is done.
        """
        # 如果 tensor_type 为 None,则直接返回当前对象,不进行任何修改
        if tensor_type is None:
            return self

        # 获取适合转换成指定类型张量的函数
        is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)

        # 在批量处理中进行张量转换
        for key, value in self.items():
            try:
                # 如果当前值不是张量,则尝试转换为指定类型的张量
                if not is_tensor(value):
                    tensor = as_tensor(value)

                    # 更新当前键对应的值为转换后的张量
                    self[key] = tensor
            except:  # noqa E722
                # 处理异常情况,特别是针对不同长度的溢出值处理
                if key == "overflowing_values":
                    raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
                raise ValueError(
                    "Unable to create tensor, you should probably activate padding "
                    "with 'padding=True' to have batched tensors with the same length."
                )

        # 返回转换后的对象
        return self
    def to(self, *args, **kwargs) -> "BatchFeature":
        """
        Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
        different `dtypes` and sending the `BatchFeature` to a different `device`.

        Args:
            args (`Tuple`):
                Will be passed to the `to(...)` function of the tensors.
            kwargs (`Dict`, *optional*):
                Will be passed to the `to(...)` function of the tensors.

        Returns:
            [`BatchFeature`]: The same instance after modification.
        """
        # Ensure that PyTorch backend is available
        requires_backends(self, ["torch"])
        import torch  # noqa
        
        # Initialize a new dictionary for modified data
        new_data = {}
        
        # Retrieve the device from kwargs if available
        device = kwargs.get("device")
        
        # Check if the first argument in args is a device or dtype
        if device is None and len(args) > 0:
            # If device is not specified, the first argument in args is used
            arg = args[0]
            if is_torch_dtype(arg):
                # If the first argument is a PyTorch dtype
                pass
            elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
                # If the first argument is a device or a dtype specifier
                device = arg
            else:
                # If the first argument is of an unsupported type
                raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
        
        # Iterate over key-value pairs in the current instance
        for k, v in self.items():
            # Check if the value v is a floating point tensor
            if torch.is_floating_point(v):
                # If v is floating point, cast and send it to the specified device or dtype
                new_data[k] = v.to(*args, **kwargs)
            elif device is not None:
                # If a device is specified, send v to that device
                new_data[k] = v.to(device=device)
            else:
                # Otherwise, retain v as it is
                new_data[k] = v
        
        # Update the data attribute of the instance with the modified data
        self.data = new_data
        
        # Return the modified instance of BatchFeature
        return self
    """
    # 这是一个特征提取的 Mixin 类,用于为顺序数据和图像特征提取器提供保存和加载功能。
    """

    _auto_class = None

    def __init__(self, **kwargs):
        """
        # 初始化方法,将 kwargs 中的元素设置为对象的属性。
        """
        # 弹出 "processor_class" 作为私有属性,用于保存处理器类信息
        self._processor_class = kwargs.pop("processor_class", None)
        # 处理额外的属性,这些属性没有默认值
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                logger.error(f"Can't set {key} with value {value} for {self}")
                raise err

    def _set_processor_class(self, processor_class: str):
        """
        # 设置处理器类作为对象的属性。
        """
        self._processor_class = processor_class

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
    ):
        """
        # 从预训练模型或路径加载类实例,并配置相关参数。
        """
    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
        Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
        [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
        """
        use_auth_token = kwargs.pop("use_auth_token", None)

        # Handle deprecated `use_auth_token` argument
        if use_auth_token is not None:
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            # Raise an error if both `token` and `use_auth_token` are specified
            if kwargs.get("token", None) is not None:
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            kwargs["token"] = use_auth_token

        # Assert that the provided path is a directory, not a file
        if os.path.isfile(save_directory):
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

        # Create the directory if it does not exist
        os.makedirs(save_directory, exist_ok=True)

        # If push_to_hub is True, prepare to push the model to the model hub
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            # Determine the repository ID from the save_directory name
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            # Create or get the repository ID for the model
            repo_id = self._create_repo(repo_id, **kwargs)
            # Get timestamps of files in save_directory for tracking changes
            files_timestamps = self._get_files_timestamps(save_directory)

        # If there's a custom config, save it in the directory
        if self._auto_class is not None:
            custom_object_save(self, save_directory, config=self)

        # Save the feature extractor JSON file in save_directory
        output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
        self.to_json_file(output_feature_extractor_file)
        logger.info(f"Feature extractor saved in {output_feature_extractor_file}")

        # If push_to_hub is True, upload modified files to the model hub
        if push_to_hub:
            self._upload_modified_files(
                save_directory,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=kwargs.get("token"),
            )

        # Return the list containing the path to the saved feature extractor file
        return [output_feature_extractor_file]

    @classmethod
    @classmethod
    # 类方法:从预训练模型名称或路径和其他关键字参数中获取特征提取器字典
    def get_feature_extractor_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ):
        """
        Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
        parameters.

        Args:
            feature_extractor_dict (`Dict[str, Any]`):
                Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
                retrieved from a pretrained checkpoint by leveraging the
                [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
            kwargs (`Dict[str, Any]`):
                Additional parameters from which to initialize the feature extractor object.

        Returns:
            [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
            parameters.
        """
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        # 使用 feature_extractor_dict 字典创建特征提取器对象
        feature_extractor = cls(**feature_extractor_dict)

        # 如果需要,用 kwargs 更新 feature_extractor
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(feature_extractor, key):
                setattr(feature_extractor, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

        # 记录日志,展示创建的特征提取器对象
        logger.info(f"Feature extractor {feature_extractor}")

        # 如果需要返回未使用的关键字参数,则返回特征提取器对象和未使用的 kwargs
        if return_unused_kwargs:
            return feature_extractor, kwargs
        else:
            return feature_extractor

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary. Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
        """
        # 深拷贝对象的属性到 output 字典
        output = copy.deepcopy(self.__dict__)
        output["feature_extractor_type"] = self.__class__.__name__

        # 如果存在 "mel_filters" 属性,则从 output 字典中删除
        if "mel_filters" in output:
            del output["mel_filters"]

        # 如果存在 "window" 属性,则从 output 字典中删除
        if "window" in output:
            del output["window"]

        return output

    @classmethod
    # 类方法:从 JSON 文件中实例化特征提取器对象
    def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
        """
        Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
        a JSON file of parameters.

        Args:
            json_file (`str` or `os.PathLike`):
                Path to the JSON file containing the parameters.

        Returns:
            A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
            object instantiated from that JSON file.
        """
        # 从 JSON 文件中读取参数文本
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()

        # 将 JSON 文本解析为字典形式的特征提取器参数
        feature_extractor_dict = json.loads(text)

        # 使用特征提取器参数字典创建特征提取器对象并返回
        return cls(**feature_extractor_dict)
    def to_json_string(self) -> str:
        """
        Serializes this instance to a JSON string.

        Returns:
            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
        """
        # 将对象转换为字典表示
        dictionary = self.to_dict()

        # 将所有 numpy 数组转换为 Python 列表
        for key, value in dictionary.items():
            if isinstance(value, np.ndarray):
                dictionary[key] = value.tolist()

        # 确保私有名称 "_processor_class" 保存为 "processor_class"
        _processor_class = dictionary.pop("_processor_class", None)
        if _processor_class is not None:
            dictionary["processor_class"] = _processor_class

        # 将字典转换为带有缩进和排序的 JSON 字符串,最后加上换行符
        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this feature_extractor instance's parameters will be saved.
        """
        # 将对象序列化为 JSON 字符串,写入指定路径的文件中
        with open(json_file_path, "w", encoding="utf-8") as writer:
            writer.write(self.to_json_string())

    def __repr__(self):
        # 返回类的字符串表示形式,包括其 JSON 序列化的内容
        return f"{self.__class__.__name__} {self.to_json_string()}"

    @classmethod
    def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
        """
        Register this class with a given auto class. This should only be used for custom feature extractors as the ones
        in the library are already mapped with `AutoFeatureExtractor`.

        <Tip warning={true}>

        This API is experimental and may have some slight breaking changes in the next releases.

        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
                The auto class to register this new feature extractor with.
        """
        # 如果 auto_class 不是字符串,则取其类名作为字符串
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        # 导入自动模块,并检查是否存在指定的 auto_class
        import transformers.models.auto as auto_module

        if not hasattr(auto_module, auto_class):
            # 如果找不到对应的 auto_class,则抛出错误
            raise ValueError(f"{auto_class} is not a valid auto class.")

        # 将 auto_class 设置为类属性 _auto_class
        cls._auto_class = auto_class
# 将 FeatureExtractionMixin 类中的 push_to_hub 方法复制一份,使其成为独立的新函数
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)

# 如果 push_to_hub 方法已经有文档字符串(即注释),则对其进行格式化,填充特定的对象信息
if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
    FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
        object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
    )

.\file_utils.py

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File utilities: utilities related to download and cache models

This module should not be update anymore and is only left for backward compatibility.
"""

# 导入获取完整仓库名称的函数,用于向后兼容
from huggingface_hub import get_full_repo_name  # for backward compatibility
# 导入禁用遥测的常量,用于向后兼容
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY  # for backward compatibility

# 导入当前模块的版本信息
from . import __version__

# 向后兼容的导入,确保所有这些对象在file_utils中可以找到
from .utils import (
    CLOUDFRONT_DISTRIB_PREFIX,
    CONFIG_NAME,
    DUMMY_INPUTS,
    DUMMY_MASK,
    ENV_VARS_TRUE_AND_AUTO_VALUES,
    ENV_VARS_TRUE_VALUES,
    FEATURE_EXTRACTOR_NAME,
    FLAX_WEIGHTS_NAME,
    HF_MODULES_CACHE,
    HUGGINGFACE_CO_PREFIX,
    HUGGINGFACE_CO_RESOLVE_ENDPOINT,
    MODEL_CARD_NAME,
    MULTIPLE_CHOICE_DUMMY_INPUTS,
    PYTORCH_PRETRAINED_BERT_CACHE,
    PYTORCH_TRANSFORMERS_CACHE,
    S3_BUCKET_PREFIX,
    SENTENCEPIECE_UNDERLINE,
    SPIECE_UNDERLINE,
    TF2_WEIGHTS_NAME,
    TF_WEIGHTS_NAME,
    TORCH_FX_REQUIRED_VERSION,
    TRANSFORMERS_CACHE,
    TRANSFORMERS_DYNAMIC_MODULE_NAME,
    USE_JAX,
    USE_TF,
    USE_TORCH,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    ContextManagers,
    DummyObject,
    EntryNotFoundError,
    ExplicitEnum,
    ModelOutput,
    PaddingStrategy,
    PushToHubMixin,
    RepositoryNotFoundError,
    RevisionNotFoundError,
    TensorType,
    _LazyModule,
    add_code_sample_docstrings,
    add_end_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    cached_property,
    copy_func,
    default_cache_path,
    define_sagemaker_information,
    get_cached_models,
    get_file_from_repo,
    get_torch_version,
    has_file,
    http_user_agent,
    is_apex_available,
    is_bs4_available,
    is_coloredlogs_available,
    is_datasets_available,
    is_detectron2_available,
    is_faiss_available,
    is_flax_available,
    is_ftfy_available,
    is_g2p_en_available,
    is_in_notebook,
    is_ipex_available,
    is_librosa_available,
    is_offline_mode,
    is_onnx_available,
    is_pandas_available,
    is_phonemizer_available,
    is_protobuf_available,
    is_psutil_available,
    is_py3nvml_available,
    is_pyctcdecode_available,
    is_pytesseract_available,
    is_pytorch_quantization_available,
    is_rjieba_available,
    is_sagemaker_dp_enabled,
    is_sagemaker_mp_enabled,
    is_scipy_available,  # 检查是否安装了 SciPy 库
    is_sentencepiece_available,  # 检查是否安装了 SentencePiece 库
    is_seqio_available,  # 检查是否安装了 SeqIO 库
    is_sklearn_available,  # 检查是否安装了 Scikit-learn 库
    is_soundfile_availble,  # 检查是否安装了 SoundFile 库
    is_spacy_available,  # 检查是否安装了 spaCy 库
    is_speech_available,  # 检查是否安装了 speech 库
    is_tensor,  # 检查是否是张量(tensor)
    is_tensorflow_probability_available,  # 检查是否安装了 TensorFlow Probability 库
    is_tf2onnx_available,  # 检查是否安装了 tf2onnx 库
    is_tf_available,  # 检查是否安装了 TensorFlow 库
    is_timm_available,  # 检查是否安装了 timm 库
    is_tokenizers_available,  # 检查是否安装了 tokenizers 库
    is_torch_available,  # 检查是否安装了 PyTorch 库
    is_torch_bf16_available,  # 检查是否安装了 PyTorch BF16 库
    is_torch_cuda_available,  # 检查是否安装了 PyTorch CUDA 支持
    is_torch_fx_available,  # 检查是否安装了 PyTorch FX 库
    is_torch_fx_proxy,  # 检查是否安装了 PyTorch FX 代理
    is_torch_mps_available,  # 检查是否安装了 PyTorch MPS 库
    is_torch_tf32_available,  # 检查是否安装了 PyTorch TF32 支持
    is_torch_xla_available,  # 检查是否安装了 PyTorch XLA 支持
    is_torchaudio_available,  # 检查是否安装了 torchaudio 库
    is_training_run_on_sagemaker,  # 检查是否在 SageMaker 上运行训练
    is_vision_available,  # 检查是否安装了视觉相关库
    replace_return_docstrings,  # 替换返回值的文档字符串
    requires_backends,  # 需要的后端库
    to_numpy,  # 转换为 NumPy 格式
    to_py_obj,  # 转换为 Python 对象
    torch_only_method,  # 仅限于 PyTorch 的方法
)

.\generation\beam_constraints.py

# 导入必要的库
from abc import ABC, abstractmethod
from typing import List, Optional

# 定义抽象基类 Constraint
class Constraint(ABC):
    r"""Abstract base class for all constraints that can be applied during generation.
    It must define how the constraint can be satisfied.

    All classes that inherit Constraint must follow the requirement that

    ```
    completed = False
    while not completed:
        _, completed = constraint.update(constraint.advance())
    ```

    will always terminate (halt).
    """

    def __init__(self):
        # 调用 test 方法以测试约束条件
        self.test()

    def test(self):
        """
        Tests whether this constraint has been properly defined.
        """
        # 初始化计数器和完成标志
        counter = 0
        completed = False
        # 进入循环,直到约束条件被满足或超过最大尝试次数
        while not completed:
            # 如果计数器为1,调用 reset 方法
            if counter == 1:
                self.reset()
            # 调用 advance 方法获取进展信息
            advance = self.advance()
            # 检查进展是否符合要求
            if not self.does_advance(advance):
                # 若不符合要求则抛出异常
                raise Exception(
                    "Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
                )

            # 调用 update 方法获取更新后的状态
            stepped, completed, reset = self.update(advance)
            counter += 1

            # 如果超过最大尝试次数,抛出异常
            if counter > 10000:
                raise Exception("update() does not fulfill the constraint.")

        # 检查约束是否全部满足
        if self.remaining() != 0:
            raise Exception("Custom Constraint is not defined correctly.")

    @abstractmethod
    def advance(self):
        """
        When called, returns the token that would take this constraint one step closer to being fulfilled.

        Return:
            token_ids(`torch.tensor`): Must be a tensor of a list of indexable tokens, not some integer.
        """
        # 抽象方法,子类必须实现,返回一个可以使约束条件向满足状态推进的 token
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @abstractmethod
    def does_advance(self, token_id: int):
        """
        Reads in a token and returns whether it creates progress.
        """
        # 抽象方法,子类必须实现,读取一个 token 并返回它是否推进了约束条件
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @abstractmethod
    def update(self, token_id: int):
        """
        Given a token, updates the constraint.

        Return:
            stepped(`bool`): Whether the step was successful in moving towards completion.
            completed(`bool`): Whether the constraint is now completed.
            reset(`bool`): Whether the constraint was reset during this update.
        """
        # 抽象方法,子类必须实现,根据给定的 token 更新约束条件
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )
    def update(self, token_id: int):
        """
        Reads in a token and returns booleans that indicate the progress made by it. This function will update the
        state of this object unlike `does_advance(self, token_id: int)`.

        This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
        been generated. This becomes important if token_id != desired token (refer to else statement in
        PhrasalConstraint)

        Args:
            token_id(`int`):
                The id of a newly generated token in the beam search.
        Return:
            stepped(`bool`):
                Whether this constraint has become one step closer to being fulfilled.
            completed(`bool`):
                Whether this constraint has been completely fulfilled by this token being generated.
            reset (`bool`):
                Whether this constraint has reset its progress by this token being generated.
        """
        # 抛出未实现错误,表明这是一个抽象类方法,只能被继承该类的类调用。
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @abstractmethod
    def reset(self):
        """
        Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
        a constraint is aborted by an unwanted token.
        """
        # 抛出未实现错误,表明这是一个抽象类方法,只能被继承该类的类调用。
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @abstractmethod
    def remaining(self):
        """
        Returns the number of remaining steps of `advance()` in order to complete this constraint.
        """
        # 抛出未实现错误,表明这是一个抽象类方法,只能被继承该类的类调用。
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

    @abstractmethod
    def copy(self, stateful=False):
        """
        Creates a new instance of this constraint.

        Args:
            stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.

        Return:
            constraint(`Constraint`): The same constraint as the one being called from.
        """
        # 抛出未实现错误,表明这是一个抽象类方法,只能被继承该类的类调用。
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )
class PhrasalConstraint(Constraint):
    r"""
    [`Constraint`] enforcing that an ordered sequence of tokens is included in the output.

    Args:
        token_ids (`List[int]`):
            The id of the token that must be generated by the output.
    """

    def __init__(self, token_ids: List[int]):
        super(Constraint, self).__init__()  # 调用父类 Constraint 的构造函数

        if not isinstance(token_ids, list) or len(token_ids) == 0:
            raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
        if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
            raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")

        self.token_ids = token_ids  # 将参数 token_ids 赋给实例变量 self.token_ids

        self.seqlen = len(self.token_ids)  # 记录 token_ids 的长度,即要求的序列的长度
        self.fulfilled_idx = -1  # 当前已满足的步骤的索引,初始为 -1 表示还未开始
        self.completed = False  # 标志变量,指示约束是否已经完成

    def advance(self):
        if self.completed:
            return None  # 如果约束已完成,则返回 None
        return self.token_ids[self.fulfilled_idx + 1]  # 返回下一个需要满足的 token_id

    def does_advance(self, token_id: int):
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")

        if self.completed:
            return False  # 如果约束已完成,则返回 False

        return token_id == self.token_ids[self.fulfilled_idx + 1]  # 检查是否可以满足下一个 token_id

    def update(self, token_id: int):
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")

        stepped = False  # 标志是否成功迈出一步
        completed = False  # 标志是否完成了所有步骤
        reset = False  # 标志是否需要重置状态

        if self.does_advance(token_id):
            self.fulfilled_idx += 1  # 成功满足下一个步骤
            stepped = True
            if self.fulfilled_idx == (self.seqlen - 1):
                completed = True  # 如果已经满足所有步骤,标记为完成
            self.completed = completed
        else:
            # 未能取得进展,需要重置状态
            reset = True
            self.reset()  # 调用 reset 方法重置状态
        return stepped, completed, reset  # 返回操作的结果信息

    def reset(self):
        self.completed = False  # 将完成标志重置为 False
        self.fulfilled_idx = 0  # 将已满足的步骤索引重置为初始状态

    def remaining(self):
        return self.seqlen - (self.fulfilled_idx + 1)  # 返回剩余待满足的步骤数量

    def copy(self, stateful=False):
        new_constraint = PhrasalConstraint(self.token_ids)  # 创建一个新的 PhrasalConstraint 对象

        if stateful:
            new_constraint.seq_len = self.seqlen  # 如果需要复制状态,将状态信息复制到新对象
            new_constraint.fulfilled_idx = self.fulfilled_idx
            new_constraint.completed = self.completed

        return new_constraint  # 返回新创建的约束对象


class DisjunctiveTrie:
    def __init__(self, nested_token_ids: List[List[int]], no_subsets=True):
        r"""
        A helper class that builds a trie with the words represented in `nested_token_ids`.
        """
        # 计算嵌套列表中每个子列表的最大长度,作为树的最大高度
        self.max_height = max([len(one) for one in nested_token_ids])

        # 初始化树的根节点为空字典
        root = {}
        # 遍历嵌套的token_ids列表
        for token_ids in nested_token_ids:
            level = root
            # 遍历每个token_id构建trie
            for tidx, token_id in enumerate(token_ids):
                if token_id not in level:
                    level[token_id] = {}  # 如果token_id不存在当前层级,创建一个空字典

                level = level[token_id]  # 移动到下一个层级

        # 如果指定了不允许子集,并且存在子集关系,则抛出异常
        if no_subsets and self.has_subsets(root, nested_token_ids):
            raise ValueError(
                "Each list in `nested_token_ids` can't be a complete subset of another list, but is"
                f" {nested_token_ids}."
            )

        self.trie = root  # 将构建好的trie作为对象的trie属性保存

    def next_tokens(self, current_seq):
        """
        The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
        """
        start = self.trie

        # 遍历当前序列中的每个token,向下移动trie
        for current_token in current_seq:
            start = start[current_token]

        # 获取当前trie节点的所有子节点,作为下一个可能的token
        next_tokens = list(start.keys())

        return next_tokens

    def reached_leaf(self, current_seq):
        # 获取当前序列的下一个可能token集合
        next_tokens = self.next_tokens(current_seq)

        # 如果下一个可能token集合为空,表示已经达到叶子节点
        return len(next_tokens) == 0

    def count_leaves(self, root):
        # 获取当前节点的所有子节点
        next_nodes = list(root.values())

        # 如果当前节点没有子节点,返回1,表示当前节点是叶子节点
        if len(next_nodes) == 0:
            return 1
        else:
            # 否则,递归计算所有子节点的叶子节点总数,并返回
            return sum([self.count_leaves(nn) for nn in next_nodes])

    def has_subsets(self, trie, nested_token_ids):
        """
        Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
        """
        # 计算trie中的叶子节点数目
        leaf_count = self.count_leaves(trie)

        # 如果trie中的叶子节点数不等于嵌套token_ids的长度,则存在子集关系
        return len(nested_token_ids) != leaf_count
class DisjunctiveConstraint(Constraint):
    r"""
    A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.

    Args:
        nested_token_ids (`List[List[int]]`):
            A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
            the list of words.
    """

    def __init__(self, nested_token_ids: List[List[int]]):
        # 调用父类构造函数初始化
        super(Constraint, self).__init__()

        # 检查输入的 nested_token_ids 是否为非空列表
        if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
            raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
        
        # 检查 nested_token_ids 中的每个元素是否为列表
        if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
            raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
        
        # 检查 nested_token_ids 中的每个元素是否为正整数
        if any(
            any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
            for token_ids in nested_token_ids
        ):
            raise ValueError(
                f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
            )

        # 使用 nested_token_ids 创建一个 DisjunctiveTrie 对象
        self.trie = DisjunctiveTrie(nested_token_ids)
        self.token_ids = nested_token_ids  # 将 nested_token_ids 存储到实例变量中

        # 计算 trie 的最大高度并存储到实例变量中
        self.seqlen = self.trie.max_height
        self.current_seq = []  # 初始化当前序列为空列表
        self.completed = False  # 标记约束条件是否已完成

    def advance(self):
        # 获取当前序列可以继续的下一个 token 列表
        token_list = self.trie.next_tokens(self.current_seq)

        if len(token_list) == 0:
            return None  # 如果没有可继续的 token,则返回 None
        else:
            return token_list  # 否则返回可继续的 token 列表

    def does_advance(self, token_id: int):
        # 检查给定的 token_id 是否可以在当前序列中继续
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")

        next_tokens = self.trie.next_tokens(self.current_seq)

        return token_id in next_tokens  # 返回 token_id 是否在可继续的 token 列表中

    def update(self, token_id: int):
        # 更新当前序列,并返回是否有步进、是否完成、是否重置的标志
        if not isinstance(token_id, int):
            raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")

        stepped = False
        completed = False
        reset = False

        if self.does_advance(token_id):
            self.current_seq.append(token_id)  # 如果可以继续,则将 token_id 添加到当前序列中
            stepped = True
        else:
            reset = True
            self.reset()  # 否则重置当前序列

        completed = self.trie.reached_leaf(self.current_seq)  # 检查当前序列是否达到叶节点
        self.completed = completed  # 更新约束条件是否已完成的状态

        return stepped, completed, reset  # 返回步进、完成和重置的标志

    def reset(self):
        # 重置当前序列和完成状态
        self.completed = False
        self.current_seq = []

    def remaining(self):
        if self.completed:
            return 0  # 如果约束条件已完成,则剩余长度为 0
        else:
            return self.seqlen - len(self.current_seq)  # 否则返回剩余的最大长度与当前序列长度的差值
    # 定义一个方法 `copy`,用于创建当前对象的副本
    def copy(self, stateful=False):
        # 创建一个新的 DisjunctiveConstraint 对象,使用当前对象的 token_ids 初始化
        new_constraint = DisjunctiveConstraint(self.token_ids)

        # 如果 stateful 参数为 True,则复制当前对象的状态到新对象中
        if stateful:
            new_constraint.seq_len = self.seqlen  # 复制当前对象的 seq_len 属性
            new_constraint.current_seq = self.current_seq  # 复制当前对象的 current_seq 属性
            new_constraint.completed = self.completed  # 复制当前对象的 completed 属性

        # 返回新创建的对象副本
        return new_constraint
class ConstraintListState:
    r"""
    A class for beam scorers to track its progress through a list of constraints.

    Args:
        constraints (`List[Constraint]`):
            A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
    """

    def __init__(self, constraints: List[Constraint]):
        self.constraints = constraints

        # max # of steps required to fulfill a given constraint
        self.max_seqlen = max([c.seqlen for c in constraints])  # 计算所有约束中的最大步数
        self.n_constraints = len(constraints)  # 约束数量
        self.completed = False  # 标志位,表示是否完成

        self.init_state()  # 初始化状态

    def init_state(self):
        self.complete_constraints = []  # 已完成的约束列表
        self.inprogress_constraint = None  # 当前进行中的约束
        self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]  # 待处理的约束列表,复制并标记为非状态化

    def get_bank(self):
        add = 0
        if self.inprogress_constraint:
            # extra points for having a constraint mid-fulfilled
            add += self.max_seqlen - self.inprogress_constraint.remaining()  # 如果存在进行中的约束,计算其剩余步数对应的额外分数

        return (len(self.complete_constraints) * self.max_seqlen) + add  # 返回当前已完成约束的总步数加上额外分数

    def advance(self):
        """The list of tokens to generate such that we can make progress.
        By "list" we don't mean the list of token that will fully fulfill a constraint.

        Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
        specific constraint `c_i`, we return:

        `[t_k1 for k in indices of unfulfilled constraints]`

        If we are in the middle of a constraint, then we return:
            `[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.

        Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
        that's the only one we'll return.
        """
        token_list = []
        if self.inprogress_constraint is None:
            for constraint in self.pending_constraints:  # 遍历待处理的约束
                advance = constraint.advance()  # 获取约束的推进状态
                if isinstance(advance, int):
                    token_list.append(advance)  # 如果推进状态是整数,直接添加到 token_list
                elif isinstance(advance, list):
                    token_list.extend(advance)  # 如果推进状态是列表,扩展到 token_list
        else:
            advance = self.inprogress_constraint.advance()  # 获取当前进行中约束的推进状态
            if isinstance(advance, int):
                token_list.append(advance)  # 如果推进状态是整数,直接添加到 token_list
            elif isinstance(advance, list):
                token_list.extend(advance)  # 如果推进状态是列表,扩展到 token_list

        if len(token_list) == 0:
            return None  # 如果 token_list 为空,返回 None
        else:
            return token_list  # 否则返回 token_list
    def reset(self, token_ids: Optional[List[int]]):
        """
        重置对象状态,根据给定的token_ids重新设置约束的进度状态。
        token_ids: 到目前为止生成的令牌,用于重置通过约束的进度状态。
        """
        self.init_state()  # 调用初始化状态方法

        if token_ids is not None:
            for token in token_ids:
                # 添加一个令牌,完成或推进一个约束
                complete, stepped = self.add(token)

                # 如果所有约束已完成,则退出循环
                if self.completed:
                    break

    def copy(self, stateful=True):
        """
        创建并返回一个当前对象的副本,可以选择是否保持状态。
        stateful: 是否保持状态,默认为True。
        """
        new_state = ConstraintListState(self.constraints)  # 使用当前约束列表创建新的状态对象

        if stateful:
            # 复制已完成的约束列表中的每个约束对象
            new_state.complete_constraints = [
                constraint.copy(stateful=True) for constraint in self.complete_constraints
            ]
            # 如果存在正在进行中的约束,则复制该约束对象的状态副本
            if self.inprogress_constraint is not None:
                new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
            # 复制待处理约束列表中的每个约束对象
            new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]

        return new_state

.\generation\beam_search.py

# 导入必要的模块和库
from abc import ABC, abstractmethod  # 导入抽象基类和抽象方法装饰器
from collections import UserDict  # 导入用户自定义字典类
from typing import Dict, List, Optional, Tuple, Union  # 导入类型提示

import numpy as np  # 导入 NumPy 库
import torch  # 导入 PyTorch 库

from ..utils import add_start_docstrings  # 从上级目录的 utils 模块导入 add_start_docstrings 函数
from .beam_constraints import Constraint, ConstraintListState  # 从当前目录的 beam_constraints 模块导入 Constraint 和 ConstraintListState 类

# 定义常量,该常量包含一个多行的文档字符串,用于描述函数 process_inputs 的参数和返回值
PROCESS_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
            Current scores of the top `2 * num_beams` non-finished beam hypotheses.
        next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
            `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
        next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
            Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
        pad_token_id (`int`, *optional*):
            The id of the *padding* token.
        eos_token_id (`Union[int, List[int]]`, *optional*):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
        beam_indices (`torch.LongTensor`, *optional*):
            Beam indices indicating to which beam hypothesis each token correspond.
        group_index (`int`, *optional*):
            The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].

    Return:
        `UserDict`: A dictionary composed of the fields as defined above:

            - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
              non-finished beams.
            - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
              to the non-finished beam_hypotheses.
            - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
              indicating to which beam the next tokens shall be added.
FINALIZE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
            The final scores of all non-finished beams.
        final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
            The last tokens to be added to the non-finished beam_hypotheses.
        final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
            The beam indices indicating to which beam the `final_beam_tokens` shall be added.
        pad_token_id (`int`, *optional*):
            The id of the *padding* token.
        eos_token_id (`Union[int, List[int]]`, *optional*):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.

    Return:
        `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
        The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
        due to the `eos_token_id`.

"""


class BeamScorer(ABC):
    """
    Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
    [`~PreTrainedModel.beam_sample`].
    """

    @abstractmethod
    @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)  # 添加输入处理方法的文档字符串
    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        **kwargs,
    ) -> Tuple[torch.Tensor]:
        raise NotImplementedError("This is an abstract method.")

    @abstractmethod
    @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)  # 添加最终处理方法的文档字符串
    def finalize(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        max_length: int,
        **kwargs,
    ) ->
        Args:
            batch_size (`int`):
                并行运行标准束搜索解码的 `input_ids` 的批大小。
            num_beams (`int`):
                梁搜索的束大小。
            device (`torch.device`):
                分配此 `BeamSearchScorer` 实例的设备类型(例如 `"cpu"` 或 `"cuda"`)。
            length_penalty (`float`, *optional*, defaults to 1.0):
                用于基于束搜索的生成的指数长度惩罚。应用为序列长度的指数,然后用于将序列的分数除以此值。由于分数是序列的对数似然(即负数),`length_penalty` > 0.0 会促进更长的序列,而 `length_penalty` < 0.0 会鼓励更短的序列。
            do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
                控制束搜索等方法(如束搜索)的停止条件。接受以下值:
                `True`,生成器一旦有 `num_beams` 个完整候选项即停止;
                `False`,应用启发式方法,生成器停止时不太可能找到更好的候选项;
                `"never"`,束搜索过程仅在不能有更好的候选项时停止(典型的束搜索算法)。
            num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
                在调用 [`~transformers.BeamSearchScorer.finalize`] 后返回的束假设数量。
            num_beam_groups (`int`, *optional*, defaults to 1):
                为了确保不同束组之间的多样性,将 `num_beams` 分成的组数。详细信息请参阅[此论文](https://arxiv.org/pdf/1610.02424.pdf)。
            max_length (`int`, *optional*):
                要生成的序列的最大长度。
    """

    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        device: torch.device,
        length_penalty: Optional[float] = 1.0,
        do_early_stopping: Optional[Union[bool, str]] = False,
        num_beam_hyps_to_keep: Optional[int] = 1,
        num_beam_groups: Optional[int] = 1,
        max_length: Optional[int] = None,
        ):
        self.num_beams = num_beams
        self.device = device
        self.length_penalty = length_penalty
        self.do_early_stopping = do_early_stopping
        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
        self.num_beam_groups = num_beam_groups
        self.group_size = self.num_beams // self.num_beam_groups

        self._is_init = False
        # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
        # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
        self._beam_hyps = [
            BeamHypotheses(
                num_beams=self.group_size,  # 创建 BeamHypotheses 对象,设置每个组的 beam 数量
                length_penalty=self.length_penalty,  # 设置长度惩罚因子
                early_stopping=self.do_early_stopping,  # 设置是否提前停止
                max_length=max_length,  # 设置最大生成长度
            )
            for _ in range(batch_size * self.num_beam_groups)  # 根据 mini-batch 大小和组数创建多个 BeamHypotheses 对象
        ]
        # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
        # in the i-th mini-batch is complete.
        self._done = torch.tensor(
            [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device  # 创建表示生成是否完成的张量
        )

        if not isinstance(num_beams, int) or num_beams <= 1:  # 检查 num_beams 是否为大于1的整数
            raise ValueError(
                f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
                " one should make use of `greedy_search` instead."
            )

        if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
            raise ValueError(
                "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
                f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
            )

    @property
    def is_done(self) -> bool:
        return self._done.all()  # 返回是否所有生成操作均完成的布尔值

    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        group_index: Optional[int] = 0,
        decoder_prompt_len: Optional[int] = 0,
    ):  # 定义一个处理生成过程的方法,接受多个参数

    def finalize(
        self,
        input_ids: torch.LongTensor,
        final_beam_scores: torch.FloatTensor,
        final_beam_tokens: torch.LongTensor,
        final_beam_indices: torch.LongTensor,
        max_length: int,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        decoder_prompt_len: Optional[int] = 0,
    ):  # 定义一个完成生成过程的方法,接受多个参数
    # 定义一个新的类 `ConstrainedBeamSearchScorer`,继承自 `BeamScorer` 类
    r"""
    [`BeamScorer`] implementing constrained beam search decoding.
    实现受限束搜索解码的 [`BeamScorer`]。
    

    Args:
        batch_size (`int`):
            Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
            输入 `input_ids` 的批处理大小,用于并行运行标准的束搜索解码。
        num_beams (`int`):
            Number of beams for beam search.
            束搜索的束数。
        constraints (`List[Constraint]`):
            A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
            output. For more information, the documentation of [`Constraint`] should be read.
            表示为 `Constraint` 对象的正约束列表,必须在生成的输出中满足。有关更多信息,请阅读 [`Constraint`] 的文档。
        device (`torch.device`):
            Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
            allocated.
            定义此 `BeamSearchScorer` 实例将分配到的设备类型(例如 `"cpu"` 或 `"cuda"`)。
        length_penalty (`float`, *optional*, defaults to 1.0):
            Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
            the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
            likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
            `length_penalty` < 0.0 encourages shorter sequences.
            用于基于束的生成的长度的指数惩罚。它作为序列长度的指数应用,进而用于分割序列的分数。由于分数是序列的对数似然(即负数),`length_penalty` > 0.0 促进更长的序列,而 `length_penalty` < 0.0 鼓励更短的序列。
        do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
            Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
            `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
            heuristic is applied and the generation stops when is it very unlikely to find better candidates;
            `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
            beam search algorithm).
            控制基于束的方法(如束搜索)的停止条件。它接受以下值:`True`,生成在有 `num_beams` 个完整候选时停止;`False`,应用启发式并在很不可能找到更好的候选时停止生成;`"never"`,束搜索过程仅在不能有更好的候选时停止(经典的束搜索算法)。
        num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
            The number of beam hypotheses that shall be returned upon calling
            [`~transformers.BeamSearchScorer.finalize`].
            在调用 [`~transformers.BeamSearchScorer.finalize`] 时将返回的束假设数。
        num_beam_groups (`int`, *optional*, defaults to 1):
            Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
            See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
            为了确保不同组的束之间的多样性,将 `num_beams` 分成的组数。有关更多详细信息,请参见 [此文献](https://arxiv.org/pdf/1610.02424.pdf)。
        max_length (`int`, *optional*):
            The maximum length of the sequence to be generated.
            要生成的序列的最大长度。
    """

    def __init__(
        self,
        batch_size: int,
        num_beams: int,
        constraints: List[Constraint],
        device: torch.device,
        length_penalty: Optional[float] = 1.0,
        do_early_stopping: Optional[Union[bool, str]] = False,
        num_beam_hyps_to_keep: Optional[int] = 1,
        num_beam_groups: Optional[int] = 1,
        max_length: Optional[int] = None,
        ):
        # 初始化 BeamSearch 类的实例
        self.num_beams = num_beams
        self.device = device
        self.length_penalty = length_penalty
        self.do_early_stopping = do_early_stopping
        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
        self.num_beam_groups = num_beam_groups
        self.group_size = self.num_beams // self.num_beam_groups
        self.constraints = constraints

        self._is_init = False
        # 初始化 `_beam_hyps` 属性,存储 BeamHypotheses 的列表
        self._beam_hyps = [
            BeamHypotheses(
                num_beams=self.num_beams,
                length_penalty=self.length_penalty,
                early_stopping=self.do_early_stopping,
                max_length=max_length,
            )
            for _ in range(batch_size)
        ]
        # 初始化 `_done` 属性为 torch tensor,表示是否完成的状态
        self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)

        # 检查 `num_beams` 是否是正整数且大于 1,否则抛出异常
        if not isinstance(num_beams, int) or num_beams <= 1:
            raise ValueError(
                f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
                " one should make use of `greedy_search` instead."
            )

        # 检查 `num_beam_groups` 是否是正整数且满足条件,否则抛出异常
        if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
            raise ValueError(
                "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
                f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
            )

    @property
    def is_done(self) -> bool:
        # 返回 `_done` 属性是否全部为 True
        return self._done.all()

    def make_constraint_states(self, n):
        # 根据约束条件创建状态列表的实例,返回列表
        return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]

    def check_completes_constraints(self, sequence):
        # 创建约束状态的实例,并重置为给定的序列,返回是否完成的布尔值
        new_state = self.make_constraint_states(1)[0]
        new_state.reset(sequence)
        return new_state.completed

    def process(
        self,
        input_ids: torch.LongTensor,
        next_scores: torch.FloatTensor,
        next_tokens: torch.LongTensor,
        next_indices: torch.LongTensor,
        scores_for_all_vocab: torch.FloatTensor,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        beam_indices: Optional[torch.LongTensor] = None,
        decoder_prompt_len: Optional[int] = 0,
    ):
        # 处理 beam search 的每个步骤,计算下一个可能的 token
        ...

    def step_sentence_constraint(
        self,
        batch_idx: int,
        input_ids: torch.LongTensor,
        vocab_scores: torch.FloatTensor,
        sent_beam_scores: torch.FloatTensor,
        sent_beam_tokens: torch.LongTensor,
        sent_beam_indices: torch.LongTensor,
        push_progress: bool = False,
        ):
        # 执行句子级别的约束步骤,更新相关的输入状态
        ...
    # 定义一个方法 finalize,用于处理束搜索的结果并生成最终的输出序列
    def finalize(
        self,
        # 输入的 token IDs,是一个 LongTensor
        input_ids: torch.LongTensor,
        # 最终束搜索得分,是一个 FloatTensor
        final_beam_scores: torch.FloatTensor,
        # 最终的束搜索 token 序列,是一个 LongTensor
        final_beam_tokens: torch.LongTensor,
        # 最终的束搜索索引,是一个 LongTensor,指示每个最终结果的束索引
        final_beam_indices: torch.LongTensor,
        # 最大生成长度,一个整数值
        max_length: int,
        # 填充 token 的 ID,可选参数,默认为 None
        pad_token_id: Optional[int] = None,
        # 结束 token 的 ID,可以是一个整数或整数列表,可选参数,默认为 None
        eos_token_id: Optional[Union[int, List[int]]] = None,
        # 生成结果时每个 token 序列对应的束索引,可选的 LongTensor,默认为 None
        beam_indices: Optional[torch.LongTensor] = None,
        # 解码器提示长度,可选的整数,默认为 0
        decoder_prompt_len: Optional[int] = 0,
# 定义一个类 BeamHypotheses,用于存储 Beam Search 算法生成的假设列表
class BeamHypotheses:
    # 初始化方法,设置各种参数和初始值
    def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
        """
        Initialize n-best list of hypotheses.

        Args:
            num_beams (int): Beam size, i.e., number of beams to keep.
            length_penalty (float): Length penalty to be applied to scores.
            early_stopping (bool): Whether to stop generation early based on conditions.
            max_length (Optional[int]): Optional maximum length for generated hypotheses.
        """
        self.length_penalty = length_penalty  # 设置长度惩罚参数
        self.early_stopping = early_stopping  # 是否启用提前停止
        self.max_length = max_length  # 最大生成长度限制
        self.num_beams = num_beams  # Beam 的数量
        self.beams = []  # 用于存储假设的列表
        self.worst_score = 1e9  # 初始设置一个极大值作为最差分数的初始值

        # 检查 early_stopping 参数类型,如果不是布尔值且 max_length 未定义,则引发错误
        if not isinstance(self.early_stopping, bool) and self.max_length is None:
            raise ValueError(
                "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
                " BeamScorer class instance at initialization time."
            )

    # 返回当前假设列表中假设的数量
    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.beams)

    # 向假设列表中添加新的假设
    def add(
        self,
        hyp: torch.LongTensor,
        sum_logprobs: float,
        beam_indices: Optional[torch.LongTensor] = None,
        generated_len: Optional[int] = None,
    ):
        """
        Add a new hypothesis to the list.

        Args:
            hyp (torch.LongTensor): Tensor representing the hypothesis.
            sum_logprobs (float): Sum of log probabilities associated with the hypothesis.
            beam_indices (Optional[torch.LongTensor]): Optional tensor of beam indices.
            generated_len (Optional[int]): Optional length of the generated sequence.
        """
        # 根据生成的序列长度或者假设的最后一个维度计算得分
        if generated_len is not None:
            score = sum_logprobs / (generated_len**self.length_penalty)
        else:
            score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)

        # 如果假设列表中假设数量小于 Beam 数量或者当前分数大于最差分数,则添加新假设
        if len(self) < self.num_beams or score > self.worst_score:
            self.beams.append((score, hyp, beam_indices))
            # 如果假设列表超过了 Beam 数量,则删除分数最低的假设
            if len(self) > self.num_beams:
                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
                del self.beams[sorted_next_scores[0][1]]
                self.worst_score = sorted_next_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
    def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
        """
        If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
        one in the heap, then we are done with this sentence.
        """

        # 如果当前堆中的假设数量小于要求的最大堆大小(num_beams),则返回 False
        if len(self) < self.num_beams:
            return False

        # 如果设定了 early_stopping 为 True,则立即停止,即使未满足其他条件
        if self.early_stopping is True:
            return True
        
        # 如果 early_stopping 设为 False,则根据当前长度计算最高可达分数,并检查是否达到最低分数标准
        elif self.early_stopping is False:
            highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
            ret = self.worst_score >= highest_attainable_score
            return ret
        
        # 如果 early_stopping 设为 "never",则根据 length_penalty 的值计算最高可达分数
        else:
            # 当 length_penalty 大于 0.0 时,从 max_length 而不是 cur_len 计算最高可达分数
            if self.length_penalty > 0.0:
                if self.max_length <= decoder_prompt_len:
                    raise ValueError("max_length is not larger than decoder prompt length")
                highest_attainable_score = (
                    best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
                )
            # 当 length_penalty 小于等于 0.0 时,从 cur_len 计算最高可达分数
            else:
                highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
            
            ret = self.worst_score >= highest_attainable_score
            return ret

.\generation\candidate_generator.py

# coding=utf-8
# 版权所有 2023 年 HuggingFace Inc. 团队。
#
# 根据 Apache 许可证 2.0 版本进行许可;
# 除非符合许可证要求,否则不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则按“原样”分发软件
# 软件没有任何形式的明示或暗示担保或条件。
# 有关详细信息,请参阅许可证。
#

import copy  # 导入深拷贝模块
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple  # 导入类型提示模块

import torch  # 导入PyTorch模块


if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel  # 导入预训练模型类型提示
    from .configuration_utils import GenerationConfig  # 导入生成配置类型提示
    from .logits_process import LogitsProcessorList  # 导入logits处理列表类型提示


class CandidateGenerator:
    """所有候选生成器的抽象基类,可在辅助生成过程中应用。"""

    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        """
        获取当前输入的候选生成序列。

        Args:
            input_ids (`torch.LongTensor`,形状为 `(batch_size, sequence_length)`):
                输入序列标记在词汇表中的索引。[什么是输入ID?](../glossary#input-ids)

        Return:
            `torch.LongTensor`,形状为 `(batch_size, candidate_length)`,包含模型评估的候选序列,
            以及一个可选的 `torch.FloatTensor`,形状为 `(batch_size, candidate_length, vocabulary_size)`,
            包含与每个候选相关的logits。
        """
        raise NotImplementedError(
            f"{self.__class__} 是一个抽象类。只有继承此类的类才能调用 `get_candidates`。"
        )

    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
        """
        根据结果更新候选生成策略。

        Args:
            input_ids (`torch.LongTensor`,形状为 `(batch_size, sequence_length)`):
                输入序列标记在词汇表中的索引。[什么是输入ID?](../glossary#input-ids)
            scores (`torch.FloatTensor`,形状为 `(batch_size, candidate_length, config.vocab_size)`):
                语言建模头部的预测分数。当不使用beam搜索时,这些可以是每个词汇的logits,或者在使用beam搜索时,每个词汇token的log softmax。
            num_matches (`int`):
                候选序列与模型预测之间的匹配数。
        """
        raise NotImplementedError(
            f"{self.__class__} 是一个抽象类。只有继承此类的类才能调用 `update_candidate_strategy`。"
        )
    """
    `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
    candidates through the use of a smaller model. Read the following blog post for more information:
    https://huggingface.co/blog/assisted-generation
    """

    def __init__(
        self,
        input_ids: torch.LongTensor,
        assistant_model: "PreTrainedModel",
        generation_config: "GenerationConfig",
        logits_processor: "LogitsProcessorList",
        model_kwargs: Dict,
        inputs_tensor: Optional[torch.Tensor] = None,
    ):
        """
        Initialize the `AssistedCandidateGenerator` with necessary parameters.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary.
            assistant_model (`PreTrainedModel`):
                The model used for generating candidates, which is smaller than the main model.
            generation_config (`~generation.GenerationConfig`, *optional*):
                Configuration for the generation process.
            logits_processor (`LogitsProcessorList`):
                List of processors to modify prediction scores of the language modeling head during generation.
            model_kwargs (`Dict`):
                Keyword arguments passed to the main model and the assistant model.
            inputs_tensor (`torch.Tensor`, *optional*):
                The input tensor for the model, typically the encoder input in encoder-decoder models.
        """
        # 调用父类的初始化方法,传入输入的参数
        super().__init__(input_ids, assistant_model, generation_config, logits_processor, model_kwargs)
        # 将输入的张量赋值给实例变量
        self.inputs_tensor = inputs_tensor
    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        """
        Fetches the candidates to be tried for the current input.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

        Return:
            `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
            assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
            vocabulary_size)` containing the logits associated to each candidate.
        """
        # Move input_ids tensor to the device of the assistant model
        input_ids = input_ids.to(self.assistant_model.device)

        # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
        new_cur_len = input_ids.shape[-1]
        max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
        if max_new_tokens == 0:
            return input_ids, None

        # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
        # (which implicitly contains the number of accepted candidates from the previous round)

        # Check if there are past key values for the assistant model
        has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
        if has_past_key_values:
            # Calculate the new cache size based on current length minus one
            new_cache_size = new_cur_len - 1
            # Crop the past key values to match the new cache size
            self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
                self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
            )  # the assistant does not have the token after the last match, hence the -1

            # Prepare attention mask based on the new current length and model configuration
            self.assistant_kwargs = _prepare_attention_mask(
                self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
            )

            # Prepare token type IDs based on the new current length
            self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

        # 2. Forecast next N tokens using the assistant model.
        assistant_generation_kwargs = {
            self.input_ids_key: input_ids,
            "max_new_tokens": max_new_tokens,
            "generation_config": self.generation_config,
            "logits_processor": self.logits_processor,
        }

        # Generate candidate sequences and logits using the assistant model
        assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

        # 3. Update variables for the next round of candidate generation
        # Update past key values for the assistant model with the latest output
        self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

        # 4. Prepare variables for output
        # Stack candidate scores along the sequence dimension
        candidate_logits = torch.stack(assistant_output.scores, dim=1)
        # Get candidate sequence IDs
        candidate_ids = assistant_output.sequences
        return candidate_ids, candidate_logits
    # 更新候选生成策略基于结果的函数
    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
        """
        Updates the candidate generation strategy based on the outcomes.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            num_matches (`int`):
                The number of matches between the candidate sequences and the model predictions.
        """
        # 调整下一个迭代中使用的助手标记的最大数量。这是一个简单的启发式方法,可能可以改进 -- 我们希望在获取正确的助手标记的好处与预测错误的代价之间取得平衡。
        if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
            "heuristic",
            "heuristic_transient",
        }:
            # 如果匹配数等于当前助手标记数量,则增加助手标记数量
            if num_matches == int(self.num_assistant_tokens):
                self.num_assistant_tokens += 2.0
            else:
                # 否则,减少助手标记数量,但不低于1.0
                self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
# 定义一个候选生成器类 `PromptLookupCandidateGenerator`,继承自 `CandidateGenerator` 类。
# 该类用于生成基于提示查找的候选结果。它通过查找在提供的提示(input_ids)中可能的延续来生成候选结果。
# 更多信息请查阅以下博客文章:https://github.com/apoorvumang/prompt-lookup-decoding
class PromptLookupCandidateGenerator(CandidateGenerator):
    
    """
    `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
    likely continuations in the provided prompt (input_ids) itself.
    Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
    """

    def __init__(
        self,
        num_output_tokens: int = 10,
        max_matching_ngram_size: int = None,
    ):
        # 初始化方法,设置候选结果输出的 token 数量和最大匹配的 ngram 大小。
        self.num_output_tokens = num_output_tokens
        # 如果未指定最大匹配的 ngram 大小,则默认为 2
        self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2

        # 如果最大匹配的 ngram 大小或者输出的 token 数量小于等于 0,则抛出数值错误异常。
        if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
            raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
        def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
            """
            Fetches the candidates to be tried for the current input.

            Args:
                input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                    Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

            Return:
                `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
            """
            input_length = input_ids.size(1)  # 获取输入的序列长度

            chosen_ids = None  # 初始化 chosen_ids 为 None
            match_found = False  # 初始化 match_found 为 False
            for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):  # 遍历 ngram 大小
                # 创建大小为 ngram_size 的滑动窗口
                windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

                # 将 ngram 转换为张量进行比较
                ngram_tensor = input_ids[0, -ngram_size:]

                # 查找窗口与 ngram 匹配的位置
                matches = (windows == ngram_tensor).all(dim=2)

                # 获取匹配的索引
                match_indices = matches.nonzero(as_tuple=True)[1]

                # 遍历匹配索引以找到有效的延续
                for idx in match_indices:
                    start_idx = idx + ngram_size
                    end_idx = start_idx + self.num_output_tokens
                    end_idx = min(end_idx, input_length)

                    if start_idx < end_idx:
                        chosen_ids = input_ids[0, start_idx:end_idx]
                        match_found = True
                        break
                if match_found:
                    break

            if chosen_ids is None or len(chosen_ids) == 0:
                # 如果没有找到匹配,则返回未更改的输入序列,恢复自回归解码
                return input_ids, None

            # 现在需要用 chosen_ids 扩展 input_ids
            chosen_ids = chosen_ids.unsqueeze(0)
            candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
            # assisted_generation 预期也返回 logits,但这里我们没有,所以返回 None
            return candidate_input_ids, None
    # 更新候选生成策略,根据结果进行调整

    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
        """
        Updates the candidate generation strategy based on the outcomes.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            num_matches (`int`):
                The number of matches between the candidate sequences and the model predictions.
        """
        # 当前函数暂未实现任何功能,仅作为占位符
        return
# 将过去的键值对裁剪到指定的最大长度
def _crop_past_key_values(model, past_key_values, maximum_length):
    new_past = []
    # 如果模型是编码-解码模型
    if model.config.is_encoder_decoder:
        # 遍历过去的键值对
        for idx in range(len(past_key_values)):
            # 裁剪过去的键值对的内容,保留最大长度内的部分
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :maximum_length, :],
                    past_key_values[idx][1][:, :, :maximum_length, :],
                    past_key_values[idx][2],
                    past_key_values[idx][3],
                )
            )
        past_key_values = tuple(new_past)
    # 如果模型类名中包含"bloom",或者模型架构中的第一个类名中包含"bloom"
    elif "bloom" in model.__class__.__name__.lower() or (
        model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
    ):
        # 遍历过去的键值对
        for idx in range(len(past_key_values)):
            # 根据不同的维度裁剪过去的键值对的内容,保留最大长度内的部分
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :maximum_length],
                    past_key_values[idx][1][:, :maximum_length, :],
                )
            )
        past_key_values = tuple(new_past)
    # 如果模型类名中包含"gptbigcode",或者模型架构中的第一个类名中包含"gptbigcode"
    elif "gptbigcode" in model.__class__.__name__.lower() or (
        model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
    ):
        # 如果是多重查询模型
        if model.config.multi_query:
            # 遍历过去的键值对,裁剪为最大长度的内容
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
        else:
            # 遍历过去的键值对,裁剪为最大长度的内容
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
    else:
        # 遍历过去的键值对
        for idx in range(len(past_key_values)):
            # 裁剪过去的键值对的内容,保留最大长度内的部分
            new_past.append(
                (
                    past_key_values[idx][0][:, :, :maximum_length, :],
                    past_key_values[idx][1][:, :, :maximum_length, :],
                )
            )
        past_key_values = tuple(new_past)
    return past_key_values


# 扩展或裁剪模型的注意力掩码,以用于解码目的,调整到指定的长度
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
    """Expands or crops the model's mask for decoding purposes, to the defined length"""

    mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
    # 如果模型参数中不包含指定的掩码键值,则直接返回模型参数
    if mask_key not in model_kwargs:
        return model_kwargs

    mask = model_kwargs[mask_key]
    mask_length_diff = new_length - mask.shape[1]

    # 如果掩码长度超出了需要的长度,则裁剪掩码
    if mask_length_diff < 0:
        model_kwargs[mask_key] = mask[:, :mask_length_diff]
    # 如果掩码长度不足需要的长度,则扩展掩码
    elif mask_length_diff > 0:
        model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
    return model_kwargs


# 扩展或裁剪模型的token_type_ids,以用于解码目的,调整到指定的长度
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
    """Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
    # 如果模型参数中不包含token_type_ids或者其值为空,则直接返回模型参数
    if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
        return model_kwargs
    # 获取模型参数字典中的 token_type_ids
    token_type_ids = model_kwargs["token_type_ids"]
    
    # 获取 token_type_ids 的最后一个元素,并在最后增加一个维度
    final_token_type = token_type_ids[:, -1].unsqueeze(-1)
    
    # 计算新长度与当前 token_type_ids 的长度之差
    type_length_diff = new_length - token_type_ids.shape[1]
    
    # 根据长度差进行条件判断和处理
    if type_length_diff < 0:
        # 如果长度差小于零,截取 token_type_ids 的前 type_length_diff 部分
        token_type_ids = token_type_ids[:, :type_length_diff]
    elif type_length_diff > 0:
        # 如果长度差大于零,复制 final_token_type,使其与长度差匹配,并将其拼接到 token_type_ids 后面
        token_type_copies = final_token_type.repeat(1, type_length_diff)
        model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
    
    # 返回更新后的模型参数字典
    return model_kwargs

.\generation\configuration_utils.py

# coding=utf-8
# 声明编码格式为UTF-8,确保文件中可以包含非ASCII字符
# Copyright 2022 The HuggingFace Inc. team.
# 版权声明,指出代码的版权归属于HuggingFace Inc.团队。

# Licensed under the Apache License, Version 2.0 (the "License");
# 根据Apache License, Version 2.0许可证授权,使用该文件需要遵守许可证规定。
# you may not use this file except in compliance with the License.
# 除非符合许可证的规定,否则不得使用此文件。
# You may obtain a copy of the License at
# 可以在以下网址获取许可证的副本
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 除非适用法律要求或书面同意,否则依照“原样”分发本软件,不附带任何明示或暗示的担保或条件。
# 可以在许可证下查看特定语言的权限和限制。

""" Generation configuration class and utilities."""
# 生成配置类和实用程序的说明文档。

import copy
# 导入copy模块,用于复制对象
import json
# 导入json模块,用于JSON数据的处理
import os
# 导入os模块,提供与操作系统交互的功能
import warnings
# 导入warnings模块,用于管理警告信息
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
# 导入类型提示相关的模块和类型

from .. import __version__
# 从父级目录导入__version__,用于获取当前模块的版本信息
from ..configuration_utils import PretrainedConfig
# 从父级目录导入PretrainedConfig类,用于处理预训练配置相关的功能
from ..utils import (
    GENERATION_CONFIG_NAME,
    ExplicitEnum,
    PushToHubMixin,
    cached_file,
    download_url,
    extract_commit_hash,
    is_remote_url,
    logging,
)
# 从父级目录的utils模块导入各种工具函数和类

if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel
    # 如果在类型检查模式下,导入预训练模型相关的模块

logger = logging.get_logger(__name__)
# 获取当前模块的日志记录器对象
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
# 元数据字段元组,包含模型配置来源、提交哈希、原始对象哈希和transformers版本信息


class GenerationMode(ExplicitEnum):
    """
    Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
    """
    # 生成模式枚举类,表示`generate`方法的可能生成模式

    # Non-beam methods
    CONTRASTIVE_SEARCH = "contrastive_search"
    # 对比搜索方法
    GREEDY_SEARCH = "greedy_search"
    # 贪婪搜索方法
    SAMPLE = "sample"
    # 随机采样方法
    ASSISTED_GENERATION = "assisted_generation"
    # 辅助生成方法

    # Beam methods
    BEAM_SEARCH = "beam_search"
    # Beam搜索方法
    BEAM_SAMPLE = "beam_sample"
    # Beam采样方法
    CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
    # 限制Beam搜索方法
    GROUP_BEAM_SEARCH = "group_beam_search"
    # 分组Beam搜索方法


class GenerationConfig(PushToHubMixin):
    # no-format
    r"""
    Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
    ```
    # 生成任务配置类,支持以下生成方法
    """
    Defines special methods for hash, equality comparison, and representation of GenerationConfig objects.
    """

    # 计算对象的哈希值,基于忽略元数据的 JSON 字符串表示
    def __hash__(self):
        return hash(self.to_json_string(ignore_metadata=True))

    # 判断两个 GenerationConfig 对象是否相等,忽略元数据进行比较
    def __eq__(self, other):
        # 如果 other 不是 GenerationConfig 类型,直接返回 False
        if not isinstance(other, GenerationConfig):
            return False
        
        # 分别获取去除元数据后的 JSON 字符串
        self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
        other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
        
        # 比较两个 JSON 字符串是否相等
        return self_without_metadata == other_without_metadata

    # 返回 GenerationConfig 对象的字符串表示,包括忽略元数据的 JSON 字符串
    def __repr__(self):
        return f"{self.__class__.__name__} {self.to_json_string(ignore_metadata=True)}"
    def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = None) -> GenerationMode:
        """
        Returns the generation mode triggered by the [`GenerationConfig`] instance.

        Arg:
            assistant_model (`PreTrainedModel`, *optional*):
                The assistant model to be used for assisted generation. If set, the generation mode will be
                assisted generation.

        Returns:
            `GenerationMode`: The generation mode triggered by the instance.
        """
        # Determine generation mode based on various configuration parameters
        if self.constraints is not None or self.force_words_ids is not None:
            generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH
        elif self.num_beams == 1:
            if self.do_sample is False:
                if (
                    self.top_k is not None
                    and self.top_k > 1
                    and self.penalty_alpha is not None
                    and self.penalty_alpha > 0
                ):
                    generation_mode = GenerationMode.CONTRASTIVE_SEARCH
                else:
                    generation_mode = GenerationMode.GREEDY_SEARCH
            else:
                generation_mode = GenerationMode.SAMPLE
        else:
            if self.num_beam_groups > 1:
                generation_mode = GenerationMode.GROUP_BEAM_SEARCH
            elif self.do_sample is True:
                generation_mode = GenerationMode.BEAM_SAMPLE
            else:
                generation_mode = GenerationMode.BEAM_SEARCH

        # Modify generation mode if assistant model is specified for assisted generation
        if assistant_model is not None or self.prompt_lookup_num_tokens is not None:
            if generation_mode in (GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE):
                generation_mode = GenerationMode.ASSISTED_GENERATION
            else:
                raise ValueError(
                    "You've set `assistant_model`, which triggers assisted generation. Currently, assisted generate "
                    "is only supported with Greedy Search and Sample."
                )
        # Return the determined generation mode
        return generation_mode

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
        config_file_name: Optional[Union[str, os.PathLike]] = None,
        push_to_hub: bool = False,
        **kwargs,
    ):
        """
        Saves the current configuration to the specified directory.

        Args:
            save_directory (Union[str, os.PathLike]): Directory where the configuration should be saved.
            config_file_name (Optional[Union[str, os.PathLike]], *optional*):
                Name for the configuration file. If not provided, a default name will be used.
            push_to_hub (bool, *optional*):
                Whether to push the saved configuration to the model hub (if applicable).
            **kwargs:
                Additional keyword arguments for future expansion.
        """

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name: Union[str, os.PathLike],
        config_file_name: Optional[Union[str, os.PathLike]] = None,
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
    ):
        """
        Creates an instance of the class from a pretrained model.

        Args:
            pretrained_model_name (Union[str, os.PathLike]): Name or path of the pretrained model.
            config_file_name (Optional[Union[str, os.PathLike]], *optional*):
                Name for the configuration file. If not provided, a default name will be used.
            cache_dir (Optional[Union[str, os.PathLike]], *optional*):
                Directory to cache downloaded files (if applicable).
            force_download (bool, *optional*):
                Whether to force re-download of the model files, ignoring any cached versions.
            local_files_only (bool, *optional*):
                Whether to only consider local files as sources for the model, ignoring any remote repositories.
            token (Optional[Union[str, bool]], *optional*):
                Access token for private model repositories (if applicable).
            revision (str, *optional*):
                Revision or version of the model to load.
            **kwargs:
                Additional keyword arguments for future expansion.

        Returns:
            Instance of the class loaded from the pretrained model.
        """
    # 从给定的 JSON 文件中读取内容并将其解析为 Python 字典
    def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        return json.loads(text)

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
        """
        从一个 Python 字典参数实例化一个 GenerationConfig 对象。

        Args:
            config_dict (`Dict[str, Any]`):
                将用于实例化配置对象的字典。
            kwargs (`Dict[str, Any]`):
                用于初始化配置对象的额外参数。

        Returns:
            [`GenerationConfig`]: 从这些参数实例化的配置对象。
        """
        # 是否返回未使用的关键字参数,默认为 False
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
        # 移除内部遥测用的参数,以防止它们出现在 `return_unused_kwargs` 中
        kwargs.pop("_from_auto", None)
        kwargs.pop("_from_pipeline", None)
        # 如果 `_commit_hash` 在 kwargs 中且在 config_dict 中,则更新 `_commit_hash`
        if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
            kwargs["_commit_hash"] = config_dict["_commit_hash"]

        # 下面的语句允许通过 kwargs 加载特定于模型的配置,并进行安全检查。
        # 参考:https://github.com/huggingface/transformers/pull/21269
        config = cls(**{**config_dict, **kwargs})
        # 更新配置,并返回未使用的关键字参数
        unused_kwargs = config.update(**kwargs)

        # 记录生成的配置信息
        logger.info(f"Generate config {config}")
        if return_unused_kwargs:
            return config, unused_kwargs
        else:
            return config

    # 将字典及其嵌套字典中的 `torch_dtype` 键转换为字符串形式,例如 `torch.float32` 转换为 `"float32"`
    def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
        if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
            d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
        for value in d.values():
            if isinstance(value, dict):
                self.dict_torch_dtype_to_str(value)
    def to_diff_dict(self) -> Dict[str, Any]:
        """
        Removes all attributes from config which correspond to the default config attributes for better readability and
        serializes to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
        """
        # 将当前配置转换为字典形式
        config_dict = self.to_dict()

        # 获取默认配置的字典形式
        default_config_dict = GenerationConfig().to_dict()

        # 初始化一个空字典,用于存储与默认配置不同的配置项
        serializable_config_dict = {}

        # 只序列化与默认配置不同的值
        for key, value in config_dict.items():
            # 如果配置项不在默认配置中,或者是特定例外项,或者值不同,则加入序列化字典中
            if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
                serializable_config_dict[key] = value

        # 转换字典中的 torch 数据类型为字符串表示
        self.dict_torch_dtype_to_str(serializable_config_dict)
        return serializable_config_dict

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary.

        Returns:
            `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
        """
        # 深拷贝对象的 __dict__ 属性,得到一个副本
        output = copy.deepcopy(self.__dict__)

        # 在序列化时忽略的字段
        if "_commit_hash" in output:
            del output["_commit_hash"]
        if "_original_object_hash" in output:
            del output["_original_object_hash"]

        # 序列化时记录 Transformers 版本信息
        output["transformers_version"] = __version__

        # 转换字典中的 torch 数据类型为字符串表示
        self.dict_torch_dtype_to_str(output)
        return output
    def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
        """
        Serializes this instance to a JSON string.

        Args:
            use_diff (`bool`, *optional*, defaults to `True`):
                If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
                is serialized to JSON string.
            ignore_metadata (`bool`, *optional*, defaults to `False`):
                Whether to ignore the metadata fields present in the instance

        Returns:
            `str`: String containing all the attributes that make up this configuration instance in JSON format.
        """
        # 根据 use_diff 参数决定是否只序列化配置实例与默认 GenerationConfig() 之间的差异
        if use_diff is True:
            config_dict = self.to_diff_dict()  # 调用实例方法获取配置实例与默认配置之间的差异字典
        else:
            config_dict = self.to_dict()  # 调用实例方法获取完整的配置实例字典

        # 如果 ignore_metadata 参数为 True,则移除配置字典中的元数据字段
        if ignore_metadata:
            for metadata_field in METADATA_FIELDS:
                config_dict.pop(metadata_field, None)

        # 定义一个函数,将字典中的键转换为字符串类型
        def convert_keys_to_string(obj):
            if isinstance(obj, dict):
                return {str(key): convert_keys_to_string(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [convert_keys_to_string(item) for item in obj]
            else:
                return obj

        # 转换配置字典中所有键为字符串类型
        config_dict = convert_keys_to_string(config_dict)

        # 将转换后的配置字典转换为带缩进、按键排序的 JSON 格式字符串,并添加换行符
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this configuration instance's parameters will be saved.
            use_diff (`bool`, *optional*, defaults to `True`):
                If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
                is serialized to JSON file.
        """
        # 打开指定路径的 JSON 文件,并将实例转换为 JSON 字符串后写入文件
        with open(json_file_path, "w", encoding="utf-8") as writer:
            writer.write(self.to_json_string(use_diff=use_diff))

    @classmethod
    def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
        """
        从一个预训练配置 (`PretrainedConfig`) 实例化一个生成配置 (`GenerationConfig`)。
        这个函数用于将可能包含生成参数的旧式预训练配置对象转换为独立的生成配置对象。

        Args:
            model_config (`PretrainedConfig`):
                将用于实例化生成配置的模型配置。

        Returns:
            [`GenerationConfig`]: 从这些参数实例化的配置对象。
        """
        # 将模型配置转换为字典
        config_dict = model_config.to_dict()
        # 移除特定的属性,这些属性不应该用于构建生成配置
        config_dict.pop("_from_model_config", None)
        # 通过字典创建生成配置对象,确保不返回未使用的关键字参数
        config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

        # 特殊情况:某些模型在解码器中设置了生成属性。如果生成配置中仍未设置这些属性,则使用解码器中的值。
        for decoder_name in ("decoder", "generator", "text_config"):
            if decoder_name in config_dict:
                default_generation_config = GenerationConfig()
                decoder_config = config_dict[decoder_name]
                # 检查生成配置中的每个属性,如果属性在解码器配置中存在且生成配置中未设置,则设置为解码器中的值
                for attr in config.to_dict().keys():
                    if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
                        setattr(config, attr, decoder_config[attr])

        # 计算对象的哈希值,用于检测实例是否已修改
        config._original_object_hash = hash(config)
        return config

    def update(self, **kwargs):
        """
        使用 `kwargs` 中的属性更新该类实例的属性,如果属性匹配现有属性,则返回所有未使用的 kwargs。

        Args:
            kwargs (`Dict[str, Any]`):
                尝试更新此类的属性的属性字典。

        Returns:
            `Dict[str, Any]`: 包含所有未用于更新实例的键值对的字典。
        """
        to_remove = []
        # 遍历传入的关键字参数
        for key, value in kwargs.items():
            # 如果类实例具有这个属性,则更新为传入的值,并记录已更新的属性名
            if hasattr(self, key):
                setattr(self, key, value)
                to_remove.append(key)

        # 确保更新后的实例仍然有效
        self.validate()

        # 返回所有未使用的关键字参数,即未更新到类实例的参数
        unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
        return unused_kwargs

.\generation\flax_logits_process.py

# coding=utf-8
# 导入inspect模块,用于检查和获取源代码信息
import inspect

# 导入JAX库
import jax
import jax.lax as lax
import jax.numpy as jnp

# 从上级目录中导入工具函数
from ..utils import add_start_docstrings
# 从日志记录工具中导入日志记录器
from ..utils.logging import get_logger

# 获取当前模块的日志记录器
logger = get_logger(__name__)

# 定义文档字符串常量,描述了logits处理器的输入和返回值
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            输入序列标记的索引,形状为(batch_size, sequence_length)。

            可以使用[`PreTrainedTokenizer`]来获取索引。参见[`PreTrainedTokenizer.encode`]和
            [`PreTrainedTokenizer.__call__`]获取详情。

            [什么是输入ID?](../glossary#input-ids)
        scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`):
            语言模型头的预测分数。当不使用beam搜索时,这些可以是每个词汇的logits;当使用beam搜索时,可以是
            每个词汇token的log softmax。
        kwargs (`Dict[str, Any]`, *optional*):
            特定于logits处理器的额外kwargs参数。

    Return:
        `jnp.ndarray` of shape `(batch_size, config.vocab_size)`: 处理后的预测分数。

"""

# 定义FlaxLogitsProcessor类,抽象基类,用于在生成过程中应用所有logits处理器
class FlaxLogitsProcessor:
    """用于生成过程中可以应用的所有logits处理器的抽象基类。"""

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
        """处理logits的Flax方法。"""
        # 抛出未实现错误,提示该类为抽象类,只能通过继承该类的子类调用
        raise NotImplementedError(
            f"{self.__class__}是一个抽象类。只有继承了这个类的类才能被调用。"
        )

# 定义FlaxLogitsWarper类,抽象基类,用于在使用多项式采样的生成过程中应用所有logit变形器
class FlaxLogitsWarper:
    """用于使用多项式采样生成过程中可以应用的所有logit变形器的抽象基类。"""

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
        """变形logits的Flax方法。"""
        # 抛出未实现错误,提示该类为抽象类,只能通过继承该类的子类调用
        raise NotImplementedError(
            f"{self.__class__}是一个抽象类。只有继承了这个类的类才能被调用。"
        )

# 定义FlaxLogitsProcessorList类,继承自list,用于创建一个[`FlaxLogitsProcessor`]或[`FlaxLogitsWarper`]列表,
# 并能够对输入的`scores`张量应用每一个处理器或变形器
class FlaxLogitsProcessorList(list):
    """
    此类可用于创建[`FlaxLogitsProcessor`]或[`FlaxLogitsWarper`]的列表,以随后处理`scores`输入张量。
    此类继承自列表,并添加了一个特定的*__call__*方法来应用每个[`FlaxLogitsProcessor`]或[`FlaxLogitsWarper`]到输入上。
    """
    """
    对象方法,根据给定的输入和参数处理逻辑
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
        # 遍历每个处理器对象
        for processor in self:
            # 获取处理器的调用方法参数签名
            function_args = inspect.signature(processor.__call__).parameters
            # 如果参数个数大于3
            if len(function_args) > 3:
                # 检查是否所有所需的参数都在kwargs中
                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
                    # 如果有缺失参数,抛出数值错误异常
                    raise ValueError(
                        f"Make sure that all the required parameters: {list(function_args.keys())} for "
                        f"{processor.__class__} are passed to the logits processor."
                    )
                # 调用处理器的方法,传入输入数据、得分、当前长度和其他参数
                scores = processor(input_ids, scores, cur_len, **kwargs)
            else:
                # 如果参数个数不大于3,直接调用处理器的方法,传入输入数据、得分和当前长度
                scores = processor(input_ids, scores, cur_len)
        # 返回处理后的得分
        return scores
    ```
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
    r"""
    [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).

    Args:
        temperature (`float`):
            The value used to module the logits distribution.
    """

    def __init__(self, temperature: float):
        # 检查温度参数是否为正浮点数,如果不是则抛出异常
        if not isinstance(temperature, float) or not (temperature > 0):
            raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")

        self.temperature = temperature

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        # 将得分按温度值缩放,用于温度调节输出概率分布
        scores = scores / self.temperature
        return scores


class FlaxTopPLogitsWarper(FlaxLogitsWarper):
    """
    [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.

    Args:
        top_p (`float`):
            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
            higher are kept for generation.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """

    def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 检查 top_p 是否为介于 0 和 1 之间的浮点数,否则抛出异常
        if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
        # 检查 min_tokens_to_keep 是否为正整数,否则抛出异常
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")

        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        # 获取前 k 个最高得分和其对应的索引
        topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])

        # 创建一个与 scores 形状相同的数组,填充为 filter_value
        mask_scores = jnp.full_like(scores, self.filter_value)
        # 计算 softmax 后的累积概率
        cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
        # 创建用于掩码的布尔数组,仅保留累积概率小于 top_p 的部分
        score_mask = cumulative_probs < self.top_p

        # 将累积概率大于 top_p 的位置移到 score_mask 中
        score_mask = jnp.roll(score_mask, 1)
        score_mask |= score_mask.at[:, 0].set(True)

        # 至少保留 min_tokens_to_keep 个 token
        score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)

        # 根据 score_mask 选择相应的得分值或者 filter_value
        topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
        # 按照 topk_indices 排序,获取排序后的最终得分
        next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]

        return next_scores


class FlaxTopKLogitsWarper(FlaxLogitsWarper):
    r"""
    [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """
    定义一个类,用于执行Top-K筛选操作,保留概率最高的词汇标记。

    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        # 初始化方法,设置Top-K值,并确保不小于最小保留标记数
        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        # 调用实例时,执行Top-K筛选操作

        # 获取输入的批次大小和词汇表大小
        batch_size, vocab_size = scores.shape

        # 初始化一个数组,用来存储被过滤后的分数值,默认为filter_value
        next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)

        # 确定实际的Top-K值,避免超过分数数组的长度
        topk = min(self.top_k, scores.shape[-1])

        # 使用JAX库中的top_k函数找到每个批次中前Top-K个分数及其对应的索引
        topk_scores, topk_indices = lax.top_k(scores, topk)

        # 计算扁平化后的索引偏移,以便在一维数组中正确设置Top-K分数
        shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
        topk_scores_flat = topk_scores.flatten()
        topk_indices_flat = topk_indices.flatten() + shift

        # 在next_scores_flat数组中设置Top-K分数值
        next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)

        # 将扁平化后的数组重新形状为(batch_size, vocab_size),得到最终的Top-K分数数组
        next_scores = next_scores_flat.reshape(batch_size, vocab_size)
        return next_scores
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
    r"""
    [`FlaxLogitsProcessor`] that enforces the specified token as the first generated token.

    Args:
        bos_token_id (`int`):
            The id of the token to force as the first generated token.
    """

    def __init__(self, bos_token_id: int):
        self.bos_token_id = bos_token_id  # 初始化函数,保存要强制作为第一个生成token的token id

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        new_scores = jnp.full(scores.shape, -float("inf"))  # 创建一个形状与scores相同的全负无穷数组

        apply_penalty = 1 - jnp.bool_(cur_len - 1)  # 根据当前生成长度是否为0,决定是否应用惩罚

        scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
        # 根据apply_penalty条件,将scores中对应bos_token_id列的值设置为0,其它位置不变

        return scores


class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
    r"""
    [`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.

    Args:
        max_length (`int`):
            The maximum length of the sequence to be generated.
        eos_token_id (`int`):
            The id of the token to force as the last generated token when `max_length` is reached.
    """

    def __init__(self, max_length: int, eos_token_id: int):
        self.max_length = max_length  # 初始化函数,保存最大生成长度
        self.eos_token_id = eos_token_id  # 初始化函数,保存要强制作为末尾生成token的token id

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        new_scores = jnp.full(scores.shape, -float("inf"))  # 创建一个形状与scores相同的全负无穷数组

        apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
        # 根据当前生成长度是否为max_length,决定是否应用惩罚

        scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
        # 根据apply_penalty条件,将scores中对应eos_token_id列的值设置为0,其它位置不变

        return scores


class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
    r"""
    [`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.

    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`int`):
            The id of the *end-of-sequence* token.
    """

    def __init__(self, min_length: int, eos_token_id: int):
        if not isinstance(min_length, int) or min_length < 0:
            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

        if not isinstance(eos_token_id, int) or eos_token_id < 0:
            raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")

        self.min_length = min_length  # 初始化函数,保存最小生成长度
        self.eos_token_id = eos_token_id  # 初始化函数,保存要设置其概率为负无穷的token id

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        # create boolean flag to decide if min length penalty should be applied
        apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
        # 根据当前生成长度是否小于min_length,决定是否应用惩罚

        scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
        # 根据apply_penalty条件,将scores中对应eos_token_id列的值设置为负无穷,其它位置不变

        return scores


class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
    r"""
    # 定义一个处理类 `FlaxLogitsProcessor`,用于在 `generate` 函数开始生成时抑制一组指定的 token。
    # 这应该确保在生成的开头,由 `begin_suppress_tokens` 定义的 token 不会被抽样到。

    Args:
        begin_suppress_tokens (`List[int]`):
            不抽样的 token 列表。
        begin_index (`int`):
            开始抑制 token 的索引位置。
    """

    class FlaxLogitsProcessor:
        def __init__(self, begin_suppress_tokens, begin_index):
            # 将输入的 begin_suppress_tokens 转换为列表
            self.begin_suppress_tokens = list(begin_suppress_tokens)
            # 设置开始抑制 token 的索引位置
            self.begin_index = begin_index

        def __call__(self, input_ids, scores, cur_len: int):
            # 根据当前生成长度 `cur_len` 和开始抑制的索引 `begin_index` 计算是否应用惩罚
            apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)

            # 根据应用的惩罚,将指定的 `begin_suppress_tokens` 的分数设置为负无穷大
            scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)

            # 返回处理后的分数
            return scores
class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
    r"""
    [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
    to be `-inf` so they are not sampled.

    Args:
        suppress_tokens (`list`):
            Tokens to not sample.
    """

    def __init__(self, suppress_tokens: list):
        # 初始化方法,接收一个要抑制的token列表
        self.suppress_tokens = list(suppress_tokens)

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        # 在scores张量的指定位置设置为负无穷,以便在采样时不被选中
        scores = scores.at[..., self.suppress_tokens].set(-float("inf"))

        return scores


class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
    r"""
    [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
    token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
    to `-inf` so that they are sampled at their corresponding index.

    Args:
        force_token_map (`list`):
            Map giving token ids and indices where they will be forced to be sampled.
    """

    def __init__(self, force_token_map):
        # 将force_token_map转换为字典格式,并初始化一个强制token的数组以提高XLA的兼容性
        force_token_map = dict(force_token_map)
        force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
        for index, token in force_token_map.items():
            if token is not None:
                force_token_array = force_token_array.at[index].set(token)
        self.force_token_array = jnp.int32(force_token_array)

    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        def _force_token(generation_idx):
            # 根据generation_idx确定要强制采样的token,并更新scores张量
            batch_size = scores.shape[0]
            current_token = self.force_token_array[generation_idx]

            new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
            updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
            new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
            return new_scores

        # 使用lax.cond根据cur_len的值来决定是否进行token强制操作
        scores = lax.cond(
            cur_len >= self.force_token_array.shape[0],
            # 如果当前长度大于等于force_token_array的长度,则不进行强制操作
            lambda: scores,
            # 否则,根据force_token_array[cur_len]的值来判断是否强制采样特定token
            lambda: lax.cond(
                self.force_token_array[cur_len] >= 0,
                # 只有有效(非负)的token才会被强制采样
                lambda: _force_token(cur_len),
                # 否则不进行强制操作
                lambda: scores,
            ),
        )
        return scores
class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
    r"""
    Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
    probs to `inf` so that they are sampled at their corresponding index.

    Args:
        generate_config (`GenerateConfig`):
            The generate config used to generate the output. The following parameters are required:
                eos_token_id (`int`, *optional*, defaults to 50257):
                    The id of the *end-of-sequence* token.
                no_timestamps_token_id (`int`, *optional*, defaults to 50363):
                    The id of the `"<|notimestamps|>"` token.
                max_initial_timestamp_index (`int`, *optional*, defaults to 1):
                    Used to set the maximum value of the initial timestamp. This is used to prevent the model from
                    predicting timestamps that are too far in the future.
    """

    def __init__(self, generate_config, model_config, decoder_input_length):
        # 初始化方法,设置对象的初始属性
        self.eos_token_id = generate_config.eos_token_id
        self.no_timestamps_token_id = generate_config.no_timestamps_token_id
        # 设置时间戳开始的位置
        self.timestamp_begin = generate_config.no_timestamps_token_id + 1

        # 设置开始索引,考虑解码器输入长度
        self.begin_index = decoder_input_length + 1

        # 如果是多语言模型,为语言标记和任务标记预留空间
        if generate_config.is_multilingual:
            self.begin_index += 2
        
        # 如果生成配置有最大初始时间戳索引属性,使用该值;否则使用模型词汇表大小
        if hasattr(generate_config, "max_initial_timestamp_index"):
            self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
        else:
            self.max_initial_timestamp_index = model_config.vocab_size
        
        # 如果最大初始时间戳索引为 None,则设为模型词汇表大小
        if self.max_initial_timestamp_index is None:
            self.max_initial_timestamp_index = model_config.vocab_size
    def __call__(self, input_ids, scores, cur_len):
        # 将包含 self.no_timestamps_token_id 的列设为负无穷,这由 without_timestamps 处理
        scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))

        def handle_pairs(input_ids_k, scores_k):
            # 判断前一个 token 是否为时间戳,如果是,则设置为 True,否则为 False
            last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
            last_was_timestamp = jnp.where(
                input_ids_k[cur_len - 1] >= self.timestamp_begin,
                True and last_was_timestamp,
                False,
            )

            # 判断倒数第二个 token 是否为时间戳,如果是,则设置为 True,否则为 False
            penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
            penultimate_was_timestamp = jnp.where(
                input_ids_k[cur_len - 2] >= self.timestamp_begin,
                True,
                penultimate_was_timestamp,
            )

            return jnp.where(
                last_was_timestamp,
                jnp.where(
                    penultimate_was_timestamp > 0,
                    scores_k.at[self.timestamp_begin :].set(-float("inf")),  # 如果倒数第二个是时间戳,则将时间戳之后的分数设为负无穷
                    scores_k.at[: self.eos_token_id].set(-float("inf")),  # 否则将句子结束符之前的分数设为负无穷
                ),
                scores_k,  # 如果前一个不是时间戳,则保持分数不变
            )

        # 对每对 (input_ids, scores) 应用 handle_pairs 函数
        scores = jax.vmap(handle_pairs)(input_ids, scores)

        # 判断是否应用最大初始时间戳策略
        apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
        apply_max_initial_timestamp = jnp.where(
            self.max_initial_timestamp_index is not None,
            True and apply_max_initial_timestamp,
            False,
        )

        # 计算最大允许的时间戳
        last_allowed = self.timestamp_begin + self.max_initial_timestamp_index

        # 如果应用最大初始时间戳策略,则将分数矩阵中大于最大允许时间戳之后的分数设为负无穷
        scores = jnp.where(
            apply_max_initial_timestamp,
            scores.at[:, last_allowed + 1 :].set(-float("inf")),
            scores,
        )

        # 如果时间戳的概率总和超过其它 token 的概率总和,则将时间戳之前的分数设为负无穷
        logprobs = jax.nn.log_softmax(scores, axis=-1)

        def handle_cumulative_probs(logprobs_k, scores_k):
            timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
            max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
            return jnp.where(
                timestamp_logprob > max_text_token_logprob,
                scores_k.at[: self.timestamp_begin].set(-float("inf")),  # 如果时间戳的概率总和高于其它 token,则将时间戳之前的分数设为负无穷
                scores_k,  # 否则保持分数不变
            )

        # 对每个 (logprobs, scores) 应用 handle_cumulative_probs 函数
        scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)

        # 返回处理后的分数矩阵
        return scores

.\generation\flax_utils.py

# 导入必要的库和模块
import copy                    # 导入 copy 模块,用于复制对象
import inspect                 # 导入 inspect 模块,用于检查对象
import warnings                # 导入 warnings 模块,用于警告处理
from functools import partial  # 导入 partial 函数,用于创建部分应用的函数
from typing import Any, Dict, Optional, Union  # 导入类型提示相关模块

import flax                   # 导入 flax 框架
import jax                    # 导入 jax 框架
import jax.numpy as jnp       # 导入 jax 的 numpy 接口作为 jnp
import numpy as np            # 导入 numpy 库
from jax import lax           # 导入 jax 的 lax 模块

from ..models.auto import (   # 导入自定义模块中的多个自动模型映射
    FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
    FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, logging  # 导入自定义模块中的 ModelOutput 和 logging
from .configuration_utils import GenerationConfig  # 导入本地模块中的 GenerationConfig 类
from .flax_logits_process import (  # 导入本地模块中的多个 logits 处理器类
    FlaxForcedBOSTokenLogitsProcessor,
    FlaxForcedEOSTokenLogitsProcessor,
    FlaxForceTokensLogitsProcessor,
    FlaxLogitsProcessorList,
    FlaxMinLengthLogitsProcessor,
    FlaxSuppressTokensAtBeginLogitsProcessor,
    FlaxSuppressTokensLogitsProcessor,
    FlaxTemperatureLogitsWarper,
    FlaxTopKLogitsWarper,
    FlaxTopPLogitsWarper,
)

logger = logging.get_logger(__name__)  # 获取当前模块的 logger 实例


@flax.struct.dataclass
class FlaxGreedySearchOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using greedy search.


    Args:
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
            The generated sequences.
    """

    sequences: jnp.ndarray = None  # 类属性,存储生成的序列数据


@flax.struct.dataclass
class FlaxSampleOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using sampling.


    Args:
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
            The generated sequences.
    """

    sequences: jnp.ndarray = None  # 类属性,存储生成的序列数据


@flax.struct.dataclass
class FlaxBeamSearchOutput(ModelOutput):
    """
    Flax Base class for outputs of decoder-only generation models using beam search.


    Args:
        sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
            The generated sequences.
        scores (`jnp.ndarray` of shape `(batch_size,)`):
            The scores (log probabilities) of the generated sequences.
    """

    sequences: jnp.ndarray = None  # 类属性,存储生成的序列数据
    scores: jnp.ndarray = None     # 类属性,存储生成序列的分数(对数概率)


@flax.struct.dataclass
class GreedyState:
    """
    Dataclass to store state during greedy decoding.

    Args:
        cur_len (`jnp.ndarray`): Current lengths of sequences.
        sequences (`jnp.ndarray`): Generated sequences.
        running_token (`jnp.ndarray`): Running tokens for decoding.
        is_sent_finished (`jnp.ndarray`): Boolean array indicating finished sentences.
        model_kwargs (Dict[str, jnp.ndarray]): Additional model arguments.
    """

    cur_len: jnp.ndarray            # 当前序列长度
    sequences: jnp.ndarray          # 生成的序列
    running_token: jnp.ndarray      # 解码中的当前 token
    is_sent_finished: jnp.ndarray   # 表示句子是否结束的布尔数组
    model_kwargs: Dict[str, jnp.ndarray]  # 存储额外模型参数的字典


@flax.struct.dataclass
class SampleState:
    """
    Dataclass to store state during sampling.

    Args:
        cur_len (`jnp.ndarray`): Current lengths of sequences.
    """

    cur_len: jnp.ndarray  # 当前序列长度
    # 定义变量 sequences,类型为 jnp.ndarray,用于存储序列数据
    sequences: jnp.ndarray
    # 定义变量 running_token,类型为 jnp.ndarray,用于存储运行中的标记数据
    running_token: jnp.ndarray
    # 定义变量 is_sent_finished,类型为 jnp.ndarray,用于存储句子完成状态的数据
    is_sent_finished: jnp.ndarray
    # 定义变量 prng_key,类型为 jnp.ndarray,用于存储伪随机数生成器密钥的数据
    prng_key: jnp.ndarray
    # 定义变量 model_kwargs,类型为 Dict[str, jnp.ndarray],用于存储模型参数的字典,其中键为字符串,值为 jnp.ndarray 类型
    model_kwargs: Dict[str, jnp.ndarray]
@flax.struct.dataclass
class BeamSearchState:
    cur_len: jnp.ndarray  # 当前长度,作为一个 NumPy 数组
    running_sequences: jnp.ndarray  # 正在运行的序列,作为一个 NumPy 数组
    running_scores: jnp.ndarray  # 运行中的分数,作为一个 NumPy 数组
    sequences: jnp.ndarray  # 序列,作为一个 NumPy 数组
    scores: jnp.ndarray  # 分数,作为一个 NumPy 数组
    is_sent_finished: jnp.ndarray  # 标志句子是否完成的数组,作为一个 NumPy 数组
    model_kwargs: Dict[str, jnp.ndarray]  # 模型参数,字典形式,键为字符串,值为 NumPy 数组


class FlaxGenerationMixin:
    """
    包含自回归文本生成的所有函数的类,作为[`FlaxPreTrainedModel`]的混合类使用。

    该类公开[`~generation.FlaxGenerationMixin.generate`]方法,可用于:
            - 当`num_beams=1`且`do_sample=False`时通过调用[`~generation.FlaxGenerationMixin._greedy_search`]进行贪婪解码
            - 当`num_beams=1`且`do_sample=True`时通过调用[`~generation.FlaxGenerationMixin._sample`]进行多项式采样
            - 当`num_beams>1`且`do_sample=False`时通过调用[`~generation.FlaxGenerationMixin._beam_search`]进行束搜索解码

    无需直接调用上述任何方法。只需将自定义参数值传递给'generate'方法即可。有关解码策略的更多信息,请参阅[文本生成策略指南](../generation_strategies)。
    """

    def prepare_inputs_for_generation(self, *args, **kwargs):
        raise NotImplementedError(
            "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
        )

    @staticmethod
    def _run_loop_in_debug(cond_fn, body_fn, init_state):
        """
        以非跟踪模式运行生成过程。仅用于调试目的。
        """
        state = init_state  # 初始化状态
        while cond_fn(state):  # 当条件函数为真时循环执行
            state = body_fn(state)  # 执行主体函数
        return state  # 返回最终状态

    def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, model_kwargs):
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
        }
        model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs)
        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
        self,
        batch_size: int,
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ) -> jnp.ndarray:
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            # 如果模型参数中存在'decoder_input_ids',则使用它,否则从模型参数中移除
            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
            if decoder_input_ids is not None:
                return decoder_input_ids  # 返回decoder_input_ids
        decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
        return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0)
    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        # 检索用于编码器-解码器模型的decoder_start_token_id
        # 如果需要,可以回退到bos_token_id
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id is not None
            else self.generation_config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
        # 如果decoder_start_token_id已经定义,则返回它
        if decoder_start_token_id is not None:
            return decoder_start_token_id
        # 否则,检查配置是否具有decoder_start_token_id,并且不为None
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "decoder_start_token_id")
            and self.config.decoder.decoder_start_token_id is not None
        ):
            return self.config.decoder.decoder_start_token_id
        # 如果以上条件不满足,检查是否定义了bos_token_id,并且不为None
        elif bos_token_id is not None:
            return bos_token_id
        # 最后如果bos_token_id也未定义,则引发ValueError
        elif (
            hasattr(self.config, "decoder")
            and hasattr(self.config.decoder, "bos_token_id")
            and self.config.decoder.bos_token_id is not None
        ):
            return self.config.decoder.bos_token_id
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )

    @staticmethod
    def _expand_to_num_beams(tensor, num_beams):
        # 将tensor扩展为num_beams数量的beam搜索结果
        return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])

    def _adapt_logits_for_beam_search(self, logits):
        """
        This function can be overwritten in the specific modeling_flax_<model-name>.py classes to allow for custom beam
        search behavior. Note that the only model that overwrites this method is [`~transformes.FlaxMarianMTModel`].
        """
        # 默认情况下,直接返回logits,这个方法可以在具体的modeling_flax_<model-name>.py类中被覆盖,以允许自定义beam搜索行为。
        return logits
    def _validate_model_class(self):
        """
        Confirms that the model class is compatible with generation. If not, raises an exception that points to the
        right class to use.
        """
        # 检查当前模型是否支持生成操作
        if not self.can_generate():
            # 定义支持生成操作的模型映射列表
            generate_compatible_mappings = [
                FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
                FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
                FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
            ]
            # 收集所有兼容的模型类名
            generate_compatible_classes = set()
            for model_mapping in generate_compatible_mappings:
                # 获取当前模型配置对应的支持模型
                supported_models = model_mapping.get(type(self.config), default=None)
                if supported_models is not None:
                    generate_compatible_classes.add(supported_models.__name__)
            # 构建异常消息
            exception_message = (
                f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
                "it doesn't have a language model head."
            )
            # 如果存在兼容的模型类,则添加建议使用的类名到异常消息中
            if generate_compatible_classes:
                exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
            # 抛出类型错误异常,指示模型类不兼容生成操作
            raise TypeError(exception_message)

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # 初始化未使用的模型参数列表
        unused_model_args = []
        # 获取用于生成输入的参数名称集合
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # 如果 `kwargs` 或 `model_kwargs` 在模型参数中,扩展模型参数集合
        if "kwargs" in model_args or "model_kwargs" in model_args:
            model_args |= set(inspect.signature(self.__call__).parameters)
        # 检查传入的 `model_kwargs` 是否有未使用的参数
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        # 如果存在未使用的模型参数,抛出值错误异常
        if unused_model_args:
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )
    
    def generate(
        self,
        input_ids: jnp.ndarray,
        generation_config: Optional[GenerationConfig] = None,
        prng_key: Optional[jnp.ndarray] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        **kwargs,
    def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList:
        """
        返回一个 [`FlaxLogitsProcessorList`] 列表对象,其中包含所有用于多项式采样的相关 [`FlaxLogitsWarper`] 实例。
        """
        # 创建一个空的 FlaxLogitsProcessorList 对象,用于存储 Logits 处理器
        warpers = FlaxLogitsProcessorList()

        # 如果设置了温度且不等于 1.0,则添加温度调整器
        if generation_config.temperature is not None and generation_config.temperature != 1.0:
            warpers.append(FlaxTemperatureLogitsWarper(generation_config.temperature))
        # 如果设置了 top_k 且不等于 0,则添加 top_k 调整器
        if generation_config.top_k is not None and generation_config.top_k != 0:
            warpers.append(FlaxTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
        # 如果设置了 top_p 且小于 1.0,则添加 top_p 调整器
        if generation_config.top_p is not None and generation_config.top_p < 1.0:
            warpers.append(FlaxTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))

        # 返回配置好的 warpers 列表对象
        return warpers

    def _get_logits_processor(
        self,
        generation_config: GenerationConfig,
        input_ids_seq_length: int,
        logits_processor: Optional[FlaxLogitsProcessorList],
    ) -> FlaxLogitsProcessorList:
        """
        This method returns a [`FlaxLogitsProcessorList`] object containing all relevant
        [`FlaxLogitsProcessor`] instances used to modify the scores of the language model head.
        """
        processors = FlaxLogitsProcessorList()

        # Check if minimum length and end-of-sequence token ID are specified and valid
        if (
            generation_config.min_length is not None
            and generation_config.eos_token_id is not None
            and generation_config.min_length > -1
        ):
            # Append a processor to enforce minimum length and end token ID constraints
            processors.append(
                FlaxMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)
            )
        
        # Check if forced beginning-of-sequence token ID is specified
        if generation_config.forced_bos_token_id is not None:
            # Append a processor to force the beginning-of-sequence token ID
            processors.append(FlaxForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
        
        # Check if forced end-of-sequence token ID is specified
        if generation_config.forced_eos_token_id is not None:
            # Append a processor to force the end-of-sequence token ID
            processors.append(
                FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
            )
        
        # Check if tokens to suppress are specified
        if generation_config.suppress_tokens is not None:
            # Append a processor to suppress specific tokens
            processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
        
        # Check if tokens to suppress at the beginning are specified
        if generation_config.begin_suppress_tokens is not None:
            begin_index = input_ids_seq_length
            
            # Adjust beginning index based on conditions
            begin_index = (
                begin_index
                if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
                else begin_index + 1
            )
            
            # Adjust beginning index further based on forced decoder IDs
            if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
                begin_index += generation_config.forced_decoder_ids[-1][0]
            
            # Append a processor to suppress tokens at the beginning
            processors.append(
                FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
            )
        
        # Check if forced decoder IDs are specified
        if generation_config.forced_decoder_ids is not None:
            # Calculate adjusted IDs for forced tokens
            forced_decoder_ids = [
                [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
            ]
            
            # Append a processor to force tokens based on adjusted IDs
            processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids))
        
        # Merge the default processors list with any custom processors provided
        processors = self._merge_criteria_processor_list(processors, logits_processor)

        return processors

    def _merge_criteria_processor_list(
        self,
        default_list: FlaxLogitsProcessorList,
        custom_list: FlaxLogitsProcessorList,
        ) -> FlaxLogitsProcessorList:
        """
        This method merges a default list of logits processors with a custom list of logits processors.
        It returns a combined [`FlaxLogitsProcessorList`] object.
        """
    ) -> FlaxLogitsProcessorList:
        # 如果自定义列表为空,则直接返回默认列表
        if len(custom_list) == 0:
            return default_list
        # 遍历默认列表中的每个元素
        for default in default_list:
            # 遍历自定义列表中的每个元素
            for custom in custom_list:
                # 如果自定义元素的类型与默认元素相同
                if type(custom) is type(default):
                    # 确定对象类型为"logits processor"
                    object_type = "logits processor"
                    # 抛出值错误,说明已经创建了相同类型的自定义对象
                    raise ValueError(
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
                    )
        # 将自定义列表中的元素追加到默认列表中
        default_list.extend(custom_list)
        # 返回合并后的默认列表
        return default_list

    def _greedy_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    def _sample(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jnp.ndarray] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[Union[bool, str]] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        num_return_sequences: Optional[int] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,

.\generation\logits_process.py

# 设置代码文件的编码格式为 UTF-8
# 版权声明,指明该代码的版权归 HuggingFace Inc. 团队所有
#
# 根据 Apache 许可证 2.0 版本,除非符合许可证的要求,否则不得使用此文件
# 可以在以下链接获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则本软件按"原样"分发,不附带任何形式的明示或暗示担保或条件
# 请查看许可证了解详细信息

# 导入所需的模块和函数
import inspect
import math
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

# 导入 numpy 和 torch 模块
import numpy as np
import torch

# 从相对路径导入 utils 模块中的 add_start_docstrings 函数
from ..utils import add_start_docstrings
# 从 logging 模块中导入 get_logger 函数
from ..utils.logging import get_logger

# 获取当前模块的 logger 对象
logger = get_logger(__name__)

# 定义一个原始文档字符串,用于记录 logits 处理函数的输入和返回说明
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            输入序列标记在词汇表中的索引。[什么是输入 ID?](../glossary#input-ids)
        scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
            语言建模头的预测分数。当不使用 beam search 时,这些可以是每个词汇表的 logits;
            当使用 beam search 时,这些可以是每个词汇表标记的对数 softmax
        
    Return:
        `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: 处理后的预测分数。
"""

class LogitsProcessor:
    """所有生成过程中可以应用的 logits 处理器的抽象基类。"""

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 抽象方法,需要被继承此类的类实现具体逻辑
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )


class LogitsWarper:
    """所有多项式采样生成过程中可以应用的 logits 转换器的抽象基类。"""

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 抽象方法,需要被继承此类的类实现具体逻辑
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )


class LogitsProcessorList(list):
    """
    可用于创建一个 [`LogitsProcessor`] 或 [`LogitsWarper`] 列表,以便随后处理输入张量 `scores`。
    此类继承自列表,并添加了一个特定的 *__call__* 方法来对输入应用每个 [`LogitsProcessor`] 或 [`LogitsWarper`]。
    """
    # 定义一个特殊方法 `__call__`,使得对象可以像函数一样被调用
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            kwargs (`Dict[str, Any]`, *optional*):
                Additional kwargs that are specific to a logits processor.

        Return:
            `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
                The processed prediction scores.

        """
        # 遍历对象中所有的处理器
        for processor in self:
            # 获取处理器的 __call__ 方法的参数签名
            function_args = inspect.signature(processor.__call__).parameters
            # 如果处理器的 __call__ 方法参数个数大于2
            if len(function_args) > 2:
                # 检查所有除了前两个参数(self 和 input_ids)外的参数是否在 kwargs 中
                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
                    # 如果有未传递的参数,则抛出 ValueError 异常
                    raise ValueError(
                        f"Make sure that all the required parameters: {list(function_args.keys())} for "
                        f"{processor.__class__} are passed to the logits processor."
                    )
                # 调用处理器的 __call__ 方法,传递 input_ids, scores 和 kwargs
                scores = processor(input_ids, scores, **kwargs)
            else:
                # 调用处理器的 __call__ 方法,传递 input_ids 和 scores
                scores = processor(input_ids, scores)

        # 返回处理后的预测分数
        return scores
# 定义一个新的 logits 处理器类,继承自 LogitsProcessor
class MinLengthLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. Note that, for decoder-only models
    like most LLMs, the length includes the prompt.

    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`Union[int, List[int]]`):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.

    Examples:

    ```
    >>> from transformers import AutoModelForCausalLM, AutoTokenizer

    >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
    >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")

    >>> inputs = tokenizer("A number:", return_tensors="pt")
    >>> gen_out = model.generate(**inputs)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one

    >>> # setting `min_length` to a value smaller than the uncontrolled output length has no impact
    >>> gen_out = model.generate(**inputs, min_length=3)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one

    >>> # setting a larger `min_length` will force the model to generate beyond its natural ending point, which is not
    >>> # necessarily incorrect
    >>> gen_out = model.generate(**inputs, min_length=10)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one thousand, nine hundred and ninety-four
    ```

    """

    # 初始化方法,接受最小长度和 EOS 标记 ID
    def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
        # 检查 min_length 必须为非负整数
        if not isinstance(min_length, int) or min_length < 0:
            raise ValueError(f"`min_length` has to be a non-negative integer, but is {min_length}")

        # 如果 eos_token_id 是单个整数,则转换为列表形式
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        # 检查 eos_token_id 必须为正整数列表
        if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
            logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

        # 初始化对象的属性
        self.min_length = min_length
        self.eos_token_id = eos_token_id

    # 调用方法,处理输入的 logits 和分数,并返回处理后的分数
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前输入的长度
        cur_len = input_ids.shape[-1]
        # 如果当前长度小于最小长度
        if cur_len < self.min_length:
            # 将所有 EOS 标记的分数设为负无穷
            for i in self.eos_token_id:
                scores[:, i] = -float("inf")
        # 返回处理后的分数
        return scores


# 定义另一个新的 logits 处理器类,继承自 LogitsProcessor
class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
    Contrarily to [`MinLengthLogitsProcessor`], this processor ignores the prompt.
    ```

    # 注释继续在下一个代码块中
    Args:
        prompt_length_to_skip (`int`):
            要跳过的输入标记长度。与 `generate` 一起使用时,不是有效的参数,因为它会自动分配输入长度。
        min_new_tokens (`int`):
            下面这个得分为 `-float("Inf")` 的条件最小 *新* 标记长度。
        eos_token_id (`Union[int, List[int]]`):
            *结束序列* 标记的 ID。可选择使用列表设置多个 *结束序列* 标记。

    Examples:

    ```
    >>> from transformers import AutoModelForCausalLM, AutoTokenizer

    >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")
    >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")

    >>> inputs = tokenizer(["A number:"], return_tensors="pt")
    >>> gen_out = model.generate(**inputs)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one

    >>> # 设置 `min_new_tokens` 将强制模型生成超出其自然结束点,这不一定是错误的
    >>> gen_out = model.generate(**inputs, min_new_tokens=2)
    >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
    A number: one thousand
    ```
    """

    def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: Union[int, List[int]]):
        # 验证并设置 `prompt_length_to_skip` 和 `min_new_tokens` 参数
        for arg_name, arg_value in [
            ("prompt_length_to_skip", prompt_length_to_skip),
            ("min_new_tokens", min_new_tokens),
        ]:
            if not isinstance(arg_value, int) or arg_value < 0:
                raise ValueError(f"`{arg_name}` 必须是正整数,但其值为 {arg_value}")

        # 验证并设置 `eos_token_id` 参数,确保其为正整数列表
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
            logger.warning(f"`eos_token_id` 必须是正整数列表,但其值为 {eos_token_id}")

        # 初始化对象的属性
        self.prompt_length_to_skip = prompt_length_to_skip
        self.min_new_tokens = min_new_tokens
        self.eos_token_id = eos_token_id

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 计算新生成标记的长度
        new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
        # 如果生成的新标记长度小于设定的最小值,将相应的 `eos_token_id` 的得分设为 `-float("inf")`
        if new_tokens_length < self.min_new_tokens:
            for i in self.eos_token_id:
                scores[:, i] = -float("inf")

        return scores
# TemperatureLogitsWarper 类,继承自 LogitsWarper
# 用于温度(指数缩放输出概率分布),有效地控制预测标记的随机性
# 常与 TopPLogitsWarper 和 TopKLogitsWarper 一起使用

class TemperatureLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
    that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
    [`TopKLogitsWarper`].

    <Tip>

    Make sure that `do_sample=True` is included in the `generate` arguments otherwise the temperature value won't have
    any effect.

    </Tip>

    Args:
        temperature (`float`):
            Strictly positive float value used to modulate the logits distribution. A value smaller than `1` decreases
            randomness (and vice versa), with `0` being equivalent to shifting all probability mass to the most likely
            token.

    Examples:

    ```
    >>> import torch
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

    >>> set_seed(0)  # for reproducibility

    >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
    >>> model.config.pad_token_id = model.config.eos_token_id
    >>> inputs = tokenizer(["Hugging Face Company is"], return_tensors="pt")

    >>> # With temperature=1.0, the default, we consistently get random outputs due to random sampling.
    >>> generate_kwargs = {"max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "num_return_sequences": 2}
    >>> outputs = model.generate(**inputs, **generate_kwargs)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
    ['Hugging Face Company is a joint venture between GEO Group, one of',
    'Hugging Face Company is not an exact science – but what we believe does']

    >>> # However, with temperature close to 0, it approximates greedy decoding strategies (invariant)
    >>> generate_kwargs["temperature"] = 0.0001
    >>> outputs = model.generate(**inputs, **generate_kwargs)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
    ['Hugging Face Company is a company that has been around for over 20 years',
    'Hugging Face Company is a company that has been around for over 20 years']
    ```
    """

    def __init__(self, temperature: float):
        # 检查温度参数是否为有效的浮点数且大于0
        if not isinstance(temperature, float) or not (temperature > 0):
            # 如果温度不是有效的正浮点数,抛出值错误异常
            except_msg = (
                f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
                "scores will be invalid."
            )
            # 如果温度为0,提醒用户可以设置 `do_sample=False` 来实现贪婪解码策略
            if isinstance(temperature, float) and temperature == 0.0:
                except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
            raise ValueError(except_msg)

        # 设置实例的温度属性
        self.temperature = temperature

    # 添加文档字符串,参考 LOGITS_PROCESSOR_INPUTS_DOCSTRING
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 定义类的特殊方法 __call__,使得对象可以像函数一样被调用
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 将分数 scores 除以温度 temperature,用于调整输出的分布
        scores = scores / self.temperature
        # 返回调整后的分数
        return scores
class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
    most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt.

    In the original [paper](https://arxiv.org/pdf/1909.05858.pdf), the authors suggest the use of a penalty of around
    1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
    repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
    repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.

    Args:
        penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
            tokens. Between 0.0 and 1.0 rewards previously generated tokens.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> # Initializing the model and tokenizer for it
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
    >>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")

    >>> # This shows a normal generate without any specific parameters
    >>> summary_ids = model.generate(**inputs)
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
    I'm not going to be able to do that. I'm going to be able to do that

    >>> # This generates a penalty for repeated tokens
    >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
    >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
    I'm not going to be able to do that. I'll just have to go out and play
    ```
    """

    def __init__(self, penalty: float):
        # 检查 penalty 是否为正的浮点数,否则抛出错误
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 从 scores 中选择对应 input_ids 的分数
        score = torch.gather(scores, 1, input_ids)

        # 如果 score < 0,则乘以 penalty 以减少 token 的概率
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        # 将修正后的分数重新写入 scores 中对应的位置
        scores.scatter_(1, input_ids, score)
        return scores


class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that works similarly to [`RepetitionPenaltyLogitsProcessor`], but with an *inverse* penalty
    that is applied to the tokens present in the prompt. In other words, a penalty above 1.0 increases the odds of
    selecting tokens that were present in the prompt.
    def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor):
        # 检查 penalty 是否为 float 类型且大于 0,否则抛出数值错误异常
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        # 计算实际的惩罚值,即将 1 除以 penalty
        self.penalty = 1 / penalty
        # 将输入的 encoder_input_ids 赋值给实例变量
        self.encoder_input_ids = encoder_input_ids

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 从 scores 中按列索引提取与 encoder_input_ids 相对应的分数
        score = torch.gather(scores, 1, self.encoder_input_ids)

        # 如果分数小于 0,则乘以 penalty 值以增加 token 的概率
        # 如果分数大于等于 0,则除以 penalty 值以降低 token 的概率
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        # 将处理后的 score 根据 encoder_input_ids 的索引位置更新到 scores 中
        scores.scatter_(1, self.encoder_input_ids, score)
        # 返回更新后的 scores
        return scores
class TopPLogitsWarper(LogitsWarper):
    """
    [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
    used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].

    Args:
        top_p (`float`):
            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
            higher are kept for generation.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

    >>> set_seed(0)
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")

    >>> # With sampling, the output is unexpected -- sometimes too unexpected.
    >>> outputs = model.generate(**inputs, do_sample=True)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2

    >>> # With `top_p` sampling, the output gets restricted to high-probability tokens.
    >>> # Pro tip: In practice, LLMs use `top_p` in the 0.9-0.95 range.
    >>> outputs = model.generate(**inputs, do_sample=True, top_p=0.1)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
    ```
    """

    def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 初始化 TopPLogitsWarper 对象,设置 top-p 概率截断参数
        top_p = float(top_p)
        # 检查 top_p 参数是否在有效范围 (0, 1) 内,否则引发 ValueError 异常
        if top_p < 0 or top_p > 1.0:
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
        # 检查 min_tokens_to_keep 参数是否为正整数,否则引发 ValueError 异常
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")

        # 设置对象的属性
        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    # 添加文档字符串作为类的一部分,描述输入参数
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 定义一个调用函数,接受输入的token IDs和对应的分数,返回处理后的分数
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 对分数进行升序排序,并返回排序后的分数和索引
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        # 对排序后的分数进行 softmax 处理并计算累积概率
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # 移除累积概率超过 top_p 阈值的token(累积概率为0的token保留)
        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
        # 至少保留 min_tokens_to_keep 个token
        sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

        # 将排序后的移除指标张量按照排序后的索引分散到原始索引位置
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        # 使用 filter_value 替换需要移除的token对应的分数
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        # 返回处理后的分数张量
        return scores
class TopKLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
    with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].

    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

    >>> set_seed(0)
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: A, B, C, D", return_tensors="pt")

    >>> # With sampling, the output is unexpected -- sometimes too unexpected.
    >>> outputs = model.generate(**inputs, do_sample=True)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: A, B, C, D, G, H, I. A, M

    >>> # With `top_k` sampling, the output gets restricted the k most likely tokens.
    >>> # Pro tip: In practice, LLMs use `top_k` in the 5-50 range.
    >>> outputs = model.generate(**inputs, do_sample=True, top_k=2)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: A, B, C, D, E, F, G, H, I
    ```
    """

    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 检查并初始化 `top_k` 参数,确保其为正整数
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        # 将 `top_k` 设为不小于 `min_tokens_to_keep` 的值,设置过滤值 `filter_value`
        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 确保 `top_k` 不超过 `scores` 的最后一维大小,以避免越界
        top_k = min(self.top_k, scores.size(-1))  # Safety check
        # 移除概率小于 `top-k` 中最后一个概率值的所有 token
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores


class TypicalLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
    log probability is close to the entropy of the token probability distribution. This means that the most likely
    tokens may be discarded in the process.

    See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
    # 初始化函数,用于创建一个新的实例对象
    def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 将输入参数 mass 转换为 float 类型
        mass = float(mass)
        # 检查 mass 参数是否在有效范围 (0, 1) 内,如果不是则引发 ValueError 异常
        if not (mass > 0 and mass < 1):
            raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
        # 检查 min_tokens_to_keep 是否为正整数,如果不是则引发 ValueError 异常
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")

        # 设置对象的 filter_value 属性为传入的 filter_value 参数值
        self.filter_value = filter_value
        # 设置对象的 mass 属性为处理后的 mass 参数值
        self.mass = mass
        # 设置对象的 min_tokens_to_keep 属性为处理后的 min_tokens_to_keep 参数值
        self.min_tokens_to_keep = min_tokens_to_keep
    # 定义一个调用方法,接收输入的token ID张量和得分张量,并返回处理后的得分张量
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 计算熵(entropy)
        normalized = torch.nn.functional.log_softmax(scores, dim=-1)
        p = torch.exp(normalized)
        ent = -(normalized * p).nansum(-1, keepdim=True)

        # 移位并排序
        shifted_scores = torch.abs((-normalized) - ent)
        sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
        sorted_logits = scores.gather(-1, sorted_indices)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # 根据累积概率阈值移除部分token
        last_ind = (cumulative_probs < self.mass).sum(dim=1)
        last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
        sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
        sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # 使用指定的值过滤掉需要移除的token的得分
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
# 定义一个名为 EpsilonLogitsWarper 的类,继承自 LogitsWarper 类
class EpsilonLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
    largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
    Desmoothing](https://arxiv.org/abs/2210.15191) for more information.

    Args:
        epsilon (`float`):
            If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
        filter_value (`float`, *optional*, defaults to -inf):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.

    Examples:
    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

    >>> set_seed(0)
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")

    >>> # With sampling, the output is unexpected -- sometimes too unexpected.
    >>> outputs = model.generate(**inputs, do_sample=True)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2

    >>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to
    >>> # Top P sampling, which restricts tokens based on their cumulative probability.
    >>> # Pro tip: The paper recomends using `epsilon_cutoff` values between 3e-4 and 9e-4
    >>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
    ```
    """

    # 初始化方法,设置 epsilon-sampling 的参数
    def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 将 epsilon 强制转换为 float 类型
        epsilon = float(epsilon)
        # 如果 epsilon 不在有效范围 (0, 1) 内,抛出异常
        if epsilon <= 0 or epsilon >= 1:
            raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")

        # 将 min_tokens_to_keep 强制转换为 int 类型
        min_tokens_to_keep = int(min_tokens_to_keep)
        # 如果 min_tokens_to_keep 不大于等于 1,抛出异常
        if min_tokens_to_keep < 1:
            raise ValueError(
                f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
            )

        # 初始化对象的属性
        self.epsilon = epsilon
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    # 添加 LogitsProcessor 的输入文档字符串
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 定义一个调用方法,接收输入的张量 input_ids 和分数张量 scores,并返回一个分数张量
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 使用 softmax 函数计算分数张量在最后一个维度上的概率分布
        probabilities = scores.softmax(dim=-1)
        # 创建一个布尔张量,指示哪些索引的概率低于阈值 self.epsilon
        indices_to_remove = probabilities < self.epsilon

        # 确保保留至少 self.min_tokens_to_keep 个最高概率的单词
        top_k = min(self.min_tokens_to_keep, scores.size(-1))  # 进行安全性检查,取最小值
        # 使用 torch.topk 函数获取最高分数的前 top_k 个分数,并与 indices_to_remove 合并
        indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])

        # 使用指定的 self.filter_value 替换 scores 张量中 indices_to_remove 为 True 的元素
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        # 返回处理后的分数张量
        return scores
class EtaLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
    cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
    the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
    min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
    samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
    Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
    must be set to `True` for this `LogitsWarper` to work.


    Args:
        epsilon (`float`):
            A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The
            suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model.
        filter_value (`float`, *optional*, defaults to -inf):
            All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This
            parameter is useful when logits need to be modified for very low probability tokens that should be excluded
            from generation entirely.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities.
            For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation,
            even if all tokens have probabilities below the cutoff `eta`.

    Examples:
    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

    >>> set_seed(0)
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt")

    >>> # With sampling, the output is unexpected -- sometimes too unexpected.
    >>> outputs = model.generate(**inputs, do_sample=True)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2

    >>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of
    >>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff).
    >>> # Pro tip: The paper recomends using `eta_cutoff` values between 3e-4 to 4e-3
    >>> outputs = model.generate(**inputs, do_sample=True, eta_cutoff=0.1)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9
    ```
    """
    def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 将 epsilon 转换为浮点数并进行验证
        epsilon = float(epsilon)
        # 检查 epsilon 的取值范围是否在 (0, 1) 之间,否则引发 ValueError 异常
        if epsilon <= 0 or epsilon >= 1:
            raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")

        # 将 min_tokens_to_keep 转换为整数并进行验证
        min_tokens_to_keep = int(min_tokens_to_keep)
        # 检查 min_tokens_to_keep 是否大于等于 1,否则引发 ValueError 异常
        if min_tokens_to_keep < 1:
            raise ValueError(
                f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
            )

        # 初始化对象的属性
        self.epsilon = torch.tensor(epsilon)
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 计算自适应阈值 eta
        probabilities = scores.softmax(dim=-1)  # 计算概率分布
        entropy = torch.distributions.Categorical(logits=scores).entropy()  # 计算熵
        eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]  # 计算 eta

        # 确定需要移除的索引
        indices_to_remove = probabilities < eta

        # 保留概率最高的 min_tokens_to_keep 个词
        top_k = min(self.min_tokens_to_keep, scores.size(-1))  # 安全检查,确保 top_k 不超过 scores 的最后一个维度大小
        indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])

        # 根据 indices_to_remove 进行掩码操作,用 filter_value 替换需要移除的位置的分数
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
# 定义一个函数 `_get_ngrams`,用于生成给定大小的 n-gram 并保存在字典中
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    """
    Assume ngram_size=2 and prev_input_ids=tensor([[40, 2883, 2712, 4346]]). The output of generated ngrams look like
    this {(40,): [2883], (2883,): [2712], (2712,): [4346]}.

    Args:
        ngram_size (`int`):
            The number sequential tokens taken as a group which may only occur once before being banned.
        prev_input_ids (`torch.Tensor`):
           Generated token ids for the current hypothesis.
        num_hypos (`int`):
            The number of hypotheses for which n-grams need to be generated.

    Returns:
        generated_ngrams (`dict`):
            Dictionary of generated ngrams.
    """
    # 初始化一个空的字典列表,每个假设 (索引) 对应一个字典,数量为 num_hypos
    generated_ngrams = [{} for _ in range(num_hypos)]
    # 遍历每个假设
    for idx in range(num_hypos):
        # 将当前假设的生成的 token 转换为列表
        gen_tokens = prev_input_ids[idx].tolist()
        # 获取当前假设的生成 ngram 字典
        generated_ngram = generated_ngrams[idx]
        # 遍历当前假设生成的 token 列表,生成大小为 ngram_size 的 ngram
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            # 将生成的 ngram 加入到生成的 ngram 字典中
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


# 定义一个函数 `_get_generated_ngrams`,用于确定基于先前生成的 ngram 的当前假设的禁用 token
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    """
    Determines the banned tokens for the current hypothesis based on previously generated n-grams.

    Args:
        banned_ngrams (`dict`):
            A dictionary containing previously generated n-grams for each hypothesis.
        prev_input_ids (`torch.Tensor`):
            Generated token ids for the current hypothesis.
        ngram_size (`int`):
            The number sequential tokens taken as a group which may only occur once before being banned.
        cur_len (`int`):
            The current length of the token sequences for which the n-grams are being checked.

    Returns:
        List of tokens that are banned.
    """
    # 计算当前需要检查的 ngram 的起始索引
    start_idx = cur_len + 1 - ngram_size
    # 获取当前假设生成的 ngram 的索引元组
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    # 返回先前生成的 ngrams 中与当前 ngram 索引匹配的禁用 tokens 列表
    return banned_ngrams.get(ngram_idx, [])


# 定义一个函数 `_calc_banned_ngram_tokens`,用于计算禁用的 ngram token
def _calc_banned_ngram_tokens(
    ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    # 如果当前生成的 token 数量小于 ngram_size,则返回空的禁用 tokens 列表
    if cur_len + 1 < ngram_size:
        return [[] for _ in range(num_hypos)]
    # 生成当前假设的 ngrams
    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
    # 获取每个假设的禁用 tokens 列表
    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens
class NoRepeatNGramLogitsProcessor(LogitsProcessor):
    r"""
    N-grams are groups of "n" consecutive words, characters, or tokens taken from a sequence of text. Given the
    sentence: "She runs fast", the bi-grams (n=2) would be ("she", "runs") and ("runs", "fast"). In text generation,
    avoiding repetitions of word sequences provides a more diverse output. This [`LogitsProcessor`] enforces no
    repetition of n-grams by setting the scores of banned tokens to negative infinity which eliminates those tokens
    from consideration when further processing the scores. Note that, for decoder-only models like most LLMs, the
    prompt is also considered to obtain the n-grams.
    [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

    <Tip>

    Use n-gram penalties with care. For instance, penalizing 2-grams (bigrams) in an article about the city of New York
    might lead to undesirable outcomes where the city's name appears only once in the entire text.
    [Reference](https://huggingface.co/blog/how-to-generate)

    </Tip>

    Args:
        ngram_size (`int`):
            All ngrams of size `ngram_size` can only occur once.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
    >>> inputs = tokenizer(["Today I"], return_tensors="pt")

    >>> output = model.generate(**inputs)
    >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
    Today I’m not sure if I’m going to be able to do it.

    >>> # Now let's add ngram size using `no_repeat_ngram_size`. This stops the repetitions ("I’m") in the output.
    >>> output = model.generate(**inputs, no_repeat_ngram_size=2)
    >>> print(tokenizer.decode(output[0], skip_special_tokens=True))
    Today I’m not sure if I can get a better understanding of the nature of this issue
    ```
    """

    def __init__(self, ngram_size: int):
        # 检查并初始化 ngram_size,确保其为正整数
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
        self.ngram_size = ngram_size

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前 batch 的假设数量
        num_batch_hypotheses = scores.shape[0]
        # 获取当前输入序列的长度
        cur_len = input_ids.shape[-1]
        # 计算当前 batch 每个假设中不允许出现的 n-gram tokens
        banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
        # 将不允许出现的 token 的分数设为负无穷,以便在后续处理中排除这些 token
        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

        return scores


class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that works similarly to [`NoRepeatNGramLogitsProcessor`], but applied exclusively to prevent
    """
    Initializes an instance of the ultimate n-gram blocker.

    Args:
        encoder_ngram_size (`int`):
            Size of the n-grams that should not be repeated in the decoder.
        encoder_input_ids (`torch.LongTensor`):
            Tensor containing input IDs for the encoder.

    """

    def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor):
        # Check if encoder_ngram_size is a positive integer
        if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
            raise ValueError(
                f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
            )
        # Store the n-gram size
        self.ngram_size = encoder_ngram_size
        
        # Ensure encoder_input_ids is 2-dimensional
        if len(encoder_input_ids.shape) == 1:
            encoder_input_ids = encoder_input_ids.unsqueeze(0)
        
        # Store batch size
        self.batch_size = encoder_input_ids.shape[0]
        
        # Generate n-grams from the encoder input IDs
        self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Calculate number of hypotheses
        num_hypos = scores.shape[0]
        
        # Calculate number of beams per hypothesis
        num_beams = num_hypos // self.batch_size
        
        # Current length of input_ids
        cur_len = input_ids.shape[-1]
        
        # List of banned tokens for each hypothesis
        banned_batch_tokens = [
            _get_generated_ngrams(
                self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
            )
            for hypo_idx in range(num_hypos)
        ]
        
        # Apply -inf score to banned tokens in scores tensor
        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")
        
        return scores
class SequenceBiasLogitsProcessor(LogitsProcessor):
    """
    [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
    when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
    one token, consider using beam methods (to gracefully work around partially completed sequences that have a
    negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).

    <Tip>

    In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when
    initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The
    `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours
    come from `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).

    </Tip>

    Args:
        sequence_bias (`Dict[Tuple[int], float]`):
            Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
            sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias
            will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be
            completed (in the token selection step after this processor is applied).

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt")

    >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4)
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
    The full name of Donald is Donald J. Trump Jr

    >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently!
    >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)


    >>> def get_tokens_as_tuple(word):
    ...     return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])


    >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations
    >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0}
    >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias)
    >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
    The full name of Donald is Donald J. Donald,

    >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)

    """

    def __init__(self, sequence_bias):
        """
        Initialize the SequenceBiasLogitsProcessor with a sequence bias dictionary.

        Args:
            sequence_bias (`Dict[Tuple[int], float]`): A dictionary mapping sequences of tokens to their bias values.
        """
        super().__init__()
        self.sequence_bias = sequence_bias

    def __call__(self, input_ids, scores):
        """
        Apply the sequence bias to the logits.

        Args:
            input_ids (torch.Tensor): Input token IDs.
            scores (torch.Tensor): Logits (scores) for each token.

        Returns:
            torch.Tensor: Modified logits after applying sequence bias.
        """
        # Determine the sequence length
        seq_len = input_ids.size(1)
        # Get the last token's token IDs
        last_token_ids = input_ids[:, -1].tolist()

        # Check if the last token is in the sequence_bias dictionary
        if tuple(last_token_ids) in self.sequence_bias:
            # Apply bias to the last token's logits
            scores[:, -1] += self.sequence_bias[tuple(last_token_ids)]

        return scores
    >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
    The full name of Donald is Donald Rumsfeld,

    >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations
    >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0}
    >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)
    >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])
    The full name of Donald is Donald Duck.
    ```
    """

    # 初始化函数,接收一个序列偏置的字典作为参数
    def __init__(self, sequence_bias: Dict[Tuple[int], float]):
        self.sequence_bias = sequence_bias  # 初始化序列偏置
        self._validate_arguments()  # 调用内部方法验证参数

        # 下面的变量在第一次调用时才会被填充(为了向后兼容性,词汇大小将在第一次使用中推断出来,因此在这里不进行初始化)
        self.length_1_bias = None  # 长度为1的偏置变量
        self.prepared_bias_variables = False  # 准备好的偏置变量标志位

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 调用方法,接收输入的input_ids和scores,返回经过处理后的scores
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 1 - 准备偏置张量。这仅在第一次调用logit处理器时需要。
        if not self.prepared_bias_variables:
            self._prepare_bias_variables(scores)

        # 2 - 准备一个空的偏置张量以添加
        bias = torch.zeros_like(scores)

        # 3 - 包含长度为1时的偏置
        bias += self.length_1_bias

        # 4 - 包含长度大于1时的偏置,确定可以完成的偏置序列
        for sequence_ids, sequence_bias in self.sequence_bias.items():
            if len(sequence_ids) == 1:  # 序列长度为1,已应用偏置
                continue
            if len(sequence_ids) > input_ids.shape[1]:  # 序列比上下文长,忽略
                continue
            prefix_length = len(sequence_ids) - 1
            last_token = sequence_ids[-1]
            matching_rows = torch.eq(
                input_ids[:, -prefix_length:],
                torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
            ).prod(dim=1)
            bias[:, last_token] += torch.where(
                matching_rows.bool(),
                torch.tensor(sequence_bias, device=input_ids.device),
                torch.tensor(0.0, device=input_ids.device),
            )

        # 5 - 将偏置应用于得分
        scores = scores + bias
        return scores
    # 准备偏置变量,根据模型得分张量的形状确定词汇表大小
    def _prepare_bias_variables(self, scores: torch.FloatTensor):
        vocabulary_size = scores.shape[-1]

        # 检查偏置的标记是否超出范围
        invalid_biases = []
        for sequence_ids in self.sequence_bias:
            for token_id in sequence_ids:
                if token_id >= vocabulary_size:
                    invalid_biases.append(token_id)
        if len(invalid_biases) > 0:
            raise ValueError(
                f"The model vocabulary size is {vocabulary_size}, but the following tokens were being biased: "
                f"{invalid_biases}"
            )

        # 预计算要应用的偏置张量。长度为1的序列单独处理,因为可以使用更简单的逻辑应用。
        self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device)
        for sequence_ids, bias in self.sequence_bias.items():
            if len(sequence_ids) == 1:
                self.length_1_bias[sequence_ids[-1]] = bias

        # 标记已准备好偏置变量
        self.prepared_bias_variables = True

    # 验证参数是否合法
    def _validate_arguments(self):
        sequence_bias = self.sequence_bias
        # 检查 `sequence_bias` 是否是非空字典
        if not isinstance(sequence_bias, dict) or len(sequence_bias) == 0:
            raise ValueError(f"`sequence_bias` has to be a non-empty dictionary, but is {sequence_bias}.")
        # 检查 `sequence_bias` 的键是否是元组
        if any(not isinstance(sequence_ids, tuple) for sequence_ids in sequence_bias.keys()):
            raise ValueError(f"`sequence_bias` has to be a dict with tuples as keys, but is {sequence_bias}.")
        # 检查 `sequence_bias` 的键是否为非空的正整数元组
        if any(
            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in sequence_ids)
            or len(sequence_ids) == 0
            for sequence_ids in sequence_bias.keys()
        ):
            raise ValueError(
                f"Each key in `sequence_bias` has to be a non-empty tuple of positive integers, but is "
                f"{sequence_bias}."
            )
        # 检查 `sequence_bias` 的值是否都是浮点数
        if any(not isinstance(bias, float) for bias in sequence_bias.values()):
            raise ValueError(f"`sequence_bias` has to be a dict with floats as values, but is {sequence_bias}.")
"""
[`LogitsProcessor`] that enforces that specified sequences will never be selected.

<Tip>

In order to get the token ids of the words that should not appear in the generated text, make sure to set
`add_prefix_space=True` when initializing the tokenizer, and use `tokenizer(bad_words,
add_special_tokens=False).input_ids`. The `add_prefix_space` argument is only supported for some slow tokenizers,
as fast tokenizers' prefixing behaviours come from `pre tokenizers`. Read more
[here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).

</Tip>

Args:
    bad_words_ids (`List[List[int]]`):
        List of list of token ids that are not allowed to be generated.
    eos_token_id (`Union[int, List[int]]`):
        The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.

Examples:


>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> inputs = tokenizer(["In a word, the cake is a"], return_tensors="pt")

>>> output_ids = model.generate(inputs["input_ids"], max_new_tokens=5, pad_token_id=tokenizer.eos_token_id)
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a mess.

>>> # Now let's take the bad words out. Please note that the tokenizer is initialized differently
>>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=True)


>>> def get_tokens_as_list(word_list):
...     "Converts a sequence of words into a list of tokens"
...     tokens_list = []
...     for word in word_list:
...         tokenized_word = tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]
...         tokens_list.append(tokenized_word)
...     return tokens_list


>>> bad_words_ids = get_tokens_as_list(word_list=["mess"])
>>> output_ids = model.generate(
...     inputs["input_ids"], max_new_tokens=5, bad_words_ids=bad_words_ids, pad_token_id=tokenizer.eos_token_id
... )
>>> print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])
In a word, the cake is a bit of a surprise.

"""
    # 初始化函数,接收两个参数:bad_words_ids 是包含不良词汇列表的列表,eos_token_id 是结束标记的整数或整数列表
    def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
        # 将参数 bad_words_ids 存储在对象属性中
        self.bad_word_ids = bad_words_ids
        # 调用内部方法验证参数的有效性
        self._validate_arguments()

        # 过滤掉 bad_words_ids 中包含的 EOS 标记
        if eos_token_id is None:
            eos_token_id = []
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        # 使用 lambda 函数过滤 bad_words_ids,确保不包含任何 EOS 标记的序列
        bad_words_ids = list(
            filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
        )

        # 将禁止序列设置为负无穷的偏置字典
        sequence_bias = {tuple(sequence): float("-inf") for sequence in bad_words_ids}
        # 调用父类初始化方法,传递序列偏置字典作为参数
        super().__init__(sequence_bias=sequence_bias)

    # 内部方法,验证 bad_words_ids 参数的有效性
    def _validate_arguments(self):
        # 将对象属性 bad_word_ids 赋值给局部变量 bad_words_ids
        bad_words_ids = self.bad_word_ids
        # 检查 bad_words_ids 是否为非空列表
        if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
            raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
        # 检查 bad_words_ids 中的每个元素是否为列表
        if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
            raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
        # 检查 bad_words_ids 中每个列表的元素是否为正整数
        if any(
            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
            for bad_word_ids in bad_words_ids
        ):
            raise ValueError(
                f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
            )
class PrefixConstrainedLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
    generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.

    Args:
        prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`):
            This function constraints the beam search to allowed tokens only at each step. This function takes 2
            arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the
            next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID
            `batch_id`.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloomz-560m")
    >>> tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-560m")

    >>> inputs = tokenizer("Alice and Bob", return_tensors="pt")

    >>> # By default, it continues generating according to the model's logits
    >>> outputs = model.generate(**inputs, max_new_tokens=5)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    Alice and Bob are friends

    >>> # We can contrain it with `prefix_allowed_tokens_fn` to force a certain behavior based on a prefix.
    >>> # For instance, we can force an entire entity to be generated when its beginning is detected.
    >>> entity =  tokenizer(" Bob Marley", return_tensors="pt").input_ids[0]  # 3 tokens
    >>> def prefix_allowed_tokens_fn(batch_id, input_ids):
    ...     '''
    ...     Attempts to generate 'Bob Marley' when 'Bob' is detected.
    ...     In this case, `batch_id` is not used, but you can set rules for each batch member.
    ...     '''
    ...     if input_ids[-1] == entity[0]:
    ...         return entity[1]
    ...     elif input_ids[-2] == entity[0] and input_ids[-1] == entity[1]:
    ...         return entity[2]
    ...     return list(range(tokenizer.vocab_size))  # If no match, allow all tokens

    >>> outputs = model.generate(**inputs, max_new_tokens=5, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn)
    >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
    Alice and Bob Marley
    ```

    """

    def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int):
        # 初始化函数,接受两个参数:prefix_allowed_tokens_fn 控制允许的生成标记,num_beams 控制束搜索的数量
        self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
        self._num_beams = num_beams

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 定义一个方法,接受输入的torch.LongTensor类型的input_ids和torch.FloatTensor类型的scores,并返回一个torch.FloatTensor类型的结果
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 创建一个与scores形状相同的张量,填充为负无穷大,用作掩码
        mask = torch.full_like(scores, -math.inf)
        
        # 遍历input_ids,按照_beam_num划分batch_id和beam_sent
        for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
            # 遍历每个beam_sent中的beam_id和sent
            for beam_id, sent in enumerate(beam_sent):
                # 调用_prefix_allowed_tokens_fn方法获取允许的前缀标记
                prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
                # 如果prefix_allowed_tokens列表为空,抛出ValueError异常
                if len(prefix_allowed_tokens) == 0:
                    raise ValueError(
                        f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
                        f"This means that the constraint is unsatisfiable. Please check your implementation"
                        f"of `prefix_allowed_tokens_fn` "
                    )
                # 将mask中指定位置(batch_id * self._num_beams + beam_id行)的允许标记位置设为0
                mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0

        # 返回scores与mask相加后的结果
        return scores + mask
# 定义一个继承自 LogitsProcessor 的类,用于实现多样化的 Beam Search 算法。
class HammingDiversityLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces diverse beam search.

    Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam
    Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more
    details.

    Traditional beam search often generates very similar sequences across different beams.
    `HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
    beams in the same time step.

    Args:
        diversity_penalty (`float`):
            This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
            particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
            this value can help strike a balance between diversity and natural likelihood.
        num_beams (`int`):
            Number of beams for beam search. 1 means no beam search.
        num_beam_groups (`int`):
            Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
            [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    >>> import torch

    >>> # Initialize the model and tokenizer
    >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
    >>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")

    >>> # A long text about the solar system
    >>> text = (
    ...     "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
    ...     "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
    ...     "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
    ...     "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
    ...     "interstellar molecular cloud."
    ... )
    >>> inputs = tokenizer("summarize: " + text, return_tensors="pt")

    >>> # Generate diverse summary
    >>> outputs_diverse = model.generate(
    ...     **inputs,
    ...     num_beam_groups=2,
    ...     diversity_penalty=10.0,
    ...     max_length=100,
    ...     num_beams=4,
    ...     num_return_sequences=2,
    ... )
    >>> summaries_diverse = tokenizer.batch_decode(outputs_diverse, skip_special_tokens=True)

    >>> # Generate non-diverse summary
    >>> outputs_non_diverse = model.generate(
    ...     **inputs,
    ...     max_length=100,
    ...     num_beams=4,
    ...     num_return_sequences=2,
    ... )
    >>> summary_non_diverse = tokenizer.batch_decode(outputs_non_diverse, skip_special_tokens=True)
    # 初始化方法,用于设置多样性惩罚、束搜索数和束搜索组数的初始值
    def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
        # 检查并确保 diversity_penalty 是大于0的浮点数
        if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
            raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
        self._diversity_penalty = diversity_penalty  # 设置多样性惩罚参数

        # 检查并确保 num_beams 是大于1的整数
        if not isinstance(num_beams, int) or num_beams < 2:
            raise ValueError("`num_beams` should be an integer strictly larger than 1.")
        self._num_beams = num_beams  # 设置束搜索数

        # 检查并确保 num_beam_groups 是大于1的整数,且不超过 num_beams
        if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
            raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
        if num_beam_groups > num_beams:
            raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
        self._num_sub_beams = num_beams // num_beam_groups  # 计算并设置每个束搜索组的子束搜索数

    # 对象被调用时执行的方法,用于执行束搜索过程
    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
        current_tokens: torch.LongTensor,
        beam_group_idx: int,
    ) -> torch.FloatTensor:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            current_tokens (`torch.LongTensor` of shape `(batch_size)`):
                Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
                beam groups in the current generation step.
            beam_group_idx (`int`):
                The index of the beam group currently being processed.

        Return:
            `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
                The processed prediction scores.
        """
        # hamming diversity: penalise using same token in current group which was used in previous groups at
        # the same time step
        batch_size = current_tokens.shape[0] // self._num_beams  # 计算批次大小
        group_start_idx = beam_group_idx * self._num_sub_beams  # 计算当前处理的 beam 组的起始索引
        group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)  # 计算当前处理的 beam 组的结束索引,确保不超过总数
        group_size = group_end_idx - group_start_idx  # 计算当前处理的 beam 组的大小
        vocab_size = scores.shape[-1]  # 获取词汇表大小

        if group_start_idx == 0:
            return scores  # 如果是第一个组,直接返回原始预测分数

        for batch_idx in range(batch_size):
            # predicted tokens of last time step of previous groups
            previous_group_tokens = current_tokens[
                batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
            ]  # 获取前面组在当前时间步的预测 token

            token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
            # 计算前面组使用的 token 频率,并转移到与 scores 设备一致的张量上

            scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
            # 根据多样性惩罚系数,减少当前组的预测分数

        return scores
class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces the specified token as the first generated token. Used with encoder-decoder
    models.

    Args:
        bos_token_id (`int`):
            The id of the token to force as the first generated token.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

    >>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
    >>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

    >>> inputs = tokenizer("Translate from English to German: I love cats.", return_tensors="pt")

    >>> # By default, it continues generating according to the model's logits
    >>> outputs = model.generate(**inputs, max_new_tokens=10)
    >>> print(tokenizer.batch_decode(outputs)[0])
    <pad> Ich liebe Kitty.</s>

    >>> # We can use `forced_bos_token_id` to force the start of generation with an encoder-decoder model
    >>> # (including forcing it to end straight away with an EOS token)
    >>> outputs = model.generate(**inputs, max_new_tokens=10, forced_bos_token_id=tokenizer.eos_token_id)
    >>> print(tokenizer.batch_decode(outputs)[0])
    <pad></s>
    ```
    """

    def __init__(self, bos_token_id: int):
        # 初始化方法,设置强制起始 token 的 ID
        self.bos_token_id = bos_token_id

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前生成序列的长度
        cur_len = input_ids.shape[-1]
        # 如果当前长度为1,即刚开始生成
        if cur_len == 1:
            # 获取 logits 的可能 token 数量
            num_tokens = scores.shape[1]
            # 将除了指定的强制起始 token 之外的 logits 设置为负无穷大,确保不会被生成
            scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
            # 将强制起始 token 的 logits 设置为0,确保它被生成
            scores[:, self.bos_token_id] = 0
        return scores


class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.

    Args:
        max_length (`int`):
            The maximum length of the sequence to be generated.
        eos_token_id (`Union[int, List[int]]`):
            The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a
            list to set multiple *end-of-sequence* tokens.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")

    >>> # By default, it continues generating according to the model's logits
    >>> outputs = model.generate(**inputs, max_new_tokens=10)
    >>> print(tokenizer.batch_decode(outputs)[0])
    A sequence: 1, 2, 3, 4, 5, 6, 7, 8

    >>> # `forced_eos_token_id` ensures the generation ends with a EOS token
    ```
    """

    def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
        # 初始化方法,设置强制结束 token 的 ID 或 IDs
        self.max_length = max_length
        self.eos_token_id = eos_token_id

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前生成序列的长度
        cur_len = input_ids.shape[-1]
        # 如果达到最大长度,强制设置生成序列的最后 token(s)
        if cur_len == self.max_length:
            if isinstance(self.eos_token_id, int):
                # 如果是单个 EOS token ID,将除了它之外的 logits 设置为负无穷大
                scores[:, [i for i in range(scores.shape[1]) if i != self.eos_token_id]] = -float("inf")
                # 将 EOS token 的 logits 设置为0,确保它被生成
                scores[:, self.eos_token_id] = 0
            else:
                # 如果是多个 EOS token IDs,将除了它们之外的 logits 设置为负无穷大
                for eos_id in self.eos_token_id:
                    scores[:, [i for i in range(scores.shape[1]) if i != eos_id]] = -float("inf")
                # 将所有 EOS tokens 的 logits 设置为0,确保它们中的任意一个被生成
                for eos_id in self.eos_token_id:
                    scores[:, eos_id] = 0
        return scores
    # 使用模型生成文本输出,限制生成的新标记数目为10个,强制结束标记使用给定的 eos_token_id
    outputs = model.generate(**inputs, max_new_tokens=10, forced_eos_token_id=tokenizer.eos_token_id)
    
    # 解码生成的输出序列并打印第一个结果
    print(tokenizer.batch_decode(outputs)[0])
class InfNanRemoveLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
    the logits processor should only be used if necessary since it can slow down the generation method.

    This logits processor has no `generate` example, as there shouldn't be a correct combination of flags that warrants
    its use.
    """

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # set all nan values to 0.0
        scores[scores != scores] = 0.0  # 将所有的NaN值设置为0.0

        # set all +/-inf values to max/min possible value
        scores[scores == float("inf")] = torch.finfo(scores.dtype).max  # 将所有的正无穷值设置为数据类型的最大值
        scores[scores == float("-inf")] = torch.finfo(scores.dtype).min  # 将所有的负无穷值设置为数据类型的最小值

        return scores
    """
    该类的构造函数初始化对象的属性,并计算长度调整的起始点和衰减因子。

    def __init__(
        self,
        exponential_decay_length_penalty: Tuple[int, float],  # 接收一个元组,包含衰减长度和衰减因子
        eos_token_id: Union[int, List[int]],  # 接收结束标记的 ID,可以是单个整数或整数列表
        input_ids_seq_length: int,  # 输入的序列长度
    ):
        # 计算调整起始点,考虑输入序列的长度
        self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
        # 设置衰减因子
        self.regulation_factor = exponential_decay_length_penalty[1]
        # 如果结束标记是整数,则转换为列表
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        # 存储结束标记的 ID
        self.eos_token_id = eos_token_id

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前输入序列的长度
        cur_len = input_ids.shape[-1]
        # 如果当前长度超过了调整起始点
        if cur_len > self.regulation_start:
            # 对每个结束标记执行以下操作
            for i in self.eos_token_id:
                # 计算惩罚的索引,基于当前长度和调整起始点
                penalty_idx = cur_len - self.regulation_start
                # 支持负对数,计算绝对值的惩罚,并添加到原始对数中
                scores[:, i] = scores[:, i] + torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1)
        # 返回调整后的分数
        return scores
    """
class LogitNormalization(LogitsProcessor, LogitsWarper):
    r"""
    [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
    the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
    this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
    the scores are normalized when comparing the hypotheses.

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM
    >>> import torch

    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

    >>> inputs = tokenizer("A sequence: 1, 2, 3", return_tensors="pt")

    >>> # By default, the scores are not normalized -- the sum of their exponentials is NOT a normalized probability
    >>> # distribution, summing to 1
    >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
    >>> print(torch.sum(torch.exp(outputs.scores[-1])))
    tensor(816.3250)

    >>> # Normalizing them may have a positive impact on beam methods, or when using the scores on your application
    >>> outputs = model.generate(**inputs, renormalize_logits=True, return_dict_in_generate=True, output_scores=True)
    >>> print(torch.sum(torch.exp(outputs.scores[-1])))
    tensor(1.0000)
    ```
    """

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    # 定义一个类方法,继承自 LogitsProcessor 类,并添加了文档字符串描述输入
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 对 scores 执行 log_softmax 操作,使得 scores 在最后一个维度上进行 log-softmax 归一化
        scores = scores.log_softmax(dim=-1)
        # 返回处理后的 scores
        return scores


class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
    r"""
    [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts
    generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are
    not generated at the begining. Originally created for
    [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).

    Examples:

    ```
    >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
    >>> from datasets import load_dataset

    >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
    >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")

    >>> # Whisper has `begin_suppress_tokens` set by default (= `[220, 50256]`). 50256 is the EOS token, so this means
    >>> # it can't generate and EOS token in the first iteration, but it can in the others.
    >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
    ```

    """
    >>> print(outputs.scores[1][0, 50256])  # 1 (and not 0) is the first freely generated token
    tensor(-inf)
    >>> print(outputs.scores[-1][0, 50256])  # in other places we can see some probability mass for EOS
    tensor(29.9010)

    >>> # If we disable `begin_suppress_tokens`, we can generate EOS in the first iteration.
    >>> outputs = model.generate(
    ...     **inputs, return_dict_in_generate=True, output_scores=True, begin_suppress_tokens=None
    ... )
    >>> print(outputs.scores[1][0, 50256])
    tensor(11.2027)
    ```

    """
    
    # 初始化函数,接收两个参数:begin_suppress_tokens(起始抑制令牌列表)和begin_index(起始索引)
    def __init__(self, begin_suppress_tokens, begin_index):
        # 将传入的begin_suppress_tokens转换为列表并赋值给实例变量self.begin_suppress_tokens
        self.begin_suppress_tokens = list(begin_suppress_tokens)
        # 将传入的begin_index赋值给实例变量self.begin_index
        self.begin_index = begin_index

    # 设置起始索引的方法,更新实例变量self.begin_index
    def set_begin_index(self, begin_index):
        self.begin_index = begin_index

    # 装饰器函数,添加了LOGITS_PROCESSOR_INPUTS_DOCSTRING的文档字符串,声明了输入和输出类型
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 如果输入的input_ids在第二维(列数)上的大小等于实例变量self.begin_index
        if input_ids.shape[1] == self.begin_index:
            # 则将scores张量中所有行的第self.begin_suppress_tokens列设为负无穷
            scores[:, self.begin_suppress_tokens] = -float("inf")

        # 返回修改后的scores张量
        return scores
class SuppressTokensLogitsProcessor(LogitsProcessor):
    r"""
    This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so
    that they are not generated. Originally created for
    [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).

    Examples:

    ```
    >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
    >>> from datasets import load_dataset

    >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
    >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")

    >>> # Whisper has a long list of suppressed tokens. For instance, in this case, the token 1 is suppressed by default.
    >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
    >>> print(outputs.scores[1][0, 1])  # 1 (and not 0) is the first freely generated token
    tensor(-inf)

    >>> # If we disable `suppress_tokens`, we can generate it.
    >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, suppress_tokens=None)
    >>> print(outputs.scores[1][0, 1])
    tensor(5.7738)
    ```
    """

    def __init__(self, suppress_tokens):
        # 初始化函数,接受一个需要抑制的 token 列表
        self.suppress_tokens = list(suppress_tokens)

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 在 scores 的特定列中,将抑制的 token 对应的概率设为负无穷
        scores[:, self.suppress_tokens] = -float("inf")
        # 返回处理后的 scores
        return scores
    # 验证所有的 scores 中除了索引为 50362 的位置外,其他位置是否都是负无穷大
    all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
    True

    >>> # 打印索引为 50362 的 scores,确认其值为 0
    >>> print(outputs.scores[0][0, 50362])
    tensor(0.)

    >>> # 如果禁用了 `forced_decoder_ids`,我们停止看到上述效果
    >>> outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, forced_decoder_ids=None)
    >>> # 验证所有的 scores 中除了索引为 50362 的位置外,其他位置是否都是负无穷大
    >>> print(
    ...     all(outputs.scores[0][0, i] == float("-inf") for i in range(processor.tokenizer.vocab_size) if i != 50362)
    ... )
    False
    >>> # 打印索引为 50362 的 scores,确认其新的值为 19.3140
    >>> print(outputs.scores[0][0, 50362])
    tensor(19.3140)
    ```

    """

    def __init__(self, force_token_map: List[List[int]], _has_warned: Optional[bool] = False):
        # 初始化 ForceTokensLogitsProcessor 类,接收一个强制令牌映射 force_token_map 和一个是否警告的标志 _has_warned
        self.force_token_map = dict(force_token_map)
        if not _has_warned:
            # 如果 _has_warned 为 False,发出警告,提醒在 v4.40 版本中移除该处理器
            warnings.warn(
                "This `ForceTokensLogitsProcessor` has been deprecated and will be removed in v4.40. Should you need to provide prompt ids for generation, specify `input_ids` to the generate method for decoder-only models, or `decoder_input_ids` for encoder-decoder models.",
                FutureWarning,
            )

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 根据传入的 input_ids 和 scores 处理 logits
        generation_idx = input_ids.shape[-1]  # 获取生成的索引
        current_token = self.force_token_map.get(generation_idx, None)  # 获取当前索引对应的强制令牌
        if current_token is not None:
            # 如果当前令牌不为 None,则将所有 scores 设置为负无穷大,并将当前令牌的 score 设置为 0
            scores[:, :] = -float("inf")
            scores[:, current_token] = 0
        return scores
class WhisperTimeStampLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input
    tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure
    that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is
    done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted
    probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those
    non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other
    potential tokens.


    See [the paper](https://arxiv.org/abs/2212.04356) for more information.

    Args:
        generate_config (`GenerateConfig`):
            The generate config used to generate the output. The following parameters are required:
                eos_token_id (`int`, *optional*, defaults to 50257):
                    The id of the *end-of-sequence* token.
                no_timestamps_token_id (`int`, *optional*, defaults to 50363):
                    The id of the `"<|notimestamps|>"` token.
                max_initial_timestamp_index (`int`, *optional*, defaults to 1):
                    Used to set the maximum value of the initial timestamp. This is used to prevent the model from
                    predicting timestamps that are too far in the future.
        begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
        _detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.

    Examples:
    ``` python
    >>> import torch
    >>> from transformers import AutoProcessor, WhisperForConditionalGeneration, GenerationConfig
    >>> from datasets import load_dataset

    >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
    >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt")
    >>> input_features = inputs.input_features

    >>> #Displaying timestamps
    >>> generated_ids = model.generate(inputs=input_features, return_timestamps=True)
    >>> transcription = processor.batch_decode(generated_ids, decode_with_timestamps=True)[0]
    >>> print("Transcription:", transcription)
    Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|>


    >>> #No timestamps & change EOS:
    ```
    """
    # 初始化函数,接受生成配置、可选的起始索引和检测时间戳的标志位
    def __init__(
        self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
    ):  # support for the kwargs
        # 设置不带时间戳的特殊 token ID
        self.no_timestamps_token_id = generate_config.no_timestamps_token_id
        # 计算时间戳起始的 token ID
        self.timestamp_begin = generate_config.no_timestamps_token_id + 1
        # 设置终止生成的 token ID,可以从生成配置的 EOS 或 BOS token ID 中获取
        self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id

        # 用于测试的变量,控制是否通过对数概率检测时间戳
        self._detect_timestamp_from_logprob = (
            _detect_timestamp_from_logprob
            if _detect_timestamp_from_logprob is not None
            else getattr(generate_config, "_detect_timestamp_from_logprob", True)
        )

        # 计算开始索引,考虑到强制解码器 ID 的数量
        num_forced_ids = (
            len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
        )
        self.begin_index = begin_index or (num_forced_ids + 1)

        # 最大初始时间戳索引,从生成配置中获取,默认为 None
        self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
        # TODO(Patrick): 确保官方模型将 max_initial_timestamp_index 设置为 50
        # self.max_initial_timestamp_index = 50

    # 设置起始索引的方法
    def set_begin_index(self, begin_index):
        self.begin_index = begin_index

    # 添加文档字符串,描述输入的 logits 处理器的输入
    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    """
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # suppress <|notimestamps|> which is handled by without_timestamps
        # 将不带时间戳的标记 <|notimestamps|> 的分数设为负无穷,这些标记由 without_timestamps 处理
        scores[:, self.no_timestamps_token_id] = -float("inf")

        # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
        # 时间戳必须成对出现,除非直接位于 eos_token 前面;相应地屏蔽对数几率
        for k in range(input_ids.shape[0]):
            sampled_tokens = input_ids[k, self.begin_index :]
            seq = list(sampled_tokens.tolist())

            last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
            penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin

            if last_was_timestamp:
                if penultimate_was_timestamp:
                    # has to be non-timestamp
                    # 必须是非时间戳
                    scores[k, self.timestamp_begin :] = -float("inf")
                else:
                    # cannot be normal text tokens
                    # 不能是正常文本标记
                    scores[k, : self.eos_token_id] = -float("inf")

            timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
            if timestamps.numel() > 0:
                # `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
                # `timestamps` 不应减少;禁止小于最后一个时间戳标记的时间戳标记
                if last_was_timestamp and not penultimate_was_timestamp:
                    timestamp_last = timestamps[-1]
                else:
                    # Avoid to emit <|0.00|> again
                    # 避免再次生成 <|0.00|>
                    timestamp_last = timestamps[-1] + 1

                scores[k, self.timestamp_begin : timestamp_last] = -float("inf")

        # apply the `max_initial_timestamp` option
        # 应用 `max_initial_timestamp` 选项
        if input_ids.shape[1] == self.begin_index:
            scores[:, : self.timestamp_begin] = -float("inf")

            if self.max_initial_timestamp_index is not None:
                last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
                scores[:, last_allowed + 1 :] = -float("inf")

        # if sum of probability over timestamps is above any other token, sample timestamp
        # 如果时间戳的概率和高于其他任何标记,则采样时间戳
        logprobs = torch.nn.functional.log_softmax(scores.float(), dim=-1)
        for k in range(input_ids.shape[0]):
            timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
            max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
            if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
                scores[k, : self.timestamp_begin] = -float("inf")

        return scores
class WhisperNoSpeechDetection(LogitsProcessor):
    r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""

    def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
        self.no_speech_token = no_speech_token
        # 原始实现中,<start-of-transcription> 标记的偏移量,等于第一个生成的标记的位置索引
        self.start_of_trans_offset = begin_index

        # `self.begin_index` 是一个实时变化的值
        self.begin_index = begin_index
        self._no_speech_prob = [0.0]
        self.is_scores_logprobs = scores_is_logprobs

        # 动态覆盖的属性
        self.model = None
        self.inputs = None

    def set_model(self, model):
        self.model = model

    def set_inputs(self, inputs):
        # 准备用于生成的输入,并将其与原始输入合并
        self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
        self.inputs["input_features"] = self.inputs.pop("inputs")

    @property
    def no_speech_prob(self):
        return self._no_speech_prob

    def set_begin_index(self, begin_index):
        self.begin_index = begin_index

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if input_ids.shape[1] == self.begin_index:
            if self.start_of_trans_offset > 1:
                with torch.no_grad():
                    logits = self.model(**self.inputs).logits

                no_speech_index = self.begin_index - self.start_of_trans_offset
                no_speech_scores = logits[:, no_speech_index]
            else:
                no_speech_scores = scores

            if self.is_scores_logprobs:
                probs = no_speech_scores.exp()
            else:
                probs = no_speech_scores.float().softmax(dim=-1)

            self._no_speech_prob = probs[:, self.no_speech_token]

        return scores


class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
    where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
    correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
    weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.

    See [the paper](https://arxiv.org/abs/2306.05284) for more information.

    <Tip warning={true}>

    This logits processor is exclusively compatible with
    [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen)

    </Tip>
    def __init__(self, guidance_scale):
        # 初始化方法,接受一个参数 guidance_scale,用于设置分类器自由引导(CFG)的比例尺。CFG 通过设置 `guidance_scale > 1` 启用。
        # 较高的 guidance_scale 鼓励模型生成与输入提示更紧密相关的样本,但通常会导致质量较差的生成结果。
        if guidance_scale > 1:
            # 如果 guidance_scale 大于 1,则将其赋值给实例变量 self.guidance_scale
            self.guidance_scale = guidance_scale
        else:
            # 如果 guidance_scale 不大于 1,则抛出 ValueError 异常,提示需要 guidance_scale 大于 1 才能使用分类器自由引导处理器。
            raise ValueError(
                "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
                f"{guidance_scale}."
            )

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 简单检查确保 logits 分数(条件和非条件)与输入的 input_ids(仅条件)具有兼容的批次大小。
        if scores.shape[0] != 2 * input_ids.shape[0]:
            # 如果 logits 的批次大小不是 input_ids 批次大小的两倍,则抛出 ValueError 异常。
            raise ValueError(
                f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
                f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
                f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
            )
        # 计算非引导批次大小
        unguided_bsz = scores.shape[0] // 2
        # 将 scores 按照非引导批次大小分割成条件 logits 和非条件 logits
        cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
        # 应用 guidance_scale 对 scores 进行加权处理,增强生成的条件性输出
        scores = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
        # 返回处理后的 scores
        return scores
class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing alternated generation between the two codebooks of Bark.

    <Tip warning={true}>
    
    This logits processor is exclusively compatible with
    [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark)'s fine submodel. See the model documentation
    for examples.
    
    </Tip>

    Args:
        input_start_len (`int`):
            The length of the initial input sequence.
        semantic_vocab_size (`int`):
            Vocabulary size of the semantic part, i.e number of tokens associated to the semantic vocabulary.
        codebook_size (`int`):
            Number of tokens associated to the codebook.
    """

    def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size: int):
        if not isinstance(input_start_len, int) or input_start_len < 0:
            raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}")

        # 初始化函数,验证并设置输入的起始长度、语义词汇表大小和码书大小
        self.input_start_len = input_start_len
        self.semantic_vocab_size = semantic_vocab_size
        self.codebook_size = codebook_size

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # 获取当前输入序列的长度
        curr_len = input_ids.shape[-1]

        # 判断当前序列长度决定使用哪个码书:偶数长度使用第一个码书,奇数长度使用第二个码书
        is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0

        if is_first_codebook:
            # 如果是第一个码书,将第一个码书的部分置为负无穷,表示不考虑这些部分的生成
            scores[:, : self.semantic_vocab_size] = -float("inf")
            scores[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf")
        else:
            # 如果是第二个码书,将第二个码书的部分置为负无穷,表示不考虑这些部分的生成
            scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")

        # 返回处理后的得分张量
        return scores


class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
    r"""
    Logits processor for Classifier-Free Guidance (CFG). The processors computes a weighted average across scores
    from prompt conditional and prompt unconditional (or negative) logits, parameterized by the `guidance_scale`.
    The unconditional scores are computed internally by prompting `model` with the `unconditional_ids` branch.

    See [the paper](https://arxiv.org/abs/2306.17806) for more information.
    """
    Args:
        guidance_scale (`float`):
            CFG的引导比例,用于分类器自由引导。通过设置 `guidance_scale != 1` 来启用CFG。较高的引导比例鼓励模型生成与输入提示更紧密相关的样本,通常会以较差的质量为代价。小于1的值具有相反的效果,同时使得提供的负提示(如果有的话)作为正提示。
        model (`PreTrainedModel`):
            计算无条件分数的模型。假定与计算条件分数的模型相同。这两个模型必须使用相同的分词器。
        unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            无条件分支中输入序列标记在词汇表中的索引。如果未设置,则默认为提示的最后一个标记。
        unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            用于无条件_ids的注意力掩码。
        use_cache (`bool`, *optional*, defaults to `True`):
            是否在负提示前向传递过程中缓存键/值对。

    Examples:

    ```
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM

    >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    >>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
    >>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
    >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    'Today, a dragon flew over Paris, France, killing at least 50 people and injuring more than 100'

    >>> # with a negative prompt
    >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
    >>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
    >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    'Today, a dragon flew over Paris, France, killing at least 130 people. French media reported that'

    >>> # with a positive prompt
    >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
    >>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"])
    >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
    "Today, a dragon flew over Paris, France, and I'm very happy to be here. I"
    ```
    ):
        self.guidance_scale = guidance_scale
        self.model = model
        self.unconditional_context = {
            "input_ids": unconditional_ids,
            "attention_mask": unconditional_attention_mask,
            "use_cache": use_cache,
            "past_key_values": None,
            "first_pass": True,
        }


        # 初始化方法,设置对象的初始属性
        self.guidance_scale = guidance_scale  # 设置引导尺度
        self.model = model  # 设置模型
        # 设置无条件生成的上下文信息,包括输入id、注意力掩码、是否使用缓存、过去的键值对和第一次通行标志
        self.unconditional_context = {
            "input_ids": unconditional_ids,
            "attention_mask": unconditional_attention_mask,
            "use_cache": use_cache,
            "past_key_values": None,
            "first_pass": True,
        }

    def get_unconditional_logits(self, input_ids):
        if self.unconditional_context["first_pass"]:
            if self.unconditional_context["input_ids"] is None:
                self.unconditional_context["input_ids"] = input_ids[:, -1:]
            if self.unconditional_context["attention_mask"] is None:
                self.unconditional_context["attention_mask"] = torch.ones_like(
                    self.unconditional_context["input_ids"], dtype=torch.long
                )
            input_ids = self.unconditional_context["input_ids"]
            attention_mask = self.unconditional_context["attention_mask"]
            self.unconditional_context["first_pass"] = False
        else:
            attention_mask = torch.cat(
                [
                    self.unconditional_context["attention_mask"],
                    torch.ones_like(input_ids[:, -1:], dtype=torch.long),
                ],
                dim=1,
            )
            if not self.unconditional_context["use_cache"]:
                input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
            else:
                input_ids = input_ids[:, -1:]
            self.unconditional_context["input_ids"] = input_ids
            self.unconditional_context["attention_mask"] = attention_mask


        # 根据上下文信息进行无条件生成的logits计算
        if self.unconditional_context["first_pass"]:
            # 如果是第一次通行,则根据输入的最后一个token设置初始输入id和注意力掩码
            if self.unconditional_context["input_ids"] is None:
                self.unconditional_context["input_ids"] = input_ids[:, -1:]
            if self.unconditional_context["attention_mask"] is None:
                self.unconditional_context["attention_mask"] = torch.ones_like(
                    self.unconditional_context["input_ids"], dtype=torch.long
                )
            input_ids = self.unconditional_context["input_ids"]
            attention_mask = self.unconditional_context["attention_mask"]
            self.unconditional_context["first_pass"] = False
        else:
            # 如果不是第一次通行,则根据是否使用缓存来更新输入id和注意力掩码
            attention_mask = torch.cat(
                [
                    self.unconditional_context["attention_mask"],
                    torch.ones_like(input_ids[:, -1:], dtype=torch.long),
                ],
                dim=1,
            )
            if not self.unconditional_context["use_cache"]:
                input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
            else:
                input_ids = input_ids[:, -1:]
            self.unconditional_context["input_ids"] = input_ids
            self.unconditional_context["attention_mask"] = attention_mask

        # 调用模型生成输出,传入当前的输入id、注意力掩码、是否使用缓存以及过去的键值对
        out = self.model(
            input_ids,
            attention_mask=attention_mask,
            use_cache=self.unconditional_context["use_cache"],
            past_key_values=self.unconditional_context["past_key_values"],
        )
        self.unconditional_context["past_key_values"] = out.get("past_key_values", None)

        return out.logits


    def __call__(self, input_ids, scores):
        scores = torch.nn.functional.log_softmax(scores, dim=-1)
        if self.guidance_scale == 1:
            return scores

        logits = self.get_unconditional_logits(input_ids)

        # 计算无条件logits的对数softmax
        unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
        # 根据引导尺度调整得分的对数softmax并加上无条件生成的对数softmax
        out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
        return out
class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
    r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.

    <Tip warning={true}>

    This logits processor is exclusively compatible with
    [Bark](https://huggingface.co/docs/transformers/en/model_doc/bark). See the model documentation for examples.

    </Tip>

    Args:
        eos_token_id (`Union[int, List[int]]`):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
        min_eos_p (`float`, *optional*):
            Minimum end of speech threshold.
    """

    def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float):
        # Convert eos_token_id to a list if it's provided as an integer
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        self.eos_token_id = eos_token_id
        # Validate min_eos_p is a positive float if provided
        if min_eos_p is not None and min_eos_p <= 0:
            raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
        self.min_eos_p = min_eos_p

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        # Check if min_eos_p is set
        if self.min_eos_p:
            # Compute softmax probabilities across the last dimension of scores tensor
            probs = torch.nn.functional.softmax(scores.float(), dim=-1)
            # Initialize a tensor with -inf values except for the eos_token_id
            early_stop_scores = torch.ones_like(scores) * -float("inf")
            early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
            
            # Determine if any EOS token's probability exceeds min_eos_p
            do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
            do_early_stop = torch.any(do_early_stop, dim=1, keepdim=True)
            # Conditionally replace scores with early_stop_scores where needed
            scores = torch.where(do_early_stop, early_stop_scores, scores)

        return scores
posted @ 2024-06-30 15:35  绝不原创的飞龙  阅读(101)  评论(0编辑  收藏  举报