Transformers-源码解析-五-

Transformers 源码解析(五)

.\generation\stopping_criteria.py

# 导入时间模块,用于处理时间相关功能
import time
# 导入警告模块,用于发出警告信息
import warnings
# 导入抽象基类模块,用于定义抽象类
from abc import ABC
# 导入深拷贝函数,用于创建对象的深层副本
from copy import deepcopy
# 导入类型提示模块,用于指定参数和返回值的类型
from typing import Optional

# 导入PyTorch库
import torch

# 从本地utils模块中导入指定函数和类
from ..utils import add_start_docstrings, logging

# 从logging模块中获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 停止条件的文档字符串,使用原始字符串表示(r"..."),包含参数和返回值的描述
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
            Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
            or scores for each vocabulary token after SoftMax. If this stopping criteria depends on the `scores` input,
            make sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`.
        kwargs (`Dict[str, Any]`, *optional*):
            Additional stopping criteria specific kwargs.

    Return:
        `torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
            for a particular row, `True` indicates we should continue.

"""


class StoppingCriteria(ABC):
    """Abstract base class for all stopping criteria that can be applied during generation.

    If your stopping criteria depends on the `scores` input, make sure you pass `return_dict_in_generate=True,
    output_scores=True` to `generate`.
    """

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        # 抽象方法,子类需实现该方法来定义停止生成的具体逻辑
        raise NotImplementedError("StoppingCriteria needs to be subclassed")


class MaxLengthCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
    in mind for decoder-only type of transformers, this will include the initial prompted tokens.

    Args:
        max_length (`int`):
            The maximum length that the output sequence can have in number of tokens.
        max_position_embeddings (`int`, *optional*):
            The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
    """

    def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None):
        # 初始化最大长度和最大位置嵌入
        self.max_length = max_length
        self.max_position_embeddings = max_position_embeddings

    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    # 定义一个调用函数,用于生成文本序列的逻辑
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        # 获取当前输入序列的长度
        cur_len = input_ids.shape[-1]
        # 检查当前序列长度是否已经达到或超过最大生成长度
        is_done = cur_len >= self.max_length
        # 如果模型限制了最大位置嵌入数量且当前长度未达到生成上限,并且当前长度已经超过最大位置嵌入数量,则发出警告
        if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
            logger.warning_once(
                "This is a friendly reminder - the current text generation call will exceed the model's predefined "
                f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
                "exceptions, performance degradation, or nothing at all."
            )
        # 返回一个布尔张量,表示每个输入序列是否已完成生成
        return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
# 继承自 `StoppingCriteria` 类的子类 `MaxNewTokensCriteria`,用于在生成的标记数超过 `max_new_tokens` 时停止生成。
class MaxNewTokensCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in
    mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very
    close to `MaxLengthCriteria` but ignores the number of initial tokens.

    Args:
        start_length (`int`):
            The number of initial tokens.
        max_new_tokens (`int`):
            The maximum number of tokens to generate.
    """

    # 初始化方法,发出警告信息表明该类已被弃用,建议使用 `MaxLengthCriteria` 替代
    def __init__(self, start_length: int, max_new_tokens: int):
        warnings.warn(
            "The class `MaxNewTokensCriteria` is deprecated. "
            f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` "
            "with `max_length = start_length + max_new_tokens` instead.",
            FutureWarning,
        )
        # 初始化属性,记录初始标记数和允许生成的最大标记数
        self.start_length = start_length
        self.max_new_tokens = max_new_tokens
        self.max_length = start_length + max_new_tokens

    # 调用对象时的方法,检查是否达到生成的最大标记数
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        # 判断输入标记的长度是否大于等于设定的最大长度
        is_done = input_ids.shape[-1] >= self.max_length
        return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


# 继承自 `StoppingCriteria` 类的子类 `MaxTimeCriteria`,用于在生成时间超过 `max_time` 秒时停止生成。
class MaxTimeCriteria(StoppingCriteria):
    """
    This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
    time will start being counted when you initialize this function. You can override this by passing an
    `initial_time`.

    Args:
        max_time (`float`):
            The maximum allowed time in seconds for the generation.
        initial_time (`float`, *optional*, defaults to `time.time()`):
            The start of the generation allowed time.
    """

    # 初始化方法,记录最大允许生成时间和开始计时的时间戳(默认为当前时间)
    def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
        self.max_time = max_time
        self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp

    # 调用对象时的方法,检查是否超过了允许的生成时间
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        # 计算当前时间与初始时间戳之间的差值,判断是否超过了最大允许时间
        is_done = time.time() - self.initial_timestamp > self.max_time
        return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


# 继承自列表的子类 `StoppingCriteriaList`,用于存储多个停止生成的条件,并在任何一个条件满足时停止生成。
class StoppingCriteriaList(list):
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        # 初始化一个全为 False 的 torch.BoolTensor,表示生成未完成
        is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device)
        # 遍历存储的所有停止条件,如果任何一个条件返回 True,则更新 is_done 为 True
        for criteria in self:
            is_done = is_done | criteria(input_ids, scores, **kwargs)
        return is_done
    # 定义一个方法 `max_length`,返回类型是可选的整数(可能为None)
    def max_length(self) -> Optional[int]:
        # 遍历当前对象实例中的每一个停止条件
        for stopping_criterium in self:
            # 如果当前停止条件是 `MaxLengthCriteria` 类型的实例
            if isinstance(stopping_criterium, MaxLengthCriteria):
                # 返回 `MaxLengthCriteria` 实例中定义的最大长度
                return stopping_criterium.max_length
            # 如果当前停止条件是 `MaxNewTokensCriteria` 类型的实例
            elif isinstance(stopping_criterium, MaxNewTokensCriteria):
                # 返回 `MaxNewTokensCriteria` 实例中定义的最大长度
                return stopping_criterium.max_length
        # 如果没有找到符合条件的停止条件,返回 None
        return None
# 定义一个函数,用于验证停止条件列表是否符合规范,并返回更新后的停止条件列表对象
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
    # 获取停止条件列表中的最大长度
    stopping_max_length = stopping_criteria.max_length
    # 深度复制原始的停止条件列表对象,以免修改原始数据
    new_stopping_criteria = deepcopy(stopping_criteria)
    
    # 如果停止条件列表中的最大长度存在,并且与传入的 max_length 参数不相等
    if stopping_max_length is not None and stopping_max_length != max_length:
        # 发出警告,指出设置的停止条件最大长度与传入参数的最大长度不一致
        warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
    # 如果停止条件列表中的最大长度不存在
    elif stopping_max_length is None:
        # 向新的停止条件列表中添加一个新的最大长度停止条件对象
        new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
    
    # 返回更新后的停止条件列表对象
    return new_stopping_criteria

.\generation\streamers.py

# 从队列模块导入队列类
from queue import Queue
# 导入类型检查工具,用于类型提示
from typing import TYPE_CHECKING, Optional

# 如果 TYPE_CHECKING 为真,则从 ..models.auto 模块导入 AutoTokenizer 类
if TYPE_CHECKING:
    from ..models.auto import AutoTokenizer

# 基础流生成器的基类,用于所有生成器流类的继承
class BaseStreamer:
    """
    Base class from which `.generate()` streamers should inherit.
    """

    def put(self, value):
        """Function that is called by `.generate()` to push new tokens"""
        # 抛出未实现错误,子类需要实现该方法
        raise NotImplementedError()

    def end(self):
        """Function that is called by `.generate()` to signal the end of generation"""
        # 抛出未实现错误,子类需要实现该方法
        raise NotImplementedError()


class TextStreamer(BaseStreamer):
    """
    Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.

    <Tip warning={true}>

    The API for the streamer classes is still under development and may change in the future.

    </Tip>

    Parameters:
        tokenizer (`AutoTokenizer`):
            The tokenized used to decode the tokens.
        skip_prompt (`bool`, *optional*, defaults to `False`):
            Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
        decode_kwargs (`dict`, *optional*):
            Additional keyword arguments to pass to the tokenizer's `decode` method.

    Examples:

        ```
        >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

        >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
        >>> streamer = TextStreamer(tok)

        >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
        >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
        An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
        ```
    """

    def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
        # 初始化方法,接收一个自动标记器实例和可选参数
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.decode_kwargs = decode_kwargs

        # 用于流处理的变量
        self.token_cache = []  # 初始化空的标记缓存列表
        self.print_len = 0  # 初始化打印长度为 0
        self.next_tokens_are_prompt = True  # 初始化下一个标记为提示状态
    def put(self, value):
        """
        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
        """
        # 检查输入值的维度和批处理大小是否符合要求
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        # 如果设置跳过提示且下一个标记是提示,则跳过处理
        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        # 将新标记添加到缓存并进行解码
        self.token_cache.extend(value.tolist())
        text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)

        # 如果文本以换行符结尾,则刷新缓存
        if text.endswith("\n"):
            printable_text = text[self.print_len :]
            self.token_cache = []
            self.print_len = 0
        # 如果最后一个标记是CJK字符,则打印这些字符
        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
            printable_text = text[self.print_len :]
            self.print_len += len(printable_text)
        # 否则,打印直到最后一个空格字符(简单的启发式方法,避免打印不完整的单词)
        else:
            printable_text = text[self.print_len : text.rfind(" ") + 1]
            self.print_len += len(printable_text)

        # 调用处理最终文本的回调函数
        self.on_finalized_text(printable_text)

    def end(self):
        """Flushes any remaining cache and prints a newline to stdout."""
        # 如果缓存中还有剩余内容,则刷新缓存
        if len(self.token_cache) > 0:
            text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
            printable_text = text[self.print_len :]
            self.token_cache = []
            self.print_len = 0
        else:
            printable_text = ""

        # 设置下一个标记为提示
        self.next_tokens_are_prompt = True
        # 调用处理最终文本的回调函数,并标志流结束
        self.on_finalized_text(printable_text, stream_end=True)

    def on_finalized_text(self, text: str, stream_end: bool = False):
        """Prints the new text to stdout. If the stream is ending, also prints a newline."""
        # 将新文本输出到标准输出,如果流结束则打印换行符
        print(text, flush=True, end="" if not stream_end else None)
    # 检查给定的代码点(CP)是否是CJK字符的代码点
    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 这里定义的“中文字符”是指CJK统一表意字符(Unicode块)中的任何字符:
        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
        #
        # 需要注意,尽管名称中包含CJK统一表意字符,但并非所有日文和韩文字符都包含在内。
        # 现代韩文Hangul字母使用了不同的Unicode块,日文的*假名和片假名也是如此。
        # 这些字母用于书写以空格分隔的词语,因此不被特别对待,会像其他语言一样处理。
        if (
            (cp >= 0x4E00 and cp <= 0x9FFF)            # CJK统一表意字符(4E00-9FFF)
            or (cp >= 0x3400 and cp <= 0x4DBF)        # CJK统一表意字符扩展A(3400-4DBF)
            or (cp >= 0x20000 and cp <= 0x2A6DF)      # CJK统一表意字符扩展B(20000-2A6DF)
            or (cp >= 0x2A700 and cp <= 0x2B73F)      # CJK统一表意字符扩展C(2A700-2B73F)
            or (cp >= 0x2B740 and cp <= 0x2B81F)      # CJK统一表意字符扩展D(2B740-2B81F)
            or (cp >= 0x2B820 and cp <= 0x2CEAF)      # CJK兼容扩展(2B820-2CEAF)
            or (cp >= 0xF900 and cp <= 0xFAFF)        # CJK兼容象形文字(F900-FAFF)
            or (cp >= 0x2F800 and cp <= 0x2FA1F)      # CJK兼容表意文字补充(2F800-2FA1F)
        ):  # 如果CP位于任何上述范围内,则返回True,表示是中文字符
            return True

        # 如果不在以上范围内,则返回False,表示不是中文字符
        return False
class TextIteratorStreamer(TextStreamer):
    """
    Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
    useful for applications that benefit from accessing the generated text in a non-blocking way (e.g. in an interactive
    Gradio demo).

    <Tip warning={true}>

    The API for the streamer classes is still under development and may change in the future.

    </Tip>

    Parameters:
        tokenizer (`AutoTokenizer`):
            The tokenizer used to decode the tokens.
        skip_prompt (`bool`, *optional*, defaults to `False`):
            Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
        timeout (`float`, *optional*):
            The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
            in `.generate()`, when it is called in a separate thread.
        decode_kwargs (`dict`, *optional*):
            Additional keyword arguments to pass to the tokenizer's `decode` method.

    Examples:

        ```
        >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
        >>> from threading import Thread

        >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
        >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
        >>> streamer = TextIteratorStreamer(tok)

        >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
        >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
        >>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
        >>> thread.start()
        >>> generated_text = ""
        >>> for new_text in streamer:
        ...     generated_text += new_text
        >>> generated_text
        'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
        ```
    """

    def __init__(
        self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
    ):
        # 调用父类的初始化方法,传递 tokenizer 和 decode_kwargs
        super().__init__(tokenizer, skip_prompt, **decode_kwargs)
        # 创建一个队列来存储生成的文本
        self.text_queue = Queue()
        # 初始化停止信号为 None
        self.stop_signal = None
        # 设置超时时间
        self.timeout = timeout

    def on_finalized_text(self, text: str, stream_end: bool = False):
        """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
        # 将新生成的文本放入队列中,如果流结束,则也放入停止信号
        self.text_queue.put(text, timeout=self.timeout)
        if stream_end:
            self.text_queue.put(self.stop_signal, timeout=self.timeout)

    def __iter__(self):
        # 返回迭代器自身
        return self

    def __next__(self):
        # 从队列中获取值,如果是停止信号则抛出 StopIteration 异常,否则返回值
        value = self.text_queue.get(timeout=self.timeout)
        if value == self.stop_signal:
            raise StopIteration()
        else:
            return value

.\generation\tf_logits_process.py

# 导入模块inspect用于检查对象,并从typing导入List和Tuple
import inspect
from typing import List, Tuple

# 导入NumPy和TensorFlow库
import numpy as np
import tensorflow as tf

# 从上级目录的tf_utils模块导入stable_softmax函数
from ..tf_utils import stable_softmax
# 从上级目录的utils模块导入add_start_docstrings函数
from ..utils import add_start_docstrings
# 从utils.logging模块导入get_logger函数
from ..utils.logging import get_logger

# 使用get_logger函数获取当前模块的日志记录器
logger = get_logger(__name__)

# TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING是一个原始字符串,描述了TFLogitsProcessor和TFLogitsWarper类中__call__方法的参数和返回值
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
            Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
            search or log softmax for each vocabulary token when using beam search.
        cur_len (`int`):
            The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
            is the maximum length generate can produce, and we need to know which of its tokens are valid.
        kwargs (`Dict[str, Any]`, *optional*):
            Additional logits processor specific kwargs.

    Return:
        `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""

# TFLogitsProcessor类定义了一个抽象基类,用于在生成过程中应用的所有logit处理器
class TFLogitsProcessor:
    """Abstract base class for all logit processors that can be applied during generation."""

    # 使用add_start_docstrings装饰器,添加了TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING作为文档字符串
    @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        """TF method for processing logits."""
        # 抛出未实现错误,提示该类是抽象类,只能由继承它的类调用
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

# TFLogitsWarper类定义了一个抽象基类,用于在生成过程中使用多项式抽样时应用的所有logit包装器
class TFLogitsWarper:
    """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

    # 使用add_start_docstrings装饰器,添加了TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING作为文档字符串
    @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        """TF method for warping logits."""
        # 抛出未实现错误,提示该类是抽象类,只能由继承它的类调用
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )
# 定义一个继承自列表的类 `TFLogitsProcessorList`,用于存储一组 `TFLogitsProcessor` 对象,以便后续处理输入张量 `scores`。
# 该类添加了特定的 `__call__` 方法,用于对每个 `TFLogitsProcessor` 对象应用处理。
class TFLogitsProcessorList(list):
    
    # 使用装饰器 `add_start_docstrings` 应用输入参数的文档字符串 `TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING`
    @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
        # 遍历列表中的每个处理器 `processor`
        for processor in self:
            # 检索处理器 `processor` 的调用方法的参数列表
            function_args = inspect.signature(processor.__call__).parameters
            # 如果参数个数超过 3
            if len(function_args) > 3:
                # 检查是否传递了所有必需的参数到 `processor` 的调用方法
                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
                    raise ValueError(
                        f"Make sure that all the required parameters: {list(function_args.keys())} for "
                        f"{processor.__class__} are passed to the logits processor."
                    )
                # 调用 `processor` 的方法,并更新 `scores`
                scores = processor(input_ids, scores, cur_len, **kwargs)
            else:
                # 否则,调用 `processor` 的方法,并更新 `scores`
                scores = processor(input_ids, scores, cur_len)
        # 返回处理后的 `scores`
        return scores


# 定义一个继承自 `TFLogitsWarper` 的类 `TFTemperatureLogitsWarper`
# 用于温度调节(指数缩放输出概率分布)的 `TFLogitsWarper`
class TFTemperatureLogitsWarper(TFLogitsWarper):
    
    # 初始化方法,接受一个 `temperature` 参数作为温度值
    def __init__(self, temperature: float):
        # 如果 `temperature` 不是 `float` 类型或者不是严格正数,则抛出 `ValueError`
        if not isinstance(temperature, float) or not (temperature > 0):
            raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
        
        # 将 `temperature` 赋值给实例变量 `self.temperature`
        self.temperature = temperature
    
    # 调用方法,接受 `input_ids`、`scores`、`cur_len` 参数,返回处理后的 `scores`
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 将 `scores` 按 `self.temperature` 进行缩放处理
        scores = scores / self.temperature
        # 返回处理后的 `scores`
        return scores


# 定义一个继承自 `TFLogitsWarper` 的类 `TFTopKLogitsWarper`
# 用于进行 top-k 操作的 `TFLogitsWarper`,即保留概率最高的 `top_k` 个元素
class TFTopKLogitsWarper(TFLogitsWarper):
    
    # 初始化方法,接受 `top_k`、`filter_value`(可选,默认为 `-inf`)、`min_tokens_to_keep`(可选,默认为 `1`)三个参数
    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 如果 `top_k` 不是正整数,则抛出 `ValueError`
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
        
        # 将 `top_k` 和 `min_tokens_to_keep` 中的最大值赋值给实例变量 `self.top_k`
        self.top_k = max(top_k, min_tokens_to_keep)
        # 将 `filter_value` 赋值给实例变量 `self.filter_value`
        self.filter_value = filter_value
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 计算实际需要考虑的top_k值,确保不超过scores张量的最后一个维度的大小
        top_k = min(self.top_k, scores.shape[-1])  # Safety check
        
        # 创建一个布尔遮罩,标记所有概率小于top-k中最后一个概率的token
        indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
        
        # 根据遮罩,将需要移除的token对应的分数替换为过滤值self.filter_value
        next_scores = tf.where(indices_to_remove, self.filter_value, scores)
        
        # 返回更新后的分数张量
        return next_scores
    # `TFLogitsWarper`的子类,执行top-p截断,即限制保留加起来小于等于prob_cut_off的前几个最有可能的token。

    Args:
        top_p (`float`):
            如果设置为小于1的值,则只保留概率相加达到`top_p`或更高的最有可能的token用于生成。
        filter_value (`float`, *optional*, 默认为负无穷):
            所有被过滤的值将被设置为这个浮点数值。
        min_tokens_to_keep (`int`, *optional*, 默认为1):
            不能被过滤掉的最小token数目。
    """

    def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        # 检查top_p是否为浮点数且在0到1之间
        if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
            raise ValueError(f"`top_p`必须是一个大于0且小于1的浮点数,当前值为{top_p}")
        # 检查min_tokens_to_keep是否为正整数
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(f"`min_tokens_to_keep`必须是一个正整数,当前值为{min_tokens_to_keep}")

        # 初始化实例变量
        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 获取前k个最高分数和对应的索引
        topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])

        # 创建与scores相同形状的填充值为filter_value的张量
        mask_scores = tf.fill(scores.shape, self.filter_value)
        # 计算topk_scores的稳定softmax,并累积概率
        cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
        # 创建一个布尔掩码,标记累积概率小于top_p的位置
        score_mask = cumulative_probs < self.top_p

        # 将第一个false替换为true,确保包含大于top_p的token
        score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)

        # 确保保留至少min_tokens_to_keep个token
        score_mask = tf.concat(
            (
                tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
                score_mask[:, self.min_tokens_to_keep:],
            ),
            axis=-1,
        )

        # 根据掩码将不符合条件的值设为filter_value
        topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)

        # 恢复topk排序的顺序:将原始索引位置重新分散到张量中
        scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
        scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
        next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)

        return next_scores
    # 定义一个 TFLogitsProcessor 类,用于处理 logits(预测得分),实现通过设置 EOS 概率为 0 来强制最小长度。

    Args:
        min_length (`int`):
            最小长度,低于此长度时,`eos_token_id` 的得分被设置为 `-float("Inf")`。
        eos_token_id (`int`):
            *end-of-sequence*(EOS)标记的 id。
    """

    def __init__(self, min_length: int, eos_token_id: int):
        # 检查并设置 `min_length` 参数,必须为正整数
        if not isinstance(min_length, int) or min_length < 0:
            raise ValueError(f"`min_length` 必须是正整数,但其值为 {min_length}")

        # 检查并设置 `eos_token_id` 参数,必须为正整数
        if not isinstance(eos_token_id, int) or eos_token_id < 0:
            raise ValueError(f"`eos_token_id` 必须是正整数,但其值为 {eos_token_id}")

        # 初始化对象的属性
        self.min_length = min_length
        self.eos_token_id = eos_token_id

    def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
        # 创建一个掩码,标记出 scores 中等于 eos_token_id 的位置
        eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
        # 使用 tf.where 函数将 eos_token_id 的位置对应的 scores 设置为 -inf
        scores = tf.where(eos_token_id_mask, float("-inf"), scores)
        return scores

    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 如果当前长度 cur_len 小于 min_length,则应用 eos token 掩码
        scores = tf.cond(
            tf.less(cur_len, self.min_length),
            lambda: self._apply_eos_token_mask(scores),
            lambda: tf.identity(scores),
        )
        return scores
class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
    r"""
    [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: float):
        # 检查 penalty 参数是否为正浮点数,若不是则抛出 ValueError 异常
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
        # 我们希望在 `input_ids` 的位置上填充惩罚值。由于 XLA 不能处理运行时未知的形状,
        # 不能使用 `tf.unique`。因此,当给定行中的同一标记出现多次时,可能会有冗余更新。

        # 收集要应用的惩罚值
        logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
        logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
        logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)

        # 分散惩罚值
        token_penalties = tf.ones(logits.shape)
        batch_size = input_ids.shape[0]
        seq_len = tf.shape(input_ids)[1]  # 序列长度具有动态大小,因此使用动态形状
        indexable_prev_input_ids = tf.concat(
            (
                tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
                tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
            ),
            axis=1,
        )
        token_penalties = tf.tensor_scatter_nd_update(
            token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
        )
        return token_penalties

    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 创建分数惩罚
        score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)

        # 将分数乘以相应的惩罚值
        scores = tf.math.multiply(scores, score_penalties)

        return scores


class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
    """
    [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
    """
    Args:
        bad_words_ids (`List[List[int]]`):
            不允许生成的令牌 ID 列表的列表。为了获取不应出现在生成文本中的词汇的令牌,请确保在初始化分词器时设置 `add_prefix_space=True`,并使用 `tokenizer(bad_words, add_special_tokens=False).input_ids` 来获取这些词汇的令牌 ID 列表。对于某些较慢的分词器,`add_prefix_space` 参数是支持的,因为快速分词器的前缀行为来自于 `pre tokenizers`。详细信息请参阅 [这里](https://huggingface.co/docs/tokenizers/api/pre-tokenizers)。
        eos_token_id (`int`):
            *end-of-sequence*(EOS)令牌的 ID。
    """

    def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
        # 检查 `bad_words_ids` 是否为列表且非空
        if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
            raise ValueError(f"`bad_words_ids` 必须是非空列表,当前为 {bad_words_ids}。")
        # 检查 `bad_words_ids` 中的每个元素是否为列表
        if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
            raise ValueError(f"`bad_words_ids` 必须是列表的列表,当前为 {bad_words_ids}。")
        # 检查 `bad_words_ids` 中的每个元素是否为正整数列表
        if any(
            any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
            for bad_word_ids in bad_words_ids
        ):
            raise ValueError(
                f"`bad_words_ids` 中的每个列表必须是正整数列表,当前为 {bad_words_ids}。"
            )

        # 存储关于不允许的词汇的信息,使用三个张量:
        # 1. 一个矩形张量,包含禁止序列(用 `-1` 填充),用于完整数据比较
        self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
        # 2. 一个张量,包含每个禁止序列的未填充长度,用于快速长度比较
        bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
        # 检查禁止词汇序列的长度是否为零
        if any(word_len == 0 for word_len in bad_word_seqs_len):
            raise ValueError(f"禁止词汇序列 {bad_words_ids} 不能包含空列表")
        self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
        # 3. 一个张量,包含每个序列的最后一个令牌,便于访问可能被禁止的令牌
        self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
    def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
        def _tokens_match(bad_word_seq_number):
            def _len_one():
                # 如果坏序列只有一个标记,则始终屏蔽它
                return tf.cond(
                    tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
                    lambda: tf.ones((), dtype=tf.bool),
                    _len_greater_than_cur_len,
                )

            def _len_greater_than_cur_len():
                # 否则,如果坏序列比当前长度长,它们永远不会匹配
                return tf.cond(
                    tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
                    lambda: tf.zeros((), dtype=tf.bool),
                    _match_found,
                )

            def _match_found():
                # 最后,执行实际的比较。只有在之前的比较没有结果时才能调用(否则会导致索引异常)
                compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
                return tf.cond(
                    tf.math.reduce_all(
                        tf.math.equal(
                            row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
                        )
                    ),
                    lambda: tf.ones((), dtype=tf.bool),
                    lambda: tf.zeros((), dtype=tf.bool),
                )

            match = _len_one()
            return match

        # 将当前行与所有坏词序列进行比较,获取匹配的掩码
        match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
        row_banned_tokens = self.seq_forbidden_tokens[match_mask]
        return row_banned_tokens
        # 定义一个调用函数,接受输入的 `input_ids`(Tensor 类型)、分数 `scores`(Tensor 类型)、当前长度 `cur_len`(整数类型),返回更新后的分数 `scores`(Tensor 类型)
        # 我们希望在分数级别上屏蔽一些被禁止的令牌。由于被禁止的令牌取决于前一个 `input_ids`,它们可能对每一行具有不同的长度,甚至对某些行来说可能为空。
        # 为了保持简单并与 XLA 兼容,我们以逐行的方式进行操作。
        # TODO(Joao):这个函数可能会因为 `cur_len` 的增加而触发 XLA 重追踪。如果这成为频繁的瓶颈,请修复它。(将 `cur_len` 设为一个张量?)
        def _get_row_updated_score(row_inputs: Tuple[tf.Tensor]) -> tf.Tensor:
            # 获取当前行的输入 `row_input_ids` 和分数 `row_score`
            row_input_ids, row_score = row_inputs
            # 计算当前行被禁止的坏令牌列表,基于 `row_input_ids` 的前 `cur_len` 部分
            banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
            # 创建一个布尔类型的张量,表示被禁止的令牌的位置,其形状与 `row_score` 相同
            banned_tokens_mask = tf.scatter_nd(
                indices=tf.expand_dims(banned_tokens, axis=-1),
                updates=tf.ones_like(banned_tokens, dtype=tf.bool),
                shape=row_score.shape,
            )
            # 使用 `-inf` 替换被禁止令牌的位置上的分数,保持其它位置不变
            row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
            return row_score
        
        # 对每一行调用 `_get_row_updated_score` 函数,更新分数 `scores`,并返回更新后的 `scores`
        scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
        return scores
class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
    r"""
    [`TFLogitsProcessor`] that enforces no repetition of n-grams. See
    [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).

    Args:
        ngram_size (`int`):
            All ngrams of size `ngram_size` can only occur once.
    """

    def __init__(self, ngram_size: int):
        # 初始化方法,验证并设置 ngram_size 参数
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
        self.ngram_size = ngram_size

    def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
        # 计算禁止的 ngram tokens,用于防止 ngram 重复
        # 从 fairseq 中复制用于在 beam search 中实现 no_repeat_ngram
        if cur_len + 1 < self.ngram_size:
            # 如果当前长度加 1 小于 ngram_size,返回空列表表示没有禁止的 token
            return [[] for _ in range(num_hypos)]
        generated_ngrams = [{} for _ in range(num_hypos)]
        prev_input_ids = input_ids[:, :cur_len]
        for idx in range(num_hypos):
            gen_tokens = prev_input_ids[idx].numpy().tolist()
            generated_ngram = generated_ngrams[idx]
            for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
                prev_ngram_tuple = tuple(ngram[:-1])
                generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

        def _get_generated_ngrams(hypo_idx):
            # 在解码下一个 token 前,防止解码已经出现的 ngrams
            start_idx = cur_len + 1 - self.ngram_size
            ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
            return generated_ngrams[hypo_idx].get(ngram_idx, [])

        # 返回禁止的 tokens 列表
        banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]

        return banned_tokens

    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 调用对象时的处理方法,用于处理 logits
        # TODO (joao): enable XLA on this logits processor. See discussion and attempts in
        # https://github.com/huggingface/transformers/pull/16974
        if not tf.executing_eagerly():
            raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")

        batch_size, vocab_size = scores.shape
        # 计算禁止的 ngram tokens
        banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)

        # 创建禁止 tokens 的布尔掩码
        banned_tokens_indices_mask = []
        for banned_tokens_slice in banned_tokens:
            banned_tokens_indices_mask.append(
                [True if token in banned_tokens_slice else False for token in range(vocab_size)]
            )

        # 将禁止的 tokens 对应位置的 logits 设置为负无穷
        scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)

        return scores


class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
    r"""
    # 初始化函数,接受强制作为第一个生成标记的标记 ID
    def __init__(self, bos_token_id: int):
        # 如果 bos_token_id 小于 0,则引发值错误异常
        if bos_token_id < 0:
            raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
        # 将传入的 bos_token_id 分配给实例变量
        self.bos_token_id = bos_token_id

    # 调用函数,处理输入的 token IDs 和对应的分数,根据当前生成的长度 cur_len 进行调整
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 如果当前生成的长度为 1
        if cur_len == 1:
            # 获取批处理大小和标记数
            batch_size, num_tokens = scores.shape
            # 将 bos_token_id 列的分数设为 0
            scores = tf.zeros((batch_size, 1))
            # 如果 bos_token_id 大于 0,将除了第 bos_token_id 列外的分数设置为负无穷
            if self.bos_token_id > 0:
                scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
            # 如果 bos_token_id 小于 (num_tokens - 1),将除了第 bos_token_id 列外的分数设置为负无穷
            if self.bos_token_id < (num_tokens - 1):
                scores = tf.concat(
                    (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
                    axis=-1,
                )
        # 返回调整后的分数张量
        return scores
# 定义一个继承自 `TFLogitsProcessor` 的类,用于在达到 `max_length` 时强制指定的 token 成为生成序列的最后一个 token。
class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
    r"""
    [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.

    Args:
        max_length (`int`):
            The maximum length of the sequence to be generated.
        eos_token_id (`int`):
            The id of the token to force as the last generated token when `max_length` is reached.
    """

    # 初始化方法,设置 `max_length` 和 `eos_token_id`
    def __init__(self, max_length: int, eos_token_id: int):
        self.max_length = max_length
        # 如果 `eos_token_id` 小于 0,则抛出错误
        if eos_token_id < 0:
            raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
        self.eos_token_id = eos_token_id

    # 调用方法,根据当前生成的长度 `cur_len` 对 `scores` 进行处理
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 当当前长度 `cur_len` 等于 `max_length - 1` 时
        if cur_len == self.max_length - 1:
            batch_size, num_tokens = scores.shape
            # 将 `scores` 在 `eos_token_id` 列上的值设为 0
            scores = tf.zeros((batch_size, 1))
            # 在除了 `eos_token_id` 外的其他位置上的值设为负无穷
            if self.eos_token_id > 0:
                scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
            if self.eos_token_id < (num_tokens - 1):
                scores = tf.concat(
                    (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
                    axis=-1,
                )
        return scores


# 定义一个继承自 `TFLogitsProcessor` 的类,用于在生成序列开始时抑制一组 token 的生成。
class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
    r"""
    [`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
    generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
    sampled at the begining of the generation.
    """

    # 初始化方法,设置 `begin_suppress_tokens` 和 `begin_index`
    def __init__(self, begin_suppress_tokens, begin_index):
        self.begin_suppress_tokens = list(begin_suppress_tokens)
        self.begin_index = begin_index

    # 调用方法,根据当前生成的长度 `cur_len` 对 `scores` 进行处理
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 当当前长度 `cur_len` 等于 `begin_index` 时
        scores = tf.cond(
            tf.equal(cur_len, self.begin_index),
            # 使用 `tf.tensor_scatter_nd_update` 将 `scores` 中指定位置的值更新为负无穷
            lambda: tf.tensor_scatter_nd_update(
                scores,
                indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
                updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
            ),
            lambda: scores,  # 如果条件不满足,返回原始的 `scores`
        )
        return scores


# 定义一个继承自 `TFLogitsProcessor` 的类,用于抑制一组 token 的生成。
class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
    r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
    are not sampled."""

    # 初始化方法,设置 `suppress_tokens`
    def __init__(self, suppress_tokens):
        self.suppress_tokens = list(suppress_tokens)
    # 定义一个方法 __call__,该方法接受三个参数:input_ids 是 tf.Tensor 类型,scores 是 tf.Tensor 类型,cur_len 是整数类型
    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 使用 tf.tensor_scatter_nd_update 函数更新 scores 张量
        scores = tf.tensor_scatter_nd_update(
            # 更新的目标张量是 scores
            scores,
            # 更新操作的索引是一个列表推导式,生成所有 (i, token) 对的索引
            indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
            # 更新操作的值是一个列表推导式,生成所有需要更新的 -inf 值
            updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
        )
        # 返回更新后的 scores 张量
        return scores
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
    r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
    indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
    `-inf` so that they are sampled at their corresponding index."""

    def __init__(self, force_token_map: List[List[int]]):
        # 将输入的强制 token 映射列表转换为字典形式,格式为 {index: token}
        force_token_map = dict(force_token_map)
        
        # 创建一个数组 force_token_array,其长度为 force_token_map 中最大的索引加一,
        # 初始化所有元素为 -1,用于表示未被强制的 token
        force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
        
        # 遍历 force_token_map,将指定索引位置的 token 值存入 force_token_array
        for index, token in force_token_map.items():
            if token is not None:
                force_token_array[index] = token
        
        # 将 force_token_array 转换为 TensorFlow 张量,并存储在实例变量 self.force_token_array 中
        self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)

    def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
        # 定义内部函数 _force_token,用于处理强制 token 的逻辑
        def _force_token(generation_idx):
            batch_size = scores.shape[0]
            current_token = self.force_token_array[generation_idx]

            # 创建一个新的得分张量 new_scores,初始化为 -inf
            new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
            
            # 创建索引张量 indices,用于更新 new_scores 中的特定位置为 0
            indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
            updates = tf.zeros((batch_size,), dtype=scores.dtype)
            new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
            
            return new_scores
        
        # 根据当前序列长度 cur_len 和 force_token_array 的长度,决定是否对 scores 进行处理
        scores = tf.cond(
            tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
            # 如果当前长度大于等于 force_token_array 的长度,不进行处理,直接返回 scores
            lambda: tf.identity(scores),
            # 否则,根据 force_token_array 中对应位置的值决定是否强制 token
            lambda: tf.cond(
                tf.greater_equal(self.force_token_array[cur_len], 0),
                # 如果 force_token_array[cur_len] 大于等于 0,调用 _force_token 强制 token
                lambda: _force_token(cur_len),
                # 否则,不进行处理,直接返回 scores
                lambda: scores,
            ),
        )
        
        return scores

.\generation\tf_utils.py

# 导入所需的模块和库
import copy  # 导入 copy 模块,用于复制对象
import inspect  # 导入 inspect 模块,用于获取对象信息
import warnings  # 导入 warnings 模块,用于处理警告
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Any, Dict, Optional, Tuple, Union  # 导入类型提示模块

import numpy as np  # 导入 NumPy 库,并使用别名 np
import tensorflow as tf  # 导入 TensorFlow 库,并使用别名 tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice  # 导入特定函数

# 从相对路径中导入模型输出类
from ..modeling_tf_outputs import TFCausalLMOutputWithPast, TFSeq2SeqLMOutput
# 从相对路径中导入自动模型映射字典
from ..models.auto import (
    TF_MODEL_FOR_CAUSAL_LM_MAPPING,
    TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
)
# 从相对路径中导入 TensorFlow 工具函数和稳定 softmax 函数
from ..tf_utils import shape_list, stable_softmax
# 从相对路径中导入模型输出类和日志记录函数
from ..utils import ModelOutput, logging
# 从相对路径中导入生成配置类
from .configuration_utils import GenerationConfig
# 从相对路径中导入 TensorFlow 日志处理相关模块
from .tf_logits_process import (
    TFForcedBOSTokenLogitsProcessor,
    TFForcedEOSTokenLogitsProcessor,
    TFForceTokensLogitsProcessor,
    TFLogitsProcessorList,
    TFMinLengthLogitsProcessor,
    TFNoBadWordsLogitsProcessor,
    TFNoRepeatNGramLogitsProcessor,
    TFRepetitionPenaltyLogitsProcessor,
    TFSuppressTokensAtBeginLogitsProcessor,
    TFSuppressTokensLogitsProcessor,
    TFTemperatureLogitsWarper,
    TFTopKLogitsWarper,
    TFTopPLogitsWarper,
)

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义 TFGreedySearchDecoderOnlyOutput 类,继承自 ModelOutput 基类
@dataclass
class TFGreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.
    """
    pass  # 类定义结束
    # 参数列表:
    # sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
    #     生成的序列。第二维 (sequence_length) 要么等于 `max_length`,要么在所有批次由于 `eos_token_id` 而提前结束时要短。
    # scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
    #     语言建模头部的处理预测分数(SoftMax 之前的每个词汇标记的分数)在每个生成步骤。元组的 `tf.Tensor`,最多包含 `max_new_tokens` 个元素(每个生成的标记一个元素),每个张量的形状为 `(batch_size, config.vocab_size)`。
    # attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
    #     元组(每个生成的标记一个元素)的元组(解码器每一层一个元素)的 `tf.Tensor`,形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
    # hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
    #     元组(每个生成的标记一个元素)的元组(解码器每一层一个元素)的 `tf.Tensor`,形状为 `(batch_size, generated_length, hidden_size)`。

    sequences: tf.Tensor = None  # 初始化 sequences 变量为 None
    scores: Optional[Tuple[tf.Tensor]] = None  # 初始化 scores 变量为 None,类型为可选的元组
    attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None  # 初始化 attentions 变量为 None,类型为可选的元组的元组
    hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None  # 初始化 hidden_states 变量为 None,类型为可选的元组的元组
@dataclass
class TFGreedySearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
    """

    sequences: tf.Tensor = None
    # 生成的序列,形状为(batch_size, sequence_length),第二个维度(sequence_length)要么等于max_length,要么因为eos_token_id提前结束而较短
    scores: Optional[Tuple[tf.Tensor]] = None
    # 可选项,当传入output_scores=True或config.output_scores=True时返回,是语言建模头部的处理过的预测分数(SoftMax之前每个词汇标记的分数),每个生成步骤可能有多达max_new_tokens个元素,每个张量形状为(batch_size, config.vocab_size)
    encoder_attentions: Optional[Tuple[tf.Tensor]] = None
    # 可选项,当传入output_attentions=True或config.output_attentions=True时返回,元组的每个元素对应解码器每层的注意力张量,形状为(batch_size, num_heads, sequence_length, sequence_length)
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    # 可选项,当传入output_hidden_states=True或config.output_hidden_states=True时返回,元组的每个元素对应嵌入层和每层输出的隐藏状态张量,形状为(batch_size, sequence_length, hidden_size)
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 可选项,当传入output_attentions=True或config.output_attentions=True时返回,元组的每个元素对应每个生成的标记,每层解码器的注意力张量元组,形状为(batch_size, num_heads, generated_length, sequence_length)
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 可选项,当传入output_attentions=True或config.output_attentions=True时返回,元组的每个元素对应每个生成的标记,每层解码器的交叉注意力张量元组,形状为(batch_size, num_heads, generated_length, sequence_length)
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 可选项,当传入output_hidden_states=True或config.output_hidden_states=True时返回,元组的每个元素对应每个生成的标记,每层解码器的隐藏状态张量元组,形状为(batch_size, generated_length, hidden_size)
    # 定义一个可选的变量,用于存储编码器的隐藏状态(Tensor 的元组)。初始值为 None。
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    
    # 定义一个可选的变量,用于存储解码器注意力权重(Tensor 元组的元组)。初始值为 None。
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    
    # 定义一个可选的变量,用于存储交叉注意力权重(Tensor 元组的元组)。初始值为 None。
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    
    # 定义一个可选的变量,用于存储解码器的隐藏状态(Tensor 元组的元组)。初始值为 None。
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFSampleDecoderOnlyOutput(ModelOutput):
    """
    Decoder-only生成模型使用采样生成的输出的基类。

    Args:
        sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
            生成的序列。第二个维度(sequence_length)要么等于`max_length`,要么比`eos_token_id`提前结束。
        scores (`tuple(tf.Tensor)` *optional*, 当传入`output_scores=True`或者`config.output_scores=True`时返回):
            语言建模头部的处理过的预测分数(SoftMax之前的每个词汇标记的分数)在每个生成步骤中。
            元组中包含最多`max_new_tokens`个元素(每个生成的词汇标记一个元素),每个张量的形状为`(batch_size*num_return_sequences, config.vocab_size)`。
        attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当传入`output_attentions=True`或者`config.output_attentions=True`时返回):
            每个生成的词汇标记的元组(每个生成的词汇标记一个元素),其中包含解码器每一层的注意力张量。
            注意力张量的形状为`(num_return_sequences*batch_size, num_heads, generated_length, sequence_length)`。
        hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, 当传入`output_hidden_states=True`或者`config.output_hidden_states=True`时返回):
            每个生成的词汇标记的元组(每个生成的词汇标记一个元素),其中包含解码器每一层的隐藏状态张量。
            隐藏状态张量的形状为`(num_return_sequences*batch_size, generated_length, hidden_size)`。
    """

    sequences: tf.Tensor = None
    scores: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None


@dataclass
class TFSampleEncoderDecoderOutput(ModelOutput):
    """
    Encoder-decoder生成模型使用采样生成的输出的基类。可以通过encoder_attentions和encoder_hidden_states属性(分别通过decoder_attentions和decoder_hidden_states属性)访问解码器(分别是编码器)的隐藏状态和注意力权重。

    """
    # 定义函数的参数和它们的类型注解,这些参数用于接收生成的序列、预测分数、编码器注意力、编码器隐藏状态、
    # 解码器注意力、交叉注意力以及解码器隐藏状态。这些参数都是可选的,根据函数调用时传递的参数决定是否使用。
    Args:
        sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
            生成的序列。第二维(sequence_length)要么等于 `max_length`,要么因为 `eos_token_id` 导致所有批次提前结束而更短。
        scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            语言建模头部处理后的预测分数(在SoftMax之前的每个词汇标记的分数),每一代步骤有一个元组,包含最多 `max_new_tokens` 个元素,
            每个张量的形状为 `(batch_size*num_return_sequences, config.vocab_size)`。
        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            编码器注意力的元组(每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, num_heads, sequence_length, sequence_length)`。
        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            编码器隐藏状态的元组(每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, sequence_length, hidden_size)`。
        decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            解码器注意力的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, num_heads, generated_length, sequence_length)`。
        cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            交叉注意力的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
        decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            解码器隐藏状态的元组(每个生成的令牌一个元组,每个解码器层一个张量),形状为 `(batch_size*num_return_sequences, generated_length, hidden_size)`。
    
    # 初始化所有参数为 None,表示这些参数在函数调用时可以不传递,或者传递为 None,函数会根据需要进行处理。
    sequences: tf.Tensor = None
    scores: Optional[Tuple[tf.Tensor]] = None
    encoder_attentions: Optional[Tuple[tf.Tensor]] = None
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
# 使用 dataclass 装饰器定义 TFBeamSearchDecoderOnlyOutput 类,表示仅解码器使用 beam search 生成模型的输出。
@dataclass
class TFBeamSearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using beam search.
    解码器仅使用 beam search 生成模型的输出的基类。

    Args:
        sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
            生成的序列。第二维度(sequence_length)要么等于 `max_length`,要么由于 `eos_token_id` 导致所有批次提前结束而更短。
        sequences_scores (`tf.Tensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Final beam scores of the generated `sequences`.
            生成的 `sequences` 的最终 beam 分数。
        scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log
            softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
            beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
            with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
            每一代生成步骤中每个词汇标记的处理过的 beam 分数。包括每个词汇标记的 log softmax 分数和该 beam 中先前生成的标记的 log softmax 的总和。
        beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Beam indices of generated token id at each generation step. `tf.Tensor` of shape
            `(batch_size*num_return_sequences, sequence_length)`.
            每个生成步骤生成的标记 id 的 beam 索引。形状为 `(batch_size*num_return_sequences, sequence_length)` 的 `tf.Tensor`。
        attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
            每个生成的标记的元组(每个解码器层的一个元素)的元组(每个生成的标记的元素)的注意力张量。形状为 `(batch_size*num_beams, num_heads, generated_length, sequence_length)` 的 `tf.Tensor`。
        hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `tf.Tensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
            每个生成的标记的元组(每个解码器层的一个元素)的元组(每个生成的标记的元素)的隐藏状态张量。形状为 `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)` 的 `tf.Tensor`。
    """

    sequences: tf.Tensor = None  # 生成的序列
    sequences_scores: Optional[tf.Tensor] = None  # 生成序列的最终 beam 分数,可选
    scores: Optional[Tuple[tf.Tensor]] = None  # 每个生成步骤中每个词汇标记的处理过的 beam 分数,可选
    beam_indices: Optional[tf.Tensor] = None  # 每个生成步骤生成的标记 id 的 beam 索引,可选
    attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None  # 每个生成的标记的注意力张量,可选
    hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None  # 每个生成的标记的隐藏状态张量,可选


# 使用 dataclass 装饰器定义 TFBeamSearchEncoderDecoderOutput 类,表示编码器-解码器使用 beam search 生成模型的输出。
@dataclass
class TFBeamSearchEncoderDecoderOutput(ModelOutput):
    """
    Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
    of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
    编码器-解码器使用 beam search 生成模型的输出的基类。可以通过 encoder_attentions 和 encoder_hidden_states 访问解码器(或编码器)的隐藏状态和注意力权重。

    """
    # 定义一个包含多个属性的数据类,用于存储序列、分数、索引以及各种注意力和隐藏状态信息
    
    sequences: tf.Tensor = None
    # 序列数据,类型为 TensorFlow 的张量,初始值为 None
    sequences_scores: Optional[tf.Tensor] = None
    # 序列的分数数据,类型为可选的 TensorFlow 张量,初始值为 None
    scores: Optional[Tuple[tf.Tensor]] = None
    # 分数数据,类型为可选的 TensorFlow 张量元组,初始值为 None
    beam_indices: Optional[tf.Tensor] = None
    # Beam 搜索的索引数据,类型为可选的 TensorFlow 张量,初始值为 None
    encoder_attentions: Optional[Tuple[tf.Tensor]] = None
    # 编码器注意力数据,类型为可选的 TensorFlow 张量元组,初始值为 None
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    # 编码器隐藏状态数据,类型为可选的 TensorFlow 张量元组,初始值为 None
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 解码器注意力数据,类型为可选的嵌套元组的 TensorFlow 张量元组,初始值为 None
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 交叉注意力数据,类型为可选的嵌套元组的 TensorFlow 张量元组,初始值为 None
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
    # 解码器隐藏状态数据,类型为可选的嵌套元组的 TensorFlow 张量元组,初始值为 None
@dataclass
class TFBeamSampleDecoderOnlyOutput(ModelOutput):
    """
    Decoder-only生成模型使用Beam采样的输出的基类。

    Args:
        sequences (`tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
            生成的序列。第二维(sequence_length)要么等于`max_length`,要么因为`eos_token_id`导致所有批次提前结束而更短。
        sequences_scores (`tf.Tensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            生成的`sequences`的最终beam分数。
        scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            每个生成步骤中每个词汇标记的处理beam分数。每个元素为`tf.Tensor`的元组,最多有`max_new_tokens`个元素(每个生成的标记一个元素),每个张量的形状为`(batch_size*num_beams*num_return_sequences, config.vocab_size)`。
        beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            每个生成步骤生成的标记ID的beam索引。形状为`(batch_size*num_return_sequences, sequence_length)`的`tf.Tensor`。
        attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            每个生成的标记的注意力权重。元组(每个生成标记一个元素),元组(每个解码器层一个元素),`tf.Tensor`的元组,形状为`(batch_size*num_beams, num_heads, generated_length, sequence_length)`。
        hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            解码器每层的隐藏状态。元组(每个生成标记一个元素),元组(每个解码器层一个元素),`tf.Tensor`的元组,形状为`(batch_size*num_beams, generated_length, hidden_size)`。
    """

    sequences: tf.Tensor = None
    sequences_scores: Optional[tf.Tensor] = None
    scores: Optional[Tuple[tf.Tensor]] = None
    beam_indices: Optional[tf.Tensor] = None
    attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None


@dataclass
class TFBeamSampleEncoderDecoderOutput(ModelOutput):
    """
    Encoder-decoder生成模型使用Beam采样的输出的基类。可以通过encoder_attentions和encoder_hidden_states属性访问解码器(或者通过decoder_attentions和decoder_hidden_states属性访问编码器)的隐藏状态和注意力权重。

    """
    # 定义一个变量 sequences,类型为 tf.Tensor,初始值为 None
    sequences: tf.Tensor = None
    
    # 定义一个变量 sequences_scores,类型为 Optional[tf.Tensor],初始值为 None
    sequences_scores: Optional[tf.Tensor] = None
    
    # 定义一个变量 scores,类型为 Optional[Tuple[tf.Tensor]],初始值为 None
    scores: Optional[Tuple[tf.Tensor]] = None
    
    # 定义一个变量 beam_indices,类型为 Optional[tf.Tensor],初始值为 None
    beam_indices: Optional[tf.Tensor] = None
    
    # 定义一个变量 encoder_attentions,类型为 Optional[Tuple[tf.Tensor]],初始值为 None
    encoder_attentions: Optional[Tuple[tf.Tensor]] = None
    
    # 定义一个变量 encoder_hidden_states,类型为 Optional[Tuple[tf.Tensor]],初始值为 None
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    
    # 定义一个变量 decoder_attentions,类型为 Optional[Tuple[Tuple[tf.Tensor]]],初始值为 None
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    
    # 定义一个变量 cross_attentions,类型为 Optional[Tuple[Tuple[tf.Tensor]]],初始值为 None
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    
    # 定义一个变量 decoder_hidden_states,类型为 Optional[Tuple[Tuple[tf.Tensor]]],初始值为 None
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
@dataclass
class TFContrastiveSearchDecoderOnlyOutput(ModelOutput):
    """
    Decoder-only generation model output class for contrastive search.

    Args:
        sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        scores (`tuple(tf.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each
            generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `tf.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `tf.Tensor` of shape `(batch_size, generated_length, hidden_size)`.
    """

    sequences: tf.Tensor = None
    scores: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None


@dataclass
class TFContrastiveSearchEncoderDecoderOutput(ModelOutput):
    """
    Encoder-decoder generation model output class for contrastive search.

    Base class for outputs of encoder-decoder generation models using contrastive search. Hidden states and attention
    weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
    encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
    """
    """
    Args:
        sequences (`tf.Tensor` of shape `(batch_size, sequence_length)`):
            生成的序列。第二个维度 (sequence_length) 可能等于 `max_length`,或者如果所有批次由于 `eos_token_id` 而提前结束,则会更短。
        scores (`tuple(tf.Tensor)` *optional*, 当 `output_scores=True` 传递或 `config.output_scores=True` 时返回):
            语言建模头部处理后的预测分数(SoftMax 前每个词汇标记的分数),每个生成步骤一个元组元素,元素数最多为 `max_new_tokens`,每个张量的形状为 `(batch_size, config.vocab_size)`。
        encoder_attentions (`tuple(tf.Tensor)`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
            解码器每一层的注意力权重张量的元组,形状为 `(batch_size, num_heads, sequence_length, sequence_length)`。
        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, 当 `output_hidden_states=True` 传递或 `config.output_hidden_states=True` 时返回):
            解码器每一层的隐藏状态张量的元组,形状为 `(batch_size, sequence_length, hidden_size)`,包含从嵌入层开始的所有层的输出。
        decoder_attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
            每个生成的标记一个元组元素,其中每个元素是解码器每一层的注意力权重张量的元组,形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
        cross_attentions (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_attentions=True` 传递或 `config.output_attentions=True` 时返回):
            每个生成的标记一个元组元素,其中每个元素是解码器每一层与编码器的交叉注意力权重张量的元组,形状为 `(batch_size, num_heads, generated_length, sequence_length)`。
        decoder_hidden_states (`tuple(tuple(tf.Tensor))`, *optional*, 当 `output_hidden_states=True` 传递或 `config.output_hidden_states=True` 时返回):
            每个生成的标记一个元组元素,其中每个元素是解码器每一层的隐藏状态张量的元组,形状为 `(batch_size, generated_length, hidden_size)`。
    """

    sequences: tf.Tensor = None
    scores: Optional[Tuple[tf.Tensor]] = None
    encoder_attentions: Optional[Tuple[tf.Tensor]] = None
    encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
    decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    cross_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
    decoder_hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
# 定义类型别名,表示不同的生成器输出类型
TFGreedySearchOutput = Union[TFGreedySearchEncoderDecoderOutput, TFGreedySearchDecoderOnlyOutput]
TFSampleOutput = Union[TFSampleEncoderDecoderOutput, TFSampleDecoderOnlyOutput]
TFBeamSearchOutput = Union[TFBeamSearchEncoderDecoderOutput, TFBeamSearchDecoderOnlyOutput]
TFBeamSampleOutput = Union[TFBeamSampleEncoderDecoderOutput, TFBeamSampleDecoderOnlyOutput]
TFContrastiveSearchOutput = Union[TFContrastiveSearchEncoderDecoderOutput, TFContrastiveSearchDecoderOnlyOutput]
# 定义一个类型别名,表示所有生成器可能的输出类型
TFGenerateOutput = Union[
    TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, TFContrastiveSearchOutput
]

class TFGenerationMixin:
    """
    包含支持生成的所有函数的类,用作[`TFPreTrainedModel`]中的混合类。

    该类公开[`~generation.TFGenerationMixin.generate`],可以用于:
        - 当`num_beams=1`且`do_sample=False`时通过调用[`~generation.TFGenerationMixin.greedy_search`]进行*贪婪解码*
        - 当`penalty_alpha>0`且`top_k>1`时通过调用[`~generation.TFGenerationMixin.contrastive_search`]进行*对比搜索*
        - 当`num_beams=1`且`do_sample=True`时通过调用[`~generation.TFGenerationMixin.sample`]进行*多项式采样*
        - 当`num_beams>1`时通过调用[`~generation.TFGenerationMixin.beam_search`]进行*束搜索解码*

    不需要直接调用上述任何方法。而是将自定义参数值传递给 'generate' 方法。有关解码策略的更多信息,请参阅[text generation strategies guide](../generation_strategies)。
    """

    _seed_generator = None

    @property
    def seed_generator(self):
        # 警告:`seed_generator`已弃用,并将在未来版本中移除。
        warnings.warn("`seed_generator` is deprecated and will be removed in a future version.", UserWarning)
        if self._seed_generator is None:
            # 如果尚未初始化种子生成器,则从不确定状态创建一个随机生成器
            self._seed_generator = tf.random.Generator.from_non_deterministic_state()
        return self._seed_generator

    # 表示该类支持 XLA 生成
    supports_xla_generation = True

    def prepare_inputs_for_generation(self, *args, **kwargs):
        # 如果模型类想要使用 `generate` 方法,需要定义 `prepare_inputs_for_generation` 方法
        raise NotImplementedError(
            "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
        )

    def compute_transition_scores(
        self,
        sequences: tf.Tensor,
        scores: Tuple[tf.Tensor],
        beam_indices: Optional[tf.Tensor] = None,
        normalize_logits: bool = False,
    def _validate_model_class(self):
        """
        Confirms that the model class is compatible with generation. If not, raises an exception that points to the
        right class to use.
        """
        # 检查当前模型类是否可以生成文本
        if not self.can_generate():
            # 定义兼容生成操作的模型映射列表
            generate_compatible_mappings = [
                TF_MODEL_FOR_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
                TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
            ]
            generate_compatible_classes = set()
            # 遍历每个模型映射,检查当前模型类是否在其支持的模型中
            for model_mapping in generate_compatible_mappings:
                supported_models = model_mapping.get(type(self.config), default=None)
                if supported_models is not None:
                    generate_compatible_classes.add(supported_models.__name__)
            # 构建异常消息,指示当前模型类不支持生成操作,并推荐可用的兼容模型类
            exception_message = (
                f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
                "it doesn't have a language model head."
            )
            if generate_compatible_classes:
                exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
            # 抛出类型错误异常,包含详细的错误消息
            raise TypeError(exception_message)

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # 如果是编码-解码模型,排除在调用任何模型函数之前已处理的参数
        if self.config.is_encoder_decoder:
            for key in ["decoder_input_ids"]:
                model_kwargs.pop(key, None)

        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # 检查是否 prepare_inputs_for_generation 方法接受了 `kwargs` 或 `model_kwargs` 参数,以便处理可选的前向传递输入
        if "kwargs" in model_args or "model_kwargs" in model_args:
            model_args |= set(inspect.signature(self.call).parameters)
        # 检查每个传入的 model_kwargs 是否在模型参数中有对应的接收者
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

        if unused_model_args:
            # 抛出数值错误异常,指示有未使用的 model_kwargs 参数
            raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )
        ) -> tf.Tensor:
        # 检查输入是否为 input_ids 类型且是二维的,并且数据类型为 tf.int32 或 tf.int64
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
        # 检查输入中是否存在 pad_token_id,并且在 inputs 中至少有一个元素等于 pad_token_id
        is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
        # 检查 pad_token_id 是否不等于 eos_token_id(如果 eos_token_id 为 None,则始终为 True)
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)

        # 如果输入是 input_ids 且存在 pad_token_id 且 pad_token_id 不等于 eos_token_id,则生成 attention_mask
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return tf.cast(tf.math.not_equal(inputs, pad_token_id), dtype=tf.int32)
        else:
            # 否则返回一个全为 1 的 tensor,形状为 inputs.shape[:2]
            return tf.ones(inputs.shape[:2], dtype=tf.int32)

    def _prepare_encoder_decoder_kwargs_for_generation(
        self, inputs_tensor: tf.Tensor, model_kwargs, model_input_name: Optional[str] = None
    ) -> Dict[str, Any]:
        # 1. 获取编码器并存储编码器输出
        encoder = self.get_encoder()

        # 2. 从 model_kwargs 中准备编码器参数和编码器关键字参数
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        # 从 model_kwargs 中筛选出不以 irrelevant_prefix 开头的参数作为编码器参数
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }
        # 检查编码器的调用签名,将符合签名的参数留下来
        encoder_signature = set(inspect.signature(encoder.call).parameters)
        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
        if not encoder_accepts_wildcard:
            encoder_kwargs = {
                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
            }

        # 3. 视觉模型不使用 `attention_mask`
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor
        # 如果 model_input_name 不是 self.main_input_name,在 Keras 中必须始终传递第一个输入
        if model_input_name != self.main_input_name:
            encoder_kwargs[self.main_input_name] = None
        # 调用编码器并将编码器输出存储在 model_kwargs 中的 "encoder_outputs" 键下
        encoder_outputs = encoder(**encoder_kwargs)
        model_kwargs["encoder_outputs"] = encoder_outputs

        return model_kwargs

    def _prepare_decoder_input_ids_for_generation(
        self,
        batch_size: int,
        model_input_name: str,
        model_kwargs: Dict[str, tf.Tensor],
        decoder_start_token_id: int = None,
        bos_token_id: int = None,
    ) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
        # 1. 检查用户是否手动定义了 `decoder_input_ids`。为了方便输入命名,如果编码器没有将其用作主输入,也允许用户通过 `input_ids` 参数传递。
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
        # 如果 `input_ids` 在 `model_kwargs` 中,并且 `model_input_name` 不是 "input_ids",则也将其用作 `decoder_input_ids`
        elif "input_ids" in model_kwargs and model_input_name != "input_ids":
            decoder_input_ids = model_kwargs.pop("input_ids")
        else:
            # 否则,将 `decoder_input_ids` 设为 None
            decoder_input_ids = None

        # 2. 编码器-解码器模型期望 `decoder_input_ids` 以特殊标记开始。确保它符合这个要求。
        # 获取解码器的起始标记 ID
        decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
        # 用 `decoder_start_token_id` 创建起始的 `decoder_input_ids` 张量
        decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id

        # 如果没有用户输入 -> 使用 `decoder_start_token_id` 作为 `decoder_input_ids`
        if decoder_input_ids is None:
            decoder_input_ids = decoder_input_ids_start
        # 如果有用户输入但不以 `decoder_start_token_id` 开始 -> 在开头添加 `decoder_start_token_id`(并调整 `decoder_attention_mask` 如果提供了)
        elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id):
            decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1)
            if "decoder_attention_mask" in model_kwargs:
                # 调整 `decoder_attention_mask`,在开头增加一个标记
                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
                decoder_attention_mask = tf.concat(
                    (tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
                    axis=-1,
                )
                model_kwargs["decoder_attention_mask"] = decoder_attention_mask

        # 返回处理后的 `decoder_input_ids` 和可能修改过的 `model_kwargs`
        return decoder_input_ids, model_kwargs

    def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
        # 检索编码器-解码器模型的解码器起始标记 ID
        # 如果需要,回退到 `bos_token_id`
        decoder_start_token_id = (
            decoder_start_token_id
            if decoder_start_token_id is not None
            else self.generation_config.decoder_start_token_id
        )
        bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id

        # 如果 `decoder_start_token_id` 已定义,则返回它
        if decoder_start_token_id is not None:
            return decoder_start_token_id
        # 否则,如果 `bos_token_id` 已定义,则返回它
        elif bos_token_id is not None:
            return bos_token_id
        # 如果两者都未定义,则引发 ValueError 异常
        raise ValueError(
            "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
        )
    def _expand_inputs_for_generation(
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        input_ids: Optional[tf.Tensor] = None,
        expand_in_new_axis: bool = False,
        **model_kwargs,
    ) -> Tuple[tf.Tensor, Dict[str, Any]]:
        """
        Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...] or [batch_size, expand_size, ...],
        depending on `expand_in_new_axis`. Beam-based approaches expect this function to be used with
        `expand_in_new_axis=True`
        """
        
        def _expand_tensor(tensor: tf.Tensor):
            # 根据 `expand_in_new_axis` 参数选择不同的扩展方式
            if expand_in_new_axis:
                shape = shape_list(tensor)
                return tf.broadcast_to(tensor[:, None], (shape[0], expand_size) + tuple(shape[1:]))
            else:
                return tf.repeat(tensor, expand_size, axis=0)

        def _expand_dict_for_generation(dict_to_expand):
            # 遍历字典中的每个值,如果是 Tensor 类型且非空,则调用 `_expand_tensor` 函数扩展
            for key in dict_to_expand:
                if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], tf.Tensor):
                    dict_to_expand[key] = _expand_tensor(dict_to_expand[key])
            return dict_to_expand

        if input_ids is not None:
            # 如果 `input_ids` 不为空,则调用 `_expand_tensor` 函数扩展 `input_ids`
            input_ids = _expand_tensor(input_ids)

        # 调用 `_expand_dict_for_generation` 函数扩展 `model_kwargs`
        model_kwargs = _expand_dict_for_generation(model_kwargs)

        if is_encoder_decoder:
            # 如果是编码-解码模型,确保 `encoder_outputs` 在 `model_kwargs` 中定义
            if model_kwargs.get("encoder_outputs") is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
            # 调用 `_expand_dict_for_generation` 函数扩展 `encoder_outputs`
            model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])

        # 返回扩展后的 `input_ids` 和 `model_kwargs`
        return input_ids, model_kwargs

    def _prepare_model_inputs(
        self,
        inputs: Optional[tf.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
    ):
        """
        Prepares inputs for the model, optionally including a beginning-of-sequence token ID (`bos_token_id`).
        """
        # 此函数未提供实现,仅作为方法声明,用于准备模型的输入

    def _maybe_initialize_input_ids_for_generation(
        self,
        inputs: Optional[tf.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
    ):
        """
        Initializes `input_ids` for generation, optionally including a beginning-of-sequence token ID (`bos_token_id`).
        """
        # 此函数未提供实现,仅作为方法声明,用于为生成任务初始化 `input_ids`
    ) -> tf.Tensor:
        """Initializes input ids for generation, if necessary."""
        # 如果已经提供了输入,则直接返回输入
        if inputs is not None:
            return inputs

        # 获取模型参数中的编码器输出
        encoder_outputs = model_kwargs.get("encoder_outputs")
        # 如果是编码-解码模型并且有编码器输出,则创建一个全为-100的虚拟输入,以确保不会被用于编码
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            shape = encoder_outputs.last_hidden_state.shape[:-1]
            return tf.ones(shape, dtype=tf.int32) * -100

        # 如果未提供输入且未定义bos_token_id,则抛出异常
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")

        # 如果在 `model_kwargs` 中有张量,则可以从中推断批量大小。这对于软提示或基于解码器的多模态实现很有帮助。
        batch_size = 1
        for value in model_kwargs.values():
            if isinstance(value, tf.Tensor):
                batch_size = value.shape[0]
                break
        # 创建一个形状为(batch_size, 1)的全为bos_token_id的张量,作为初始化的输入
        return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id

    @staticmethod
    def _extract_past_from_model_output(outputs: ModelOutput):
        """Extracts past key values from model outputs."""
        past_key_values = None
        # 根据不同的输出结构,提取过去的键值
        if "past_key_values" in outputs:
            past_key_values = outputs.past_key_values
        elif "mems" in outputs:
            past_key_values = outputs.mems
        elif "past_buckets_states" in outputs:
            past_key_values = outputs.past_buckets_states
        return past_key_values

    def _update_model_kwargs_for_generation(
        self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
    ) -> Dict[str, Any]:
        """Updates model keyword arguments for generation."""
        # 更新模型参数中的过去键值
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(outputs)

        # 更新注意力掩码
        if not is_encoder_decoder:
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                # 将一个形状为原注意力掩码+1列的张量与原注意力掩码拼接,用于后续生成过程中的扩展
                model_kwargs["attention_mask"] = tf.concat(
                    [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
                )

        return model_kwargs

    def _update_model_kwargs_for_xla_generation(
        self,
        model_outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        cur_len: int,
        max_length: int,
        batch_size: int,
        is_encoder_decoder: bool = False,
        batch_axis: int = 0,
    ):
        """Updates model keyword arguments for XLA generation."""
        # 省略部分代码,未作注释要求的部分
        pass

    def _get_logits_warper(
        self,
        generation_config: GenerationConfig,
        # 省略部分代码,未作注释要求的部分
        ):
        """Gets the logits warper for generation."""
        pass
        ) -> TFLogitsProcessorList:
        """
        This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
        instances used for multinomial sampling.
        """

        # instantiate warpers list
        warpers = TFLogitsProcessorList()

        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
        # better score (i.e. keep len(generation_config.eos_token_id) + 1)
        if generation_config.num_beams > 1:
            if isinstance(generation_config.eos_token_id, list):
                min_tokens_to_keep = len(generation_config.eos_token_id) + 1
            else:
                min_tokens_to_keep = 2
        else:
            min_tokens_to_keep = 1

        # Check if temperature warping is enabled and add warper accordingly
        if generation_config.temperature is not None and generation_config.temperature != 1.0:
            warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
        
        # Check if top-k warping is enabled and add warper accordingly
        if generation_config.top_k is not None and generation_config.top_k != 0:
            warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
        
        # Check if top-p warping is enabled and add warper accordingly
        if generation_config.top_p is not None and generation_config.top_p < 1.0:
            warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
        
        # Return the list of warpers containing all configured logits processors
        return warpers
    ) -> TFLogitsProcessorList:
        """
        This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
        instances used to modify the scores of the language model head.
        """
        # 创建一个空的处理器列表
        processors = TFLogitsProcessorList()

        # 如果设定了重复惩罚并且不等于默认值 1.0,则添加重复惩罚处理器
        if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
            processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
        
        # 如果设定了禁止重复 n-gram 大小,并且大于 0,则添加禁止重复 n-gram 处理器
        if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
            processors.append(TFNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
        
        # 如果设定了要避免的词汇 ID 列表,则添加避免坏词汇处理器
        if generation_config.bad_words_ids is not None:
            processors.append(
                TFNoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
            )
        
        # 如果设定了最小生成长度、结束符号 ID,并且最小长度大于 0,则添加最小长度处理器
        if (
            generation_config.min_length is not None
            and generation_config.eos_token_id is not None
            and generation_config.min_length > 0
        ):
            processors.append(TFMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
        
        # 如果设定了强制起始符号 ID,则添加强制起始符号处理器
        if generation_config.forced_bos_token_id is not None:
            processors.append(TFForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
        
        # 如果设定了强制结束符号 ID,则添加强制结束符号处理器
        if generation_config.forced_eos_token_id is not None:
            processors.append(
                TFForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
            )
        
        # 如果设定了要抑制的 token 列表,则添加抑制 token 处理器
        if generation_config.suppress_tokens is not None:
            processors.append(TFSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
        
        # 如果设定了要在开头抑制的 token 列表,则添加在开头抑制 token 处理器
        if generation_config.begin_suppress_tokens is not None:
            begin_index = input_ids_seq_length
            begin_index = (
                begin_index
                if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
                else begin_index + 1
            )
            if generation_config.forced_decoder_ids is not None:
                begin_index += generation_config.forced_decoder_ids[-1][
                    0
                ]  # generation starts after the last token that is forced
            processors.append(
                TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
            )
        
        # 如果设定了强制生成的 token 列表,则添加强制生成 token 处理器
        if generation_config.forced_decoder_ids is not None:
            processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))

        # 合并默认处理器列表和自定义处理器列表,并返回最终的处理器列表
        processors = self._merge_criteria_processor_list(processors, logits_processor)
        return processors
    # 定义一个方法,接受一个自定义的 TFLogitsProcessorList 参数列表,并返回一个 TFLogitsProcessorList 对象
    def __init__(self, custom_list: List[TFLogitsProcessor] = []) -> TFLogitsProcessorList:
        # 如果 custom_list 是空的,则返回默认的 default_list
        if len(custom_list) == 0:
            return default_list
        # 遍历 default_list 中的每个元素
        for default in default_list:
            # 遍历 custom_list 中的每个元素
            for custom in custom_list:
                # 如果 custom 和 default 的类型相同
                if type(custom) is type(default):
                    # 设置对象类型为 "logits processor"
                    object_type = "logits processor"
                    # 抛出值错误异常,提醒用户 custom 对象已经存在
                    raise ValueError(
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `generate`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `generate` instead of using a custom {object_type}."
                    )
        # 将 custom_list 中的元素扩展到 default_list 中
        default_list.extend(custom_list)
        # 返回扩展后的 default_list
        return default_list

    # 定义一个贪婪搜索方法,接受多个参数和关键字参数
    def greedy_search(
        self,
        input_ids: tf.Tensor,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        logits_processor: Optional[TFLogitsProcessorList] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        **model_kwargs,
    ):
        # 方法实现省略

    # 定义一个采样方法,接受多个参数和关键字参数
    def sample(
        self,
        input_ids: tf.Tensor,
        logits_processor: Optional[TFLogitsProcessorList] = None,
        logits_warper: Optional[TFLogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        seed: Optional[Tuple[int, int]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        **model_kwargs,
    ):
        # 方法实现省略

    @staticmethod
    def _gather_beams(nested, beam_indices, batch_axis=0):
        """Gathers the beam slices indexed by beam_indices into new beam array."""

        def gather_fn(tensor):
            # 如果 batch_axis 大于 0,则将 batch_axis 之前的所有维度移到最后,以便得到形状为 (batch, beam_id, ...) 的张量
            if batch_axis > 0:
                perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
                tensor = tf.transpose(tensor, perm=perm)

            # 在 axis=1 上使用 beam_indices 进行 gather 操作,得到聚集后的张量
            gathered_tensor = tf.gather(params=tensor, indices=beam_indices, axis=1, batch_dims=1)
            
            # 如果 batch_axis 大于 0,则将张量恢复到原始的维度顺序
            if batch_axis > 0:
                perm = tf.concat((tf.range(tf.rank(tensor))[batch_axis:], tf.range(batch_axis)), axis=0)
                perm = tf.math.invert_permutation(perm)
                gathered_tensor = tf.transpose(gathered_tensor, perm=perm)

            return gathered_tensor

        # 对 nested 结构中的每个张量应用 gather_fn 函数,并返回新的结构
        return tf.nest.map_structure(gather_fn, nested)
# 将给定的值按照批次索引散布到张量中
def scatter_values_on_batch_indices(values, batch_indices):
    # 获取批次索引张量的形状
    shape = shape_list(batch_indices)
    # 扩展批次维度以匹配形状
    broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
    # 将批次索引转换为对应的索引对
    pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
    # 根据索引对将值散布到目标形状中
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)


def sample_without_replacement(logits, num_samples):
    """
    不重复的分类采样当前尚未实现,现在使用Gumbel-Max技巧代替,请参见
    https://github.com/tensorflow/tensorflow/issues/9260 获取更多信息
    """
    z = -tf.math.log(-tf.math.log(tf.random.uniform(shape_list(logits), 0, 1)))
    _, indices = tf.nn.top_k(logits + z, num_samples)
    return indices


def _ranking_fast(
    context_hidden: tf.Tensor,
    next_hidden: tf.Tensor,
    next_top_k_probs: tf.Tensor,
    alpha: float,
    beam_width: int,
) -> tf.Tensor:
    """
    根据文献《神经文本生成的对比框架》中描述的退化惩罚(与先前标记的余弦相似度)对top_k候选进行重新排序。
    返回批次中每行最佳候选的索引。
    """
    # 对上下文隐藏层进行归一化处理
    norm_context_hidden = context_hidden / tf.norm(context_hidden, axis=2, keepdims=True)
    # 对下一个隐藏层进行归一化处理
    norm_next_hidden = next_hidden / tf.norm(next_hidden, axis=2, keepdims=True)
    # 计算余弦相似度矩阵
    cosine_matrix = tf.squeeze(tf.linalg.matmul(norm_context_hidden, norm_next_hidden, transpose_b=True), axis=-1)
    # 计算最大余弦相似度的退化惩罚
    degeneration_penalty = tf.reduce_max(cosine_matrix, axis=-1)
    # 重塑下一个top_k概率
    next_top_k_probs = tf.reshape(next_top_k_probs, shape=[-1])
    # 计算对比分数,包括概率和退化惩罚
    contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
    contrastive_score = tf.reshape(contrastive_score, shape=[-1, beam_width])
    # 选择每行中最高对比分数的索引
    selected_idx = tf.argmax(contrastive_score, axis=1)
    return selected_idx

.\generation\utils.py

# coding=utf-8
# 版权声明和许可信息,指定了本文件使用的Apache License, Version 2.0许可
# 此处为代码导入所需的标准库、第三方库及自定义模块

import copy  # 导入copy模块,用于对象的浅复制和深复制操作
import inspect  # 导入inspect模块,用于获取对象信息
import warnings  # 导入warnings模块,用于警告处理
from dataclasses import dataclass  # 从dataclasses模块导入dataclass装饰器,用于简化数据类的定义
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union  # 导入类型提示相关的类和函数

import torch  # 导入PyTorch库
import torch.distributed as dist  # 导入PyTorch分布式训练相关模块
from torch import nn  # 从torch模块中导入nn模块,用于神经网络构建

from ..cache_utils import Cache, DynamicCache, StaticCache  # 导入缓存相关的自定义模块
from ..integrations.deepspeed import is_deepspeed_zero3_enabled  # 导入深度学习加速相关模块
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput  # 导入模型输出相关类
from ..models.auto import (  # 导入自动模型加载相关映射
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging  # 导入工具类和函数
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint  # 导入束搜索相关约束类
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer  # 导入束搜索相关评分器类
from .candidate_generator import (  # 导入候选生成器相关函数和类
    AssistedCandidateGenerator,
    CandidateGenerator,
    PromptLookupCandidateGenerator,
    _crop_past_key_values,
    _prepare_attention_mask,
    _prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig, GenerationMode  # 导入生成配置和模式相关类
from .logits_process import (  # 导入logits处理相关类
    EncoderNoRepeatNGramLogitsProcessor,
    EncoderRepetitionPenaltyLogitsProcessor,
    EpsilonLogitsWarper,
    EtaLogitsWarper,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    MinNewTokensLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SequenceBiasLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
    UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from .stopping_criteria import (  # 导入停止条件相关类
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)

if TYPE_CHECKING:
    # 从相对路径导入模块中的PreTrainedModel类,用于模型预训练
    # 从相对路径导入streamers模块中的BaseStreamer类,用作基础流处理器
    from ..modeling_utils import PreTrainedModel
    from .streamers import BaseStreamer
# 获取名为__name__的模块的日志记录器对象
logger = logging.get_logger(__name__)

# 如果加速可用,导入加速相关的钩子函数和模块扩展函数
if is_accelerate_available():
    from accelerate.hooks import AlignDevicesHook, add_hook_to_module

# 静态缓存类型映射,将字符串"static"映射到StaticCache类
NEED_SETUP_CACHE_CLASSES_MAPPING = {
    "static": StaticCache,
}

# 数据类,用于生成仅解码器输出的模型结果,继承自ModelOutput类
@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
    """
    Outputs of decoder-only generation models, when using non-beam methods.
    """
    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`):
            Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
        past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            NOTE: some models have a different `past_key_values` format, confirm with the model's documentation.
            Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value
            tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
            encoder_sequence_length, embed_size_per_head)`.
    # 声明一个可选的变量 hidden_states,其类型是一个元组,包含一个元组,该元组中包含一个 torch.FloatTensor 类型的值
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    
    # 声明一个可选的变量 past_key_values,其类型是一个元组,包含一个元组,该元组中包含一个元组,该元组中包含一个 torch.FloatTensor 类型的值
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
# 用于生成编码器-解码器模型的输出,非使用 Beam 方法时的情况
@dataclass
class GenerateEncoderDecoderOutput(ModelOutput):
    """
    编码器-解码器生成模型的输出,当不使用 Beam 方法时。

    """

    sequences: torch.LongTensor = None  # 生成的序列(token ID)
    scores: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的分数
    logits: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的 logits
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None  # 编码器注意力权重
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 编码器隐藏状态
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 解码器注意力权重
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 交叉注意力权重
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 解码器隐藏状态
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None  # 额外的过去键值(针对 Transformer 模型)

# 用于生成仅解码器模型的输出,使用 Beam 方法时的情况
@dataclass
class GenerateBeamDecoderOnlyOutput(ModelOutput):
    """
    解码器生成模型的输出,仅在使用 Beam 方法时。

    """

    sequences: torch.LongTensor = None  # 生成的序列(token ID)
    sequences_scores: Optional[torch.FloatTensor] = None  # 生成序列的分数
    scores: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的分数
    logits: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的 logits
    beam_indices: Optional[torch.LongTensor] = None  # Beam 搜索时使用的索引
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 注意力权重
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 隐藏状态
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None  # 额外的过去键值(针对 Transformer 模型)

# 用于生成编码器-解码器模型的输出,使用 Beam 方法时的情况
@dataclass
class GenerateBeamEncoderDecoderOutput(ModelOutput):
    """
    编码器-解码器生成模型的输出,使用 Beam 方法时。

    """

    sequences: torch.LongTensor = None  # 生成的序列(token ID)
    sequences_scores: Optional[torch.FloatTensor] = None  # 生成序列的分数
    scores: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的分数
    logits: Optional[Tuple[torch.FloatTensor]] = None  # 每个生成序列的 logits
    beam_indices: Optional[torch.LongTensor] = None  # Beam 搜索时使用的索引
    encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None  # 编码器注意力权重
    encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None  # 编码器隐藏状态
    decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 解码器注意力权重
    cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 交叉注意力权重
    decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None  # 解码器隐藏状态
    past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None  # 额外的过去键值(针对 Transformer 模型)

# 以下是为了向后兼容而保留的等效类
GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput  # 贪婪搜索解码器模型的输出
ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput  # 对比搜索解码器模型的输出
SampleDecoderOnlyOutput = GenerateDecoderOnlyOutput  # 示例解码器模型的输出

ContrastiveSearchEncoderDecoderOutput = GenerateEncoderDecoderOutput  # 对比搜索编码器-解码器模型的输出
GreedySearchEncoderDecoderOutput = GenerateEncoderDecoderOutput  # 贪婪搜索编码器-解码器模型的输出
SampleEncoderDecoderOutput = GenerateEncoderDecoderOutput  # 示例编码器-解码器模型的输出

BeamSearchDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput  # Beam 搜索解码器模型的输出
BeamSampleDecoderOnlyOutput = GenerateBeamDecoderOnlyOutput  # Beam 示例解码器模型的输出

BeamSearchEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput  # Beam 搜索编码器-解码器模型的输出
BeamSampleEncoderDecoderOutput = GenerateBeamEncoderDecoderOutput  # Beam 示例编码器-解码器模型的输出

GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]  # 贪婪搜索的输出类型
# Typing shortcuts for specific types of model outputs
SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]

# Typing shortcut for non-beam text generation output
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
# Typing shortcut for beam search text generation output
GenerateBeamOutput = Union[GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput]
# Typing shortcut for any text generation output
GenerateOutput = Union[GenerateNonBeamOutput, GenerateBeamOutput]


class GenerationMixin:
    """
    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
          `do_sample=False`
        - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
          `top_k>1`
        - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
          `do_sample=True`
        - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
          `do_sample=False`
        - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
          and `do_sample=True`
        - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
          and `num_beam_groups>1`
        - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
          `constraints!=None` or `force_words_ids!=None`
        - *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
            `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`

    You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
    learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
    """

    def prepare_inputs_for_generation(self, *args, **kwargs):
        # Raise an error if this method is not implemented in the subclass
        raise NotImplementedError(
            "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
        )

    def _prepare_model_inputs(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ):
        # Internal method for preparing model inputs for text generation
        ...

    def _maybe_initialize_input_ids_for_generation(
        self,
        inputs: Optional[torch.Tensor] = None,
        bos_token_id: Optional[int] = None,
        model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
    ):
        # Internal method to initialize input IDs for text generation if necessary
        ...
    ) -> torch.LongTensor:
        """Initializes input ids for generation, if necessary."""
        # 如果已经提供了输入,则直接返回输入
        if inputs is not None:
            return inputs

        # 获取模型关键字参数中的 encoder_outputs
        encoder_outputs = model_kwargs.get("encoder_outputs")
        # 如果模型是编码-解码模型且 encoder_outputs 不为空
        if self.config.is_encoder_decoder and encoder_outputs is not None:
            # 创建一个与 encoder_outputs 最后一层隐藏状态相同形状的输入 id 张量,填充值为 -100
            shape = encoder_outputs.last_hidden_state.size()[:-1]
            return torch.ones(shape, dtype=torch.long, device=self.device) * -100

        # 如果未提供 input_ids 且未定义 bos_token_id,则引发错误
        if bos_token_id is None:
            raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")

        # 如果 model_kwargs 中有某些张量,则可以从中推断出批量大小
        batch_size = 1
        for value in model_kwargs.values():
            if isinstance(value, torch.Tensor):
                batch_size = value.shape[0]
                break

        # 如果 model_kwargs 中包含 "inputs_embeds" 键
        if "inputs_embeds" in model_kwargs:
            # 返回一个形状为 (batch_size, 0) 的全 1 张量,dtype 为 torch.long
            return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
        # 否则返回一个形状为 (batch_size, 1) 的全 bos_token_id 值的张量,dtype 为 torch.long
        return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

    def _prepare_attention_mask_for_generation(
        self,
        inputs: torch.Tensor,
        pad_token_id: Optional[int],
        eos_token_id: Optional[Union[int, List[int]]],
    ) -> torch.LongTensor:
        # 检查输入是否为 input_ids 且已被填充,只有这种情况下才定义 attention_mask
        is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
        is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)

        # 如果输入是 input_ids 且已填充,并且填充标记不等于 eos_token_id,则返回 attention_mask
        if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
            return inputs.ne(pad_token_id).long()
        else:
            # 否则返回一个形状与 inputs 的前两维相同的全 1 张量,dtype 为 torch.long
            return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)

    def _prepare_encoder_decoder_kwargs_for_generation(
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
        # 1. 获取编码器
        encoder = self.get_encoder()

        # 2. 兼容加速大模型推断:确保编码器在与输入相同的设备上输出结果
        if hasattr(self, "hf_device_map"):
            # 如果编码器有 `_hf_hook` 属性,设置其 `io_same_device` 为 True
            if hasattr(encoder, "_hf_hook"):
                encoder._hf_hook.io_same_device = True
            # 否则,向编码器添加一个 AlignDevicesHook,设置 `io_same_device` 为 True
            else:
                add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

        # 3. 从模型参数中准备编码器的参数和关键字参数
        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
        # 从 `model_kwargs` 中选择与编码器相关的参数和值
        encoder_kwargs = {
            argument: value
            for argument, value in model_kwargs.items()
            if not any(argument.startswith(p) for p in irrelevant_prefix)
        }
        # 检查编码器的输入签名,确定是否支持 `kwargs` 或 `model_kwargs`
        encoder_signature = set(inspect.signature(encoder.forward).parameters)
        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
        if not encoder_accepts_wildcard:
            # 如果编码器不支持通配符参数,仅选择编码器签名中存在的参数和值
            encoder_kwargs = {
                argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
            }

        # 4. 确保编码器返回 `ModelOutput`
        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
        encoder_kwargs["return_dict"] = True
        encoder_kwargs[model_input_name] = inputs_tensor

        # 调用编码器并将结果保存在 `model_kwargs` 的 `encoder_outputs` 键中
        model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

        return model_kwargs

    # 准备用于生成的解码器输入 ID
    def _prepare_decoder_input_ids_for_generation(
        self,
        batch_size: int,
        model_input_name: str,
        model_kwargs: Dict[str, torch.Tensor],
        decoder_start_token_id: Union[int, List[int]] = None,
        bos_token_id: int = None,
        device: torch.device = None,
    ) -> Dict[str, torch.Tensor]:
        ...

    # 获取解码器起始标记 ID
    def _get_decoder_start_token_id(
        self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
    ) -> int:
        ...

    # 扩展用于生成的输入
    @staticmethod
    def _expand_inputs_for_generation(
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        input_ids: Optional[torch.LongTensor] = None,
        **model_kwargs,
    ) -> Dict[str, torch.Tensor]:
        ...
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
        # 定义函数签名,指定返回类型为元组,包含一个长整型张量和一个任意类型字典

        def _expand_dict_for_generation(dict_to_expand):
            # 为生成过程扩展字典中的张量
            for key in dict_to_expand:
                if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
                    dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
            return dict_to_expand

        if input_ids is not None:
            input_ids = input_ids.repeat_interleave(expand_size, dim=0)
            # 如果输入 ID 不为空,则按照指定的扩展大小在指定维度上重复扩展

        model_kwargs = _expand_dict_for_generation(model_kwargs)
        # 扩展模型参数字典中的张量

        if is_encoder_decoder:
            if model_kwargs.get("encoder_outputs") is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
            model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
            # 如果是编码器-解码器模型,确保编码器输出在模型参数中被定义,并进行扩展

        return input_ids, model_kwargs
        # 返回扩展后的输入 ID 和模型参数字典

    def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
        past_key_values = None
        if "past_key_values" in outputs:
            past_key_values = outputs.past_key_values
        elif "mems" in outputs:
            past_key_values = outputs.mems
        elif "past_buckets_states" in outputs:
            past_key_values = outputs.past_buckets_states
        # 从模型输出中提取过去的键-值对

        # Bloom fix: standardizes the cache format when requested
        if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
            batch_size = outputs.logits.shape[0]
            past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
            # 在请求时,如果需要,标准化缓存格式

        return past_key_values
        # 返回提取的过去键-值对

    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        is_encoder_decoder: bool = False,
        standardize_cache_format: bool = False,
        # 更新用于生成的模型参数字典
    ) -> Dict[str, Any]:
        # 更新 model_kwargs 中的 past_key_values,从模型输出中提取过去的键值
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )
        # 如果 outputs 有 state 属性,则更新 model_kwargs 中的 state
        if getattr(outputs, "state", None) is not None:
            model_kwargs["state"] = outputs.state

        # 更新 token_type_ids,使用最后一个值进行扩展
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

        # 如果不是 encoder-decoder 架构
        if not is_encoder_decoder:
            # 更新 attention_mask
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
                )
        else:
            # 更新 decoder_attention_mask
            if "decoder_attention_mask" in model_kwargs:
                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
                model_kwargs["decoder_attention_mask"] = torch.cat(
                    [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
                    dim=-1,
                )

        # 如果 model_kwargs 中存在 cache_position 并且不为 None,则更新 cache_position
        if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
            model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1

        # 返回更新后的 model_kwargs
        return model_kwargs

    # 抛出未实现错误,提示在当前类的模块中实现 _reorder_cache 函数以启用 beam search
    def _reorder_cache(self, past_key_values, beam_idx):
        raise NotImplementedError(
            f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to"
            f" enable beam search for {self.__class__}"
        )

    # 返回用于辅助生成的候选生成器
    def _get_candidate_generator(
        self,
        generation_config: GenerationConfig,
        input_ids: torch.LongTensor,
        inputs_tensor: torch.Tensor,
        assistant_model: "PreTrainedModel",
        logits_processor: LogitsProcessorList,
        model_kwargs: Dict,
    ) -> CandidateGenerator:
        """
        Returns the candidate generator to be used in `assisted_generation`
        """
        # 如果指定了 prompt_lookup_num_tokens,则使用 PromptLookupCandidateGenerator
        if generation_config.prompt_lookup_num_tokens is not None:
            candidate_generator = PromptLookupCandidateGenerator(
                num_output_tokens=generation_config.prompt_lookup_num_tokens,
                max_matching_ngram_size=generation_config.max_matching_ngram_size,
            )
        else:
            # 否则使用 AssistedCandidateGenerator
            candidate_generator = AssistedCandidateGenerator(
                input_ids=input_ids,
                assistant_model=assistant_model,
                generation_config=generation_config,
                logits_processor=logits_processor,
                model_kwargs=model_kwargs,
                inputs_tensor=inputs_tensor,
            )
        return candidate_generator
    def _get_logits_warper(
        self,
        generation_config: GenerationConfig,
    ) -> LogitsProcessorList:
        """
        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
        used for multinomial sampling.
        """

        # instantiate warpers list
        warpers = LogitsProcessorList()

        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
        # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
        if generation_config.num_beams > 1:
            if isinstance(generation_config.eos_token_id, list):
                min_tokens_to_keep = len(generation_config.eos_token_id) + 1
            else:
                min_tokens_to_keep = 2
        else:
            min_tokens_to_keep = 1

        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
        # all samplers can be found in `generation_utils_samplers.py`
        
        # Apply temperature warping if temperature is defined and not equal to 1.0
        if generation_config.temperature is not None and generation_config.temperature != 1.0:
            warpers.append(TemperatureLogitsWarper(generation_config.temperature))
        
        # Apply top-k warping if top-k is defined and not equal to 0
        if generation_config.top_k is not None and generation_config.top_k != 0:
            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
        
        # Apply top-p warping if top-p is defined and less than 1.0
        if generation_config.top_p is not None and generation_config.top_p < 1.0:
            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
        
        # Apply typical-p warping if typical-p is defined and less than 1.0
        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
            warpers.append(
                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
            )
        
        # Apply epsilon cutoff warping if epsilon cutoff is defined and within (0, 1)
        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
            warpers.append(
                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
            )
        
        # Apply eta cutoff warping if eta cutoff is defined and within (0, 1)
        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
            warpers.append(
                EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
            )
        
        # `LogitNormalization` should always be the last logit processor, when present
        # Apply logit normalization if renormalize_logits flag is True
        if generation_config.renormalize_logits is True:
            warpers.append(LogitNormalization())
        
        # Return the list of warpers containing all relevant LogitsWarper instances
        return warpers
    # 获取 logits 处理器函数,根据给定的配置和参数
    def _get_logits_processor(
        self,
        generation_config: GenerationConfig,  # 生成配置对象
        input_ids_seq_length: int,  # 输入的序列长度
        encoder_input_ids: torch.LongTensor,  # 编码器输入的张量
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],  # 可以使用的前缀令牌函数
        logits_processor: Optional[LogitsProcessorList],  # logits 处理器的可选列表
        model_kwargs: Optional[Dict[str, Any]] = None,  # 模型参数的可选字典,默认为空
        negative_prompt_ids: Optional[torch.Tensor] = None,  # 负面提示的可选张量,默认为空
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,  # 负面提示的注意力掩码,可选,默认为空
    ):
        # 定义 stopping_criteria 对象并初始化为空列表
        criteria = StoppingCriteriaList()
        
        # 如果生成配置中指定了最大长度
        if generation_config.max_length is not None:
            # 从模型配置中获取最大位置嵌入数
            max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
            # 向 criteria 中添加最大长度的停止条件
            criteria.append(
                MaxLengthCriteria(
                    max_length=generation_config.max_length,
                    max_position_embeddings=max_position_embeddings,
                )
            )
        
        # 如果生成配置中指定了最大时间
        if generation_config.max_time is not None:
            # 向 criteria 中添加最大时间的停止条件
            criteria.append(MaxTimeCriteria(max_time=generation_config.max_time))
        
        # 将自定义的 stopping_criteria 合并到 criteria 中
        criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
        
        # 返回最终的 criteria 列表
        return criteria

    # 合并默认列表和自定义列表的 logits 处理器或停止条件
    def _merge_criteria_processor_list(
        self,
        default_list: Union[LogitsProcessorList, StoppingCriteriaList],  # 默认的处理器或停止条件列表
        custom_list: Union[LogitsProcessorList, StoppingCriteriaList],  # 自定义的处理器或停止条件列表
    ) -> Union[LogitsProcessorList, StoppingCriteriaList]:  # 返回合并后的处理器或停止条件列表
        # 如果自定义列表为空,直接返回默认列表
        if len(custom_list) == 0:
            return default_list
        
        # 遍历默认列表
        for default in default_list:
            # 遍历自定义列表
            for custom in custom_list:
                # 如果自定义的对象类型和默认的对象类型相同
                if type(custom) is type(default):
                    # 确定对象类型是停止条件还是 logits 处理器
                    object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
                    # 抛出值错误,提示不允许自定义与默认相同类型的处理器或条件
                    raise ValueError(
                        f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
                        f" `.generate()`, but it has already been created with the values {default}. {default} has been"
                        " created by passing the corresponding arguments to generate or by the model's config default"
                        f" values. If you just want to change the default values of {object_type} consider passing"
                        f" them as arguments to `.generate()` instead of using a custom {object_type}."
                    )
        
        # 将自定义列表的内容扩展到默认列表中
        default_list.extend(custom_list)
        
        # 返回合并后的默认列表
        return default_list

    # 计算转移分数的函数
    def compute_transition_scores(
        self,
        sequences: torch.Tensor,  # 序列张量
        scores: Tuple[torch.Tensor],  # 分数元组
        beam_indices: Optional[torch.Tensor] = None,  # 光束索引的可选张量,默认为空
        normalize_logits: bool = False,  # 是否对 logits 进行归一化,默认为 False
    def _validate_model_class(self):
        """
        Confirms that the model class is compatible with generation. If not, raises an exception that points to the
        right class to use.
        """
        # 检查当前模型是否能够生成文本
        if not self.can_generate():
            # 可生成的模型映射列表
            generate_compatible_mappings = [
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
                MODEL_FOR_VISION_2_SEQ_MAPPING,
                MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
                MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
            ]
            generate_compatible_classes = set()
            # 遍历可生成的模型映射列表,获取支持的模型类名集合
            for model_mapping in generate_compatible_mappings:
                supported_models = model_mapping.get(type(self.config), default=None)
                if supported_models is not None:
                    generate_compatible_classes.add(supported_models.__name__)
            # 出现异常的错误信息
            exception_message = (
                f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
                "it doesn't have a language model head."
            )
            # 如果存在兼容的模型类名集合,则添加到异常信息中
            if generate_compatible_classes:
                exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
            # 抛出类型错误异常,包含详细的异常信息
            raise TypeError(exception_message)
    # 执行与生成长度相关的验证,包括警告和错误处理

    # 1. 针对参数化不良的最大长度警告
    if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
        # 如果使用了默认的 `max_length`(=20)来控制生成长度,会发出警告
        warnings.warn(
            f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
            "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
            "generation.",
            UserWarning,
        )
    
    # 如果输入的ids长度超过了指定的最大长度,会引发异常
    if input_ids_length >= generation_config.max_length:
        input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
        raise ValueError(
            f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
            f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
            " increasing `max_length` or, better yet, setting `max_new_tokens`."
        )

    # 2. 由于不可行的参数组合,发出最小长度警告
    min_length_error_suffix = (
        " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
        "increase the maximum length."
    )
    if has_default_max_length:
        min_length_error_suffix += (
            f" Note that `max_length` is set to {generation_config.max_length}, its default value."
        )
    
    # 如果设定了最小长度,并且该长度大于最大长度,则发出警告
    if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
        warnings.warn(
            f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
            f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
            UserWarning,
        )
    
    # 如果设置了最小新token数量,并且计算后的最小长度超过了最大长度,则发出警告
    if generation_config.min_new_tokens is not None:
        min_length = generation_config.min_new_tokens + input_ids_length
        if min_length > generation_config.max_length:
            warnings.warn(
                f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
                f"added to the prompt length ({input_ids_length}), is larger than"
                f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
                UserWarning,
            )
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """
        Generates sequences based on the provided inputs and configuration.

        Args:
            inputs (Optional[torch.Tensor]): Input tensor for generation.
            generation_config (Optional[GenerationConfig]): Configuration for generation.
            logits_processor (Optional[LogitsProcessorList]): Processors for logits during generation.
            stopping_criteria (Optional[StoppingCriteriaList]): Criteria for stopping generation.
            prefix_allowed_tokens_fn (Optional[Callable[[int, torch.Tensor], List[int]]]): Function to allow tokens during generation.
            synced_gpus (Optional[bool]): Whether to synchronize generation across GPUs.
            assistant_model (Optional["PreTrainedModel"]): Model used for generation assistance.
            streamer (Optional["BaseStreamer"]): Streamer for generation.
            negative_prompt_ids (Optional[torch.Tensor]): IDs for negative prompts.
            negative_prompt_attention_mask (Optional[torch.Tensor]): Attention mask for negative prompts.
            **kwargs: Additional keyword arguments.

        Returns:
            dict: Dictionary containing generated sequences and other relevant outputs.
        """
        ...

    def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
        """
        Returns whether there are still unfinished sequences on the specified device.

        Args:
            this_peer_finished (bool): Flag indicating if the current peer has finished generation.
            synced_gpus (bool): Whether generation is synchronized across GPUs.
            device (torch.device): Device on which generation is performed.

        Returns:
            bool: True if there are unfinished sequences, False otherwise.
        """
        if synced_gpus:
            # Under synced_gpus, ensure all GPUs complete their sequence generation.
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
            # Send 0.0 if this peer finished, 1.0 otherwise
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            # Check if all peers finished (sum should be 0.0 if all finished)
            if this_peer_finished_flag.item() == 0.0:
                return False
        elif this_peer_finished:
            return False
        return True

    def contrastive_search(self, *args, **kwargs):
        """
        Deprecated method for performing contrastive search. Use `generate` or a custom generation loop instead.

        Args:
            *args: Positional arguments passed to `_contrastive_search`.
            **kwargs: Keyword arguments passed to `_contrastive_search`.

        Returns:
            Any: Result from `_contrastive_search`.
        """
        logger.warning_once(
            "Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        return self._contrastive_search(*args, **kwargs)

    @torch.no_grad()
    def _contrastive_search(
        self,
        input_ids: torch.LongTensor,
        top_k: Optional[int] = 1,
        penalty_alpha: Optional[float] = 0,
        logits_processor: Optional[LogitsProcessorList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        sequential: Optional[bool] = None,
        **model_kwargs,
    ):
        """
        Performs contrastive search to generate sequences based on the input_ids and additional arguments.

        Args:
            input_ids (torch.LongTensor): Input tensor containing token IDs.
            top_k (Optional[int]): Number of top-k results to consider.
            penalty_alpha (Optional[float]): Penalty factor for contrastive search.
            logits_processor (Optional[LogitsProcessorList]): Processors for logits during contrastive search.
            logits_warper (Optional[LogitsProcessorList]): Processors for logits warping during contrastive search.
            stopping_criteria (Optional[StoppingCriteriaList]): Criteria for stopping contrastive search.
            pad_token_id (Optional[int]): Token ID for padding.
            eos_token_id (Optional[Union[int, List[int]]]): Token ID(s) for end-of-sequence.
            output_attentions (Optional[bool]): Whether to output attention weights.
            output_hidden_states (Optional[bool]): Whether to output hidden states.
            output_scores (Optional[bool]): Whether to output scores.
            output_logits (Optional[bool]): Whether to output logits.
            return_dict_in_generate (Optional[bool]): Whether to return results in a dictionary format.
            synced_gpus (bool): Whether generation is synchronized across GPUs.
            streamer (Optional["BaseStreamer"]): Streamer for contrastive search.
            sequential (Optional[bool]): Whether to generate sequentially.
            **model_kwargs: Additional keyword arguments.

        Returns:
            Any: Result of contrastive search, typically sequences or generated outputs.
        """
        ...
    # 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
    def greedy_search(self, *args, **kwargs):
        logger.warning_once(
            "Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        # 调用 `_greedy_search` 方法,并将所有传入的位置参数和关键字参数传递给它
        return self._greedy_search(*args, **kwargs)

    # 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
    def _greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ):
        # 方法实现略去,用于执行贪婪搜索算法或相关任务
        pass

    # 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
    def sample(self, *args, **kwargs):
        logger.warning_once(
            "Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        # 调用 `_sample` 方法,并将所有传入的位置参数和关键字参数传递给它
        return self._sample(*args, **kwargs)

    # 发出警告日志,提醒直接调用该方法已经被废弃,将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环代替。
    def _sample(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        **model_kwargs,
    ):
        # 方法实现略去,用于执行采样或相关生成任务
        pass
    def _temporary_reorder_cache(self, past_key_values, beam_idx):
        """
        Temporary function to handle the different types of cache reordering processes while we roll out `Cache`.

        TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need
        for this function, with `Cache.reorder_cache` being the sole remaining code path
        """
        # 获取当前类名的小写形式
        model_class = self.__class__.__name__.lower()
        
        # 异常情况1:处理使用传统缓存格式的模型的代码路径
        if isinstance(past_key_values, (tuple, list)):
            past_key_values = self._reorder_cache(past_key_values, beam_idx)
        
        # 异常情况2:处理具有不同缓存格式的模型。这些模型目前仅限于 `DynamicCache`,直到它们的缓存格式标准化为止。
        elif "bloom" in model_class or "gptbigcode" in model_class:
            if not isinstance(past_key_values, DynamicCache):
                raise ValueError(
                    f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
                    "legacy tuple format or `DynamicCache`"
                )
            past_key_values = self._reorder_cache(past_key_values, beam_idx)
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        
        # 标准代码路径:使用 `Cache.reorder_cache`
        else:
            past_key_values.reorder_cache(beam_idx)
        
        return past_key_values

    def beam_search(self, *args, **kwargs):
        logger.warning_once(
            "Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        return self._beam_search(*args, **kwargs)

    def _beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        sequential: Optional[bool] = None,
        **model_kwargs,
    ):
        """
        Perform beam search to generate sequences based on input_ids and beam_scorer.
        """
        logger.warning_once(
            "Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        return self._beam_search(*args, **kwargs)

    def beam_sample(self, *args, **kwargs):
        logger.warning_once(
            "Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        return self._beam_sample(*args, **kwargs)
    # 定义一个私有方法 `_beam_sample`,用于执行束搜索采样
    def _beam_sample(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        logits_warper: Optional[LogitsProcessorList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        **model_kwargs,
    ):
        # 具体功能的注释可以在方法内部详细描述
        pass

    # 警告用户 `group_beam_search` 方法即将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环
    def group_beam_search(self, *args, **kwargs):
        logger.warning_once(
            "Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        # 调用 `_group_beam_search` 方法来执行实际的束搜索操作
        return self._group_beam_search(*args, **kwargs)

    # 定义一个私有方法 `_group_beam_search`,用于执行束搜索
    def _group_beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        **model_kwargs,
    ):
        # 具体功能的注释可以在方法内部详细描述
        pass

    # 警告用户 `constrained_beam_search` 方法即将在 v4.41 版本中移除,建议使用 `generate` 方法或自定义生成循环
    def constrained_beam_search(self, *args, **kwargs):
        logger.warning_once(
            "Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
            "custom generation loop instead.",
        )
        # 调用 `_constrained_beam_search` 方法来执行实际的约束束搜索操作
        return self._constrained_beam_search(*args, **kwargs)

    # 定义一个私有方法 `_constrained_beam_search`,用于执行约束束搜索
    def _constrained_beam_search(
        self,
        input_ids: torch.LongTensor,
        constrained_beam_scorer: ConstrainedBeamSearchScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        output_logits: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = None,
        **model_kwargs,
    ):
        # 具体功能的注释可以在方法内部详细描述
        pass
    # 发出警告日志,提醒直接调用 `_assisted_decoding` 方法已不推荐,在 v4.41 版本中将被移除。建议使用 `generate` 方法或自定义生成循环。
    logger.warning_once(
        "Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
        "custom generation loop instead.",
    )
    # 调用 `_assisted_decoding` 方法,将所有传入的位置参数和关键字参数传递给它,并返回其结果。
    return self._assisted_decoding(*args, **kwargs)
def _speculative_sampling(
    candidate_input_ids,
    candidate_logits,
    candidate_length,
    new_logits,
    last_assistant_token_is_eos,
    max_matches,
):
    """
    Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
    the selected tokens, as well as the number of candidate matches.

    NOTE: Unless otherwise stated, the variable names match those in the paper.
    """
    # Selects the last `candidate_length` tokens from `candidate_input_ids`
    new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]

    # Converts logits to probabilities and extracts assistant (q_i) and model (p_i) probabilities for selected tokens
    q = candidate_logits.softmax(dim=-1)
    q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    p = new_logits.softmax(dim=-1)
    p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    probability_ratio = p_i / q_i

    # Determines which tokens to accept based on probability ratios
    r_i = torch.rand_like(probability_ratio)
    is_accepted = r_i <= probability_ratio

    # Computes the number of accepted tokens (`n_matches` in algorithm 1)
    n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()

    # Ensures the generated sequence does not exceed `max_matches` or end with an EOS token
    if last_assistant_token_is_eos and n_matches == candidate_length:
        # Adjusts `n_matches` if the sequence ends with an EOS token
        n_matches -= 1
        valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
    else:
        n_matches = min(n_matches, max_matches)

        # Selects the next token considering rejection and adjusts probabilities if needed
        gamma = min(candidate_logits.shape[1], max_matches)
        p_n_plus_1 = p[:, n_matches, :]
        if n_matches < gamma:
            q_n_plus_1 = q[:, n_matches, :]
            p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
            p_prime.div_(p_prime.sum())
        else:
            p_prime = p_n_plus_1
        t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

        # Constructs the final sequence of valid tokens
        if n_matches > 0:
            valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
        else:
            valid_tokens = t

    return valid_tokens, n_matches
    # 给定多个生成的标记的解码器注意力或隐藏状态,将其拆分成一个元组,其中每个成员对应于单个生成的标记。
    """
    # 兼容性调整:在我们的生成函数中,第一次迭代包含了关于提示的注意力/隐藏状态。
    if len(outputs) == 0:
        # 初始化一个空的元组
        new_tuple = ()
        # 遍历新输出的每一层
        for layer in new_outputs:
            # 如果是解码器的注意力,使用当前长度和最后一维的大小;否则使用整个层的大小
            last_dim_size = cur_len if is_decoder_attention else layer.shape[-1]
            # 将当前层的片段添加到新元组中
            new_tuple += (layer[..., :cur_len, :last_dim_size],)
        # 将新元组作为一个元素添加到输出元组中
        outputs += (new_tuple,)
        # 更新当前长度变量,因为第一次迭代包含了提示 + 1个生成的标记
        cur_len += 1
        # 更新添加的长度变量
        added_len -= cur_len
    
    # 对于每个额外添加的长度
    for i in range(added_len):
        # 初始化一个空的元组
        new_tuple = ()
        # 遍历新输出的每一层
        for layer in new_outputs:
            # 如果是解码器的注意力,使用当前长度加上i和最后一维的大小;否则使用整个层的大小
            last_dim_size = cur_len + i if is_decoder_attention else layer.shape[-1]
            # 将当前层的片段添加到新元组中
            new_tuple += (layer[..., i : i + 1, :last_dim_size],)
        # 将新元组作为一个元素添加到输出元组中
        outputs += (new_tuple,)
    # 返回输出元组
    return outputs
# 根据上下文隐藏状态的每个向量的L2范数归一化,使其长度为1,以便计算余弦相似度
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)

# 根据下一个隐藏状态的每个向量的L2范数归一化,使其长度为1,以便计算余弦相似度
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)

# 计算上下文隐藏状态与下一个隐藏状态之间的余弦相似度矩阵,将维度调整为[B*K, S]
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1)

# 在余弦相似度矩阵的最后一个维度上取最大值,得到每个样本的最大相似度,形状为[B*K]
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1)

# 将下一个顶部K个候选项的概率视图调整为一维数组,形状为[B*K]
next_top_k_probs = next_top_k_probs.view(-1)

# 计算对比分数,根据论文中的对比框架计算每个候选项的分数
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty

# 将对比分数按照beam_width分割,形状调整为[B, K]的张量
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width))

# 在每行中选择最高分数对应的索引,形状为[B]
_, selected_idx = contrastive_score.max(dim=-1)

# 返回每个批次中最佳候选项的索引
return selected_idx



# 处理数据分割的函数,根据数据类型分别处理不同情况的数据分割
def _split(data, full_batch_size: int, split_size: int = None):
    if data is None:
        # 如果数据为None,则返回与分割大小对应的None列表
        return [None] * (full_batch_size // split_size)
    if isinstance(data, torch.Tensor):
        # 如果数据为Tensor,则按照分割大小分割Tensor,返回Tensor列表
        return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
    elif isinstance(data, tuple):
        # 如果数据为元组,根据元组中元素的类型进行不同的分割处理
        if isinstance(data[0], tuple):
            # 如果元组中的元素也是元组,则按照分割大小分割每个元组中的Tensor,返回元组列表的元组列表
            return [
                tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
                for i in range(0, full_batch_size, split_size)
            ]
        else:
            # 如果元组中的元素不是元组,则按照分割大小分割每个Tensor,返回元组列表
            return [
                tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
                for i in range(0, full_batch_size, split_size)
            ]
    else:
        # 如果数据类型不符合预期,则引发值错误异常
        raise ValueError(f"Unexpected attribute type: {type(data)}")



# 将模型输入(可能是ModelOutput或Dict类型)按照指定的分割大小拆分成相同类型的对象列表
def _split_model_inputs(
    model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
) -> List[Union[ModelOutput, Dict]]:
    """
    Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
    size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
    previous forward pass.
    """
    # 如果 model_input 为 None,则返回一个 Nones 列表
    # 在 Whisper 中,encoder_outputs 为 None 时会发生这种情况
    if model_input is None:
        return [model_input] * (full_batch_size // split_size)
    # 从对象中推断出类
    model_output_cls = type(model_input)
    if (full_batch_size % split_size) != 0:
        # 如果 full_batch_size 不能被 split_size 整除,则引发 ValueError
        raise ValueError("`full_batch_size` must be divisible by `split_size`")

    if split_size > full_batch_size:
        # 如果 split_size 大于 full_batch_size,则引发 ValueError
        raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")

    # 用于拆分张量或张量的元组的辅助函数

    # 查找所有数据类字段(例如,last_hidden_state,pooler_output 等),并对它们进行拆分
    keys = (
        model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
    )
    # 仅保留在 model_input 中的键
    keys = [k for k in keys if k in model_input]
    # 在这里,我们可以有四种类型的值:张量、张量的元组和布尔值,以及 encoder_outputs,后者是一个 ModelOutput 对象。
    # 布尔值不应该被拆分,而应该为每个拆分复制
    bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
    keys_to_ignore = ["cache_position", "encoder_outputs"]
    non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]

    # 拆分张量和张量的元组
    data_split_list = [
        {k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
        for i in range(full_batch_size // split_size)
    ]
    # 布尔值是相同的,每个拆分中都会复制
    bool_data = {k: model_input[k] for k in bool_keys}
    # encoder_outputs 是一个 ModelOutput 对象,应该单独拆分
    if "encoder_outputs" in model_input:
        encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
        data_split_list = [
            {**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
        ]

    # 将列表中的每个字典转换为推断类的对象
    split_model_inputs: List[Union[ModelOutput, Dict]] = [
        model_output_cls(**data_split, **bool_data) for data_split in data_split_list
    ]

    return split_model_inputs
# 将给定的 ModelOutput 对象列表沿着 batch_size 维度堆叠起来。该函数推断出列表中的具体 ModelOutput 子类。
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
    """
    Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
    specific ModelOutput subclass from the list provided.
    """
    # 如果输入的列表为空,则抛出数值错误
    if not model_outputs:
        raise ValueError("Input list is empty.")

    # 推断出列表中第一个对象的类
    model_output_cls = type(model_outputs[0])

    # 确保所有对象都是同一类型
    if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
        raise ValueError("All elements in the list should be of the same type.")

    # 辅助函数,用于连接张量或张量元组
    def _concat(data):
        """
        Reverse of `_split` function above.
        """
        # 如果数据中任意元素为 None,则返回 None
        if any(data is None for data in data):
            return None
        # 如果第一个元素是 torch.Tensor
        if isinstance(data[0], torch.Tensor):
            # 沿着 dim=0 连接所有张量
            return torch.cat(data, dim=0)
        # 如果第一个元素是元组
        elif isinstance(data[0], tuple):
            # 如果元组的元素也是元组(例如我们之前的示例中的 past_key_values)
            if isinstance(data[0][0], tuple):
                # 对每个元组的每个元素,沿着 dim=0 连接所有张量
                return tuple(
                    tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
                    for i in range(len(data[0]))
                )
            else:
                # 否则,对元组中的每个元素,沿着 dim=0 连接所有张量
                return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
        # 如果第一个元素是整数或浮点数,返回一个张量
        elif isinstance(data[0], (int, float)):
            return torch.tensor(data)
        else:
            # 抛出数值错误,显示意外的属性类型
            raise ValueError(f"Unexpected attribute type: {type(data[0])}")

    # 使用字典推导式,从所有对象中收集属性并连接它们
    concatenated_data = {
        # 对于每个属性 k,在所有模型输出对象中,获取属性 k 的值并连接它们
        k: _concat([getattr(model_output, k) for model_output in model_outputs])
        for k in model_output_cls.__dataclass_fields__.keys()
    }

    # 返回一个新的推断类对象,其中包含连接后的属性
    return model_output_cls(**concatenated_data)

.\generation\__init__.py

# 引入类型检查模块的条件语句,用于确定当前环境是否支持类型检查
from typing import TYPE_CHECKING

# 引入必要的依赖和模块,用于检查和延迟加载
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available

# 定义需要导入的模块结构字典
_import_structure = {
    "configuration_utils": ["GenerationConfig", "GenerationMode"],  # 配置工具模块
    "streamers": ["TextIteratorStreamer", "TextStreamer"],  # 数据流处理模块
}

# 尝试导入 torch 模块,如果不可用则抛出 OptionalDependencyNotAvailable 异常
try:
    if not is_torch_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    pass
else:
    # 如果 torch 可用,则添加以下模块到导入结构字典
    _import_structure["beam_constraints"] = [
        "Constraint",  # 约束条件模块
        "ConstraintListState",  # 约束条件列表状态模块
        "DisjunctiveConstraint",  # 分离约束模块
        "PhrasalConstraint",  # 短语约束模块
    ]
    _import_structure["beam_search"] = [
        "BeamHypotheses",  # 搜索假设模块
        "BeamScorer",  # 搜索评分器模块
        "BeamSearchScorer",  # 搜索评分器模块
        "ConstrainedBeamSearchScorer",  # 约束搜索评分器模块
    ]
    _import_structure["candidate_generator"] = [
        "AssistedCandidateGenerator",  # 候选生成辅助模块
        "CandidateGenerator",  # 候选生成器模块
        "PromptLookupCandidateGenerator",  # 提示查找候选生成器模块
    ]
    _import_structure["logits_process"] = [
        "AlternatingCodebooksLogitsProcessor",  # 替换码本逻辑处理器模块
        "ClassifierFreeGuidanceLogitsProcessor",  # 免分类器引导逻辑处理器模块
        "EncoderNoRepeatNGramLogitsProcessor",  # 编码器无重复 n-gram 逻辑处理器模块
        "EncoderRepetitionPenaltyLogitsProcessor",  # 编码器重复惩罚逻辑处理器模块
        "EpsilonLogitsWarper",  # Epsilon 逻辑扭曲器模块
        "EtaLogitsWarper",  # Eta 逻辑扭曲器模块
        "ExponentialDecayLengthPenalty",  # 指数衰减长度惩罚模块
        "ForcedBOSTokenLogitsProcessor",  # 强制 BOS 标记逻辑处理器模块
        "ForcedEOSTokenLogitsProcessor",  # 强制 EOS 标记逻辑处理器模块
        "ForceTokensLogitsProcessor",  # 强制令牌逻辑处理器模块
        "HammingDiversityLogitsProcessor",  # 汉明多样性逻辑处理器模块
        "InfNanRemoveLogitsProcessor",  # 无穷大和无效值移除逻辑处理器模块
        "LogitNormalization",  # Logit 归一化模块
        "LogitsProcessor",  # Logits 处理器模块
        "LogitsProcessorList",  # Logits 处理器列表模块
        "LogitsWarper",  # Logits 扭曲器模块
        "MinLengthLogitsProcessor",  # 最小长度逻辑处理器模块
        "MinNewTokensLengthLogitsProcessor",  # 最小新令牌长度逻辑处理器模块
        "NoBadWordsLogitsProcessor",  # 无不良词语逻辑处理器模块
        "NoRepeatNGramLogitsProcessor",  # 无重复 n-gram 逻辑处理器模块
        "PrefixConstrainedLogitsProcessor",  # 前缀约束逻辑处理器模块
        "RepetitionPenaltyLogitsProcessor",  # 重复惩罚逻辑处理器模块
        "SequenceBiasLogitsProcessor",  # 序列偏置逻辑处理器模块
        "SuppressTokensLogitsProcessor",  # 抑制令牌逻辑处理器模块
        "SuppressTokensAtBeginLogitsProcessor",  # 在开头抑制令牌逻辑处理器模块
        "TemperatureLogitsWarper",  # 温度逻辑扭曲器模块
        "TopKLogitsWarper",  # Top-K 逻辑扭曲器模块
        "TopPLogitsWarper",  # Top-P 逻辑扭曲器模块
        "TypicalLogitsWarper",  # 典型逻辑扭曲器模块
        "UnbatchedClassifierFreeGuidanceLogitsProcessor",  # 未分批免分类器引导逻辑处理器模块
        "WhisperTimeStampLogitsProcessor",  # Whisper 时间戳逻辑处理器模块
    ]
    # 将停止条件模块的类名列表添加到_import_structure字典中的"stopping_criteria"键下
    _import_structure["stopping_criteria"] = [
        "MaxNewTokensCriteria",  # 最大新标记数条件
        "MaxLengthCriteria",  # 最大长度条件
        "MaxTimeCriteria",  # 最大时间条件
        "StoppingCriteria",  # 停止条件基类
        "StoppingCriteriaList",  # 停止条件列表
        "validate_stopping_criteria",  # 验证停止条件函数
    ]
    
    # 将实用工具模块的类名列表添加到_import_structure字典中的"utils"键下
    _import_structure["utils"] = [
        "GenerationMixin",  # 生成混合类
        "GreedySearchEncoderDecoderOutput",  # 贪婪搜索编码器解码器输出
        "GreedySearchDecoderOnlyOutput",  # 贪婪搜索仅解码器输出
        "SampleEncoderDecoderOutput",  # 样本编码器解码器输出
        "SampleDecoderOnlyOutput",  # 样本仅解码器输出
        "BeamSearchEncoderDecoderOutput",  # Beam搜索编码器解码器输出
        "BeamSearchDecoderOnlyOutput",  # Beam搜索仅解码器输出
        "BeamSampleEncoderDecoderOutput",  # Beam样本编码器解码器输出
        "BeamSampleDecoderOnlyOutput",  # Beam样本仅解码器输出
        "ContrastiveSearchEncoderDecoderOutput",  # 对比搜索编码器解码器输出
        "ContrastiveSearchDecoderOnlyOutput",  # 对比搜索仅解码器输出
        "GenerateBeamDecoderOnlyOutput",  # 生成Beam解码器输出
        "GenerateBeamEncoderDecoderOutput",  # 生成Beam编码器解码器输出
        "GenerateDecoderOnlyOutput",  # 生成仅解码器输出
        "GenerateEncoderDecoderOutput",  # 生成编码器解码器输出
    ]
# 尝试检查是否可以使用 TensorFlow 库
try:
    # 如果 TensorFlow 不可用,引发 OptionalDependencyNotAvailable 异常
    if not is_tf_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果引发了 OptionalDependencyNotAvailable 异常,什么都不做,继续执行下一个代码块
    pass
else:
    # 如果没有引发异常,则将以下 TensorFlow 相关类添加到 _import_structure 字典中
    _import_structure["tf_logits_process"] = [
        "TFForcedBOSTokenLogitsProcessor",
        "TFForcedEOSTokenLogitsProcessor",
        "TFForceTokensLogitsProcessor",
        "TFLogitsProcessor",
        "TFLogitsProcessorList",
        "TFLogitsWarper",
        "TFMinLengthLogitsProcessor",
        "TFNoBadWordsLogitsProcessor",
        "TFNoRepeatNGramLogitsProcessor",
        "TFRepetitionPenaltyLogitsProcessor",
        "TFSuppressTokensAtBeginLogitsProcessor",
        "TFSuppressTokensLogitsProcessor",
        "TFTemperatureLogitsWarper",
        "TFTopKLogitsWarper",
        "TFTopPLogitsWarper",
    ]
    _import_structure["tf_utils"] = [
        "TFGenerationMixin",
        "TFGreedySearchDecoderOnlyOutput",
        "TFGreedySearchEncoderDecoderOutput",
        "TFSampleEncoderDecoderOutput",
        "TFSampleDecoderOnlyOutput",
        "TFBeamSearchEncoderDecoderOutput",
        "TFBeamSearchDecoderOnlyOutput",
        "TFBeamSampleEncoderDecoderOutput",
        "TFBeamSampleDecoderOnlyOutput",
        "TFContrastiveSearchEncoderDecoderOutput",
        "TFContrastiveSearchDecoderOnlyOutput",
    ]

# 尝试检查是否可以使用 Flax 库
try:
    # 如果 Flax 不可用,引发 OptionalDependencyNotAvailable 异常
    if not is_flax_available():
        raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
    # 如果引发了 OptionalDependencyNotAvailable 异常,什么都不做,继续执行下一个代码块
    pass
else:
    # 如果没有引发异常,则将以下 Flax 相关类添加到 _import_structure 字典中
    _import_structure["flax_logits_process"] = [
        "FlaxForcedBOSTokenLogitsProcessor",
        "FlaxForcedEOSTokenLogitsProcessor",
        "FlaxForceTokensLogitsProcessor",
        "FlaxLogitsProcessor",
        "FlaxLogitsProcessorList",
        "FlaxLogitsWarper",
        "FlaxMinLengthLogitsProcessor",
        "FlaxSuppressTokensAtBeginLogitsProcessor",
        "FlaxSuppressTokensLogitsProcessor",
        "FlaxTemperatureLogitsWarper",
        "FlaxTopKLogitsWarper",
        "FlaxTopPLogitsWarper",
        "FlaxWhisperTimeStampLogitsProcessor",
    ]
    _import_structure["flax_utils"] = [
        "FlaxGenerationMixin",
        "FlaxGreedySearchOutput",
        "FlaxSampleOutput",
        "FlaxBeamSearchOutput",
    ]

# 如果在类型检查模式下
if TYPE_CHECKING:
    # 从相关模块导入特定类和函数
    from .configuration_utils import GenerationConfig, GenerationMode
    from .streamers import TextIteratorStreamer, TextStreamer
    
    try:
        # 如果 Torch 不可用,引发 OptionalDependencyNotAvailable 异常
        if not is_torch_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果引发了 OptionalDependencyNotAvailable 异常,什么都不做
        pass
    # 否则,从本地的beam_constraints模块中导入多个约束类和对象
    from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
    # 从本地的beam_search模块中导入多个与beam搜索相关的类和对象
    from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
    # 从本地的candidate_generator模块中导入多个候选生成器类和对象
    from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
    # 从本地的logits_process模块中导入多个logits处理类和对象
    from .logits_process import (
        AlternatingCodebooksLogitsProcessor,               # 处理交替码簿的logits处理器
        ClassifierFreeGuidanceLogitsProcessor,            # 无分类器指导的logits处理器
        EncoderNoRepeatNGramLogitsProcessor,              # 编码器不重复n-gram的logits处理器
        EncoderRepetitionPenaltyLogitsProcessor,         # 编码器重复惩罚的logits处理器
        EpsilonLogitsWarper,                             # Epsilon的logits调节器
        EtaLogitsWarper,                                 # Eta的logits调节器
        ExponentialDecayLengthPenalty,                    # 指数衰减长度惩罚
        ForcedBOSTokenLogitsProcessor,                   # 强制BOS标记的logits处理器
        ForcedEOSTokenLogitsProcessor,                   # 强制EOS标记的logits处理器
        ForceTokensLogitsProcessor,                      # 强制token的logits处理器
        HammingDiversityLogitsProcessor,                 # Hamming多样性的logits处理器
        InfNanRemoveLogitsProcessor,                     # 移除Inf和NaN的logits处理器
        LogitNormalization,                              # logits归一化处理器
        LogitsProcessor,                                 # logits处理器基类
        LogitsProcessorList,                             # logits处理器列表
        LogitsWarper,                                    # logits调节器基类
        MinLengthLogitsProcessor,                        # 最小长度的logits处理器
        MinNewTokensLengthLogitsProcessor,               # 最小新token长度的logits处理器
        NoBadWordsLogitsProcessor,                       # 无不良词语的logits处理器
        NoRepeatNGramLogitsProcessor,                    # 不重复n-gram的logits处理器
        PrefixConstrainedLogitsProcessor,                # 前缀约束的logits处理器
        RepetitionPenaltyLogitsProcessor,                # 重复惩罚的logits处理器
        SequenceBiasLogitsProcessor,                     # 序列偏置的logits处理器
        SuppressTokensAtBeginLogitsProcessor,            # 在开头抑制token的logits处理器
        SuppressTokensLogitsProcessor,                   # 抑制token的logits处理器
        TemperatureLogitsWarper,                         # 温度的logits调节器
        TopKLogitsWarper,                                # 前K个logits调节器
        TopPLogitsWarper,                                # Top-P的logits调节器
        TypicalLogitsWarper,                             # 典型的logits调节器
        UnbatchedClassifierFreeGuidanceLogitsProcessor,  # 无批处理分类器指导的logits处理器
        WhisperTimeStampLogitsProcessor,                 # Whisper时间戳的logits处理器
    )
    # 从本地的stopping_criteria模块中导入多个停止标准类和函数
    from .stopping_criteria import (
        MaxLengthCriteria,          # 最大长度标准
        MaxNewTokensCriteria,       # 最大新token标准
        MaxTimeCriteria,            # 最大时间标准
        StoppingCriteria,           # 停止标准基类
        StoppingCriteriaList,       # 停止标准列表
        validate_stopping_criteria, # 验证停止标准的函数
    )
    # 从本地的utils模块中导入多个实用类和对象,用于不同类型的解码和编码输出
    from .utils import (
        BeamSampleDecoderOnlyOutput,                    # Beam采样仅解码器输出
        BeamSampleEncoderDecoderOutput,                 # Beam采样编码器-解码器输出
        BeamSearchDecoderOnlyOutput,                    # Beam搜索仅解码器输出
        BeamSearchEncoderDecoderOutput,                 # Beam搜索编码器-解码器输出
        ContrastiveSearchDecoderOnlyOutput,             # 对比搜索仅解码器输出
        ContrastiveSearchEncoderDecoderOutput,          # 对比搜索编码器-解码器输出
        GenerateBeamDecoderOnlyOutput,                  # 生成Beam仅解码器输出
        GenerateBeamEncoderDecoderOutput,               # 生成Beam编码器-解码器输出
        GenerateDecoderOnlyOutput,                      # 生成仅解码器输出
        GenerateEncoderDecoderOutput,                   # 生成编码器-解码器输出
        GenerationMixin,                               # 生成Mixin
        GreedySearchDecoderOnlyOutput,                  # 贪婪搜索仅解码器输出
        GreedySearchEncoderDecoderOutput,               # 贪婪搜索编码器-解码器输出
        SampleDecoderOnlyOutput,                        # 采样仅解码器输出
        SampleEncoderDecoderOutput,                     # 采样编码器-解码器输出
    )

    # 尝试检查是否存在TensorFlow依赖,若不存在则引发OptionalDependencyNotAvailable异常
    try:
        if not is_tf_available():  # 如果TensorFlow不可用
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:  # 捕获OptionalDependencyNotAvailable异常
        pass  # 什么都不做,继续执行后续代码
    else:
        # 导入针对 TensorFlow 的 logits 处理模块
        from .tf_logits_process import (
            TFForcedBOSTokenLogitsProcessor,         # 强制开头标记的 logits 处理器
            TFForcedEOSTokenLogitsProcessor,         # 强制结尾标记的 logits 处理器
            TFForceTokensLogitsProcessor,            # 强制标记的 logits 处理器
            TFLogitsProcessor,                       # logits 处理器基类
            TFLogitsProcessorList,                   # logits 处理器列表
            TFLogitsWarper,                          # logits 调整器
            TFMinLengthLogitsProcessor,              # 最小长度的 logits 处理器
            TFNoBadWordsLogitsProcessor,             # 无不良词语的 logits 处理器
            TFNoRepeatNGramLogitsProcessor,          # 无重复 n-gram 的 logits 处理器
            TFRepetitionPenaltyLogitsProcessor,      # 重复惩罚的 logits 处理器
            TFSuppressTokensAtBeginLogitsProcessor,  # 开头抑制标记的 logits 处理器
            TFSuppressTokensLogitsProcessor,         # 抑制标记的 logits 处理器
            TFTemperatureLogitsWarper,               # 温度调整器
            TFTopKLogitsWarper,                      # 基于 top-k 的 logits 调整器
            TFTopPLogitsWarper,                      # 基于 top-p 的 logits 调整器
        )
        # 导入针对 TensorFlow 的实用工具模块
        from .tf_utils import (
            TFBeamSampleDecoderOnlyOutput,           # 仅解码器的 Beam Sample 输出
            TFBeamSampleEncoderDecoderOutput,        # 编码器解码器的 Beam Sample 输出
            TFBeamSearchDecoderOnlyOutput,           # 仅解码器的 Beam Search 输出
            TFBeamSearchEncoderDecoderOutput,        # 编码器解码器的 Beam Search 输出
            TFContrastiveSearchDecoderOnlyOutput,    # 仅解码器的对比搜索输出
            TFContrastiveSearchEncoderDecoderOutput, # 编码器解码器的对比搜索输出
            TFGenerationMixin,                       # 生成混合类
            TFGreedySearchDecoderOnlyOutput,         # 仅解码器的 Greedy Search 输出
            TFGreedySearchEncoderDecoderOutput,      # 编码器解码器的 Greedy Search 输出
            TFSampleDecoderOnlyOutput,               # 仅解码器的 Sample 输出
            TFSampleEncoderDecoderOutput,            # 编码器解码器的 Sample 输出
        )

    try:
        # 检查 Flax 是否可用,如果不可用则抛出异常
        if not is_flax_available():
            raise OptionalDependencyNotAvailable()
    except OptionalDependencyNotAvailable:
        # 如果 Flax 不可用,忽略该异常
        pass
    else:
        # 导入针对 Flax 的 logits 处理模块
        from .flax_logits_process import (
            FlaxForcedBOSTokenLogitsProcessor,           # 强制开头标记的 logits 处理器
            FlaxForcedEOSTokenLogitsProcessor,           # 强制结尾标记的 logits 处理器
            FlaxForceTokensLogitsProcessor,              # 强制标记的 logits 处理器
            FlaxLogitsProcessor,                         # logits 处理器基类
            FlaxLogitsProcessorList,                     # logits 处理器列表
            FlaxLogitsWarper,                            # logits 调整器
            FlaxMinLengthLogitsProcessor,                # 最小长度的 logits 处理器
            FlaxSuppressTokensAtBeginLogitsProcessor,    # 开头抑制标记的 logits 处理器
            FlaxSuppressTokensLogitsProcessor,           # 抑制标记的 logits 处理器
            FlaxTemperatureLogitsWarper,                 # 温度调整器
            FlaxTopKLogitsWarper,                        # 基于 top-k 的 logits 调整器
            FlaxTopPLogitsWarper,                        # 基于 top-p 的 logits 调整器
            FlaxWhisperTimeStampLogitsProcessor,         # Whisper 时间戳的 logits 处理器
        )
        # 导入针对 Flax 的实用工具模块
        from .flax_utils import (
            FlaxBeamSearchOutput,                       # Flax Beam Search 输出
            FlaxGenerationMixin,                        # 生成混合类
            FlaxGreedySearchOutput,                     # Flax Greedy Search 输出
            FlaxSampleOutput,                           # Flax Sample 输出
        )
else:
    # 导入 sys 模块,用于动态设置当前模块的引用
    import sys
    # 设置当前模块的引用,将其指向 _LazyModule 实例化的对象
    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

.\generation_flax_utils.py

# 设置编码格式为UTF-8,确保脚本可以正确处理各种字符集
# 版权声明,指出代码的版权归属及使用限制
# 版权声明,版权属于Google AI Flax团队和HuggingFace Inc.团队,以及NVIDIA CORPORATION
#
# 根据Apache License, Version 2.0的许可证,除非符合许可证条款,否则不得使用此文件
# 可以在以下链接获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,否则不得将此软件分发
# 此软件按“原样”提供,没有任何明示或暗示的担保或条件
# 请参阅许可证以了解具体的使用条款和限制

# 导入警告模块,用于发出警告信息
import warnings

# 从.generation模块中导入FlaxGenerationMixin类
from .generation import FlaxGenerationMixin

# 定义一个名为FlaxGenerationMixin的类,继承自FlaxGenerationMixin类
class FlaxGenerationMixin(FlaxGenerationMixin):
    # 在导入时发出警告信息,提醒该导入方式即将被弃用
    warnings.warn(
        "Importing `FlaxGenerationMixin` from `src/transformers/generation_flax_utils.py` is deprecated and will "
        "be removed in Transformers v4.40. Import as `from transformers import FlaxGenerationMixin` instead.",
        FutureWarning,
    )

.\generation_tf_utils.py

# 设置文件编码为 UTF-8
# 版权声明,版权归谷歌 AI 语言团队和 HuggingFace 公司所有,以及 NVIDIA 公司所有
# 根据 Apache 许可证 2.0 版本,可以在遵守许可证的前提下使用本文件
# 可以在以下网址获取许可证的副本:http://www.apache.org/licenses/LICENSE-2.0
# 如果不符合适用法律或书面协议的要求,本软件按 "原样" 分发,没有任何形式的担保或条件
# 有关详细信息,请参阅许可证

# 导入警告模块
import warnings

# 从指定模块导入 TFGenerationMixin 类
# 这里出现了一个命名冲突,因为在当前作用域中的 TFGenerationMixin 已经存在
# 为了避免冲突,应该考虑重命名或者避免同名导入

# 创建 TFGenerationMixin 的子类,警告在导入时显示
# 警告提示,从 'src/transformers/generation_tf_utils.py' 导入 'TFGenerationMixin' 已经被弃用
# 在 Transformers v4.40 中将会移除,建议改为从 'transformers' 直接导入 'TFGenerationMixin'
# 使用 FutureWarning 类型显示警告信息
warnings.warn(
    "Importing `TFGenerationMixin` from `src/transformers/generation_tf_utils.py` is deprecated and will "
    "be removed in Transformers v4.40. Import as `from transformers import TFGenerationMixin` instead.",
    FutureWarning,
)

.\generation_utils.py

# 导入警告模块
import warnings

# 从generation模块中导入GenerationMixin类
from .generation import GenerationMixin

# 定义GenerationMixin类,继承自GenerationMixin类
class GenerationMixin(GenerationMixin):
    # 在导入时发出警告,提示正在从旧路径导入GenerationMixin,该功能将在未来版本中移除
    warnings.warn(
        "Importing `GenerationMixin` from `src/transformers/generation_utils.py` is deprecated and will "
        "be removed in Transformers v4.40. Import as `from transformers import GenerationMixin` instead.",
        FutureWarning,
    )

.\hf_argparser.py

# 版权声明及许可信息
#
# 在 Apache 许可证 2.0 版本下使用此文件的声明,表示除非符合许可证,否则不得使用此文件。
# 您可以在以下网址获得许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据“原样”分发的软件是根据许可证分发的,
# 没有任何形式的明示或暗示担保或条件。
# 有关更多详细信息,请参阅许可证。
#

import dataclasses  # 导入 dataclasses 模块
import json  # 导入 json 模块
import sys  # 导入 sys 模块
import types  # 导入 types 模块
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError  # 从 argparse 模块导入指定内容
from copy import copy  # 导入 copy 函数
from enum import Enum  # 导入 Enum 类
from inspect import isclass  # 导入 isclass 函数
from pathlib import Path  # 导入 Path 类
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints  # 导入 typing 模块中指定内容

import yaml  # 导入 yaml 模块


DataClass = NewType("DataClass", Any)  # 定义 DataClass 类型别名
DataClassType = NewType("DataClassType", Any)  # 定义 DataClassType 类型别名


def string_to_bool(v):
    """
    解析字符串表示的布尔值。

    Args:
        v (str): 输入的字符串值。

    Returns:
        bool: 如果字符串表示真值,则返回 True;否则返回 False。

    Raises:
        ArgumentTypeError: 如果无法解析字符串为布尔值,抛出异常。
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ArgumentTypeError(
            f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
        )


def make_choice_type_function(choices: list) -> Callable[[str], Any]:
    """
    创建从每个选择字符串表示到实际值的映射函数。用于支持单个参数的多个值类型。

    Args:
        choices (list): 选择列表。

    Returns:
        Callable[[str], Any]: 从字符串表示到每个选择的实际值的映射函数。
    """
    str_to_choice = {str(choice): choice for choice in choices}
    return lambda arg: str_to_choice.get(arg, arg)


def HfArg(
    *,
    aliases: Union[str, List[str]] = None,
    help: str = None,
    default: Any = dataclasses.MISSING,
    default_factory: Callable[[], Any] = dataclasses.MISSING,
    metadata: dict = None,
    **kwargs,
) -> dataclasses.Field:
    """
    参数辅助函数,允许使用简洁的语法为 `HfArgumentParser` 创建数据类字段。

    Example comparing the use of `HfArg` and `dataclasses.field`:
    示例比较了 `HfArg` 和 `dataclasses.field` 的使用:
    ```
    @dataclass
    class Args:
        regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
        hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
    ```
    """
    pass  # HfArg 函数主体为空,实现在示例中展示
    def make_field(aliases=None, help=None, default=dataclasses.MISSING, default_factory=dataclasses.MISSING, metadata=None, **kwargs):
        """
        Construct a `dataclasses.Field` object with specified properties.
    
        Args:
            aliases (Union[str, List[str]], optional):
                Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
                Defaults to None.
            help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
            default (Any, optional):
                Default value for the argument. If not default or default_factory is specified, the argument is required.
                Defaults to dataclasses.MISSING.
            default_factory (Callable[[], Any], optional):
                The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
                default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
                Defaults to dataclasses.MISSING.
            metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
    
        Returns:
            Field: A `dataclasses.Field` with the desired properties.
        """
        if metadata is None:
            # 如果 metadata 参数为 None,则创建一个空的字典,以避免在函数签名中使用默认参数,因为字典是可变的且在函数调用间共享
            metadata = {}
        if aliases is not None:
            # 如果传入了 aliases 参数,则将其添加到 metadata 字典中
            metadata["aliases"] = aliases
        if help is not None:
            # 如果传入了 help 参数,则将其添加到 metadata 字典中
            metadata["help"] = help
    
        # 创建并返回一个 `dataclasses.Field` 对象,传入指定的参数和 metadata 字典
        return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
# 定义一个名为 HfArgumentParser 的类,它是 argparse.ArgumentParser 的子类
class HfArgumentParser(ArgumentParser):
    """
    This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.

    The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
    arguments to the parser after initialization and you'll get the output back after parsing as an additional
    namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
    """

    # 定义一个名为 dataclass_types 的实例变量,用来存储数据类类型的可迭代对象
    dataclass_types: Iterable[DataClassType]

    # 初始化方法,接收 dataclass_types 和其他参数
    def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
        """
        Args:
            dataclass_types:
                Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
            kwargs (`Dict[str, Any]`, *optional*):
                Passed to `argparse.ArgumentParser()` in the regular way.
        """
        # 如果 kwargs 中没有指定 formatter_class,则设置为 ArgumentDefaultsHelpFormatter
        if "formatter_class" not in kwargs:
            kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
        # 调用父类 ArgumentParser 的初始化方法,传入 kwargs
        super().__init__(**kwargs)
        # 如果 dataclass_types 是单个数据类而不是列表,则转换为列表
        if dataclasses.is_dataclass(dataclass_types):
            dataclass_types = [dataclass_types]
        # 将 dataclass_types 转换为列表后存储在 self.dataclass_types 中
        self.dataclass_types = list(dataclass_types)
        # 遍历每个数据类类型,并为其添加参数到 argparse.ArgumentParser 实例中
        for dtype in self.dataclass_types:
            self._add_dataclass_arguments(dtype)

    # 静态方法,用来添加数据类的参数到 argparse.ArgumentParser 实例中
    @staticmethod
    # 将数据类的参数添加到命令行解析器中
    def _add_dataclass_arguments(self, dtype: DataClassType):
        # 检查数据类是否定义了参数组名称,如果是,则创建一个新的参数组;否则,使用当前解析器
        if hasattr(dtype, "_argument_group_name"):
            parser = self.add_argument_group(dtype._argument_group_name)
        else:
            parser = self

        try:
            # 获取数据类字段的类型提示字典
            type_hints: Dict[str, type] = get_type_hints(dtype)
        except NameError:
            # 如果类型解析失败,通常是由于数据类不在全局范围内定义或使用了延迟注释的特性
            raise RuntimeError(
                f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
                "removing line of `from __future__ import annotations` which opts in Postponed "
                "Evaluation of Annotations (PEP 563)"
            )
        except TypeError as ex:
            # 当 Python 版本低于 3.10 且涉及到 union 类型时,给出详细的错误信息和建议
            if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
                python_version = ".".join(map(str, sys.version_info[:3]))
                raise RuntimeError(
                    f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
                    "line of `from __future__ import annotations` which opts in union types as "
                    "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
                    "support Python versions that lower than 3.10, you need to use "
                    "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
                    "`X | None`."
                ) from ex
            raise

        # 遍历数据类的字段,并解析每个需要初始化的字段
        for field in dataclasses.fields(dtype):
            if not field.init:
                continue  # 跳过不需要初始化的字段
            # 将字段的类型设定为从类型提示中获取的类型
            field.type = type_hints[field.name]
            # 调用私有方法,将数据类字段解析到命令行解析器中
            self._parse_dataclass_field(parser, field)

    # 解析命令行参数到数据类对象中
    def parse_args_into_dataclasses(
        self,
        args=None,
        return_remaining_strings=False,
        look_for_args_file=True,
        args_filename=None,
        args_file_flag=None,
    def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
        """
        Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
        types.

        Args:
            args (`dict`):
                dict containing config values
            allow_extra_keys (`bool`, *optional*, defaults to `False`):
                Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.

        Returns:
            Tuple consisting of:
                - the dataclass instances in the same order as they were passed to the initializer.
        """
        # 获取所有传入参数字典的键集合
        unused_keys = set(args.keys())
        # 初始化空列表,用于存储解析后的数据类实例
        outputs = []
        # 遍历数据类类型列表
        for dtype in self.dataclass_types:
            # 获取数据类字段的名称集合,仅包括可以初始化的字段
            keys = {f.name for f in dataclasses.fields(dtype) if f.init}
            # 从参数字典中选取与数据类字段匹配的键值对
            inputs = {k: v for k, v in args.items() if k in keys}
            # 从未使用的键集合中移除已使用的键
            unused_keys.difference_update(inputs.keys())
            # 使用选取的键值对初始化数据类对象
            obj = dtype(**inputs)
            # 将初始化后的数据类对象添加到输出列表
            outputs.append(obj)
        # 如果不允许额外的键存在且有未使用的键,抛出异常
        if not allow_extra_keys and unused_keys:
            raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
        # 返回包含所有数据类实例的元组
        return tuple(outputs)

    def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
        """
        Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
        dataclass types.

        Args:
            json_file (`str` or `os.PathLike`):
                File name of the json file to parse
            allow_extra_keys (`bool`, *optional*, defaults to `False`):
                Defaults to False. If False, will raise an exception if the json file contains keys that are not
                parsed.

        Returns:
            Tuple consisting of:
                - the dataclass instances in the same order as they were passed to the initializer.
        """
        # 打开并读取 JSON 文件
        with open(Path(json_file), encoding="utf-8") as open_json_file:
            data = json.loads(open_json_file.read())
        # 使用 parse_dict 方法解析 JSON 数据,并返回结果
        outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
        # 返回包含所有数据类实例的元组
        return tuple(outputs)
    # 定义一个方法用于解析 YAML 文件,并返回一个元组,其中包含数据类实例。
    def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
        """
        Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
        dataclass types.

        Args:
            yaml_file (`str` or `os.PathLike`):
                File name of the yaml file to parse
            allow_extra_keys (`bool`, *optional*, defaults to `False`):
                Defaults to False. If False, will raise an exception if the json file contains keys that are not
                parsed.

        Returns:
            Tuple consisting of:
                - the dataclass instances in the same order as they were passed to the initializer.
        """
        # 使用 pathlib 读取 YAML 文件的文本内容,然后通过 yaml.safe_load 转换为 Python 对象
        outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
        # 返回一个包含所有数据类实例的元组
        return tuple(outputs)

.\hyperparameter_search.py

# 从 integrations 模块中导入必要的函数和变量
from .integrations import (
    is_optuna_available,
    is_ray_tune_available,
    is_sigopt_available,
    is_wandb_available,
    run_hp_search_optuna,
    run_hp_search_ray,
    run_hp_search_sigopt,
    run_hp_search_wandb,
)
# 从 trainer_utils 模块中导入必要的类和函数
from .trainer_utils import (
    HPSearchBackend,
    default_hp_space_optuna,
    default_hp_space_ray,
    default_hp_space_sigopt,
    default_hp_space_wandb,
)
# 从 utils 模块中导入 logging 函数
from .utils import logging

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)

# 定义超参数搜索后端基类
class HyperParamSearchBackendBase:
    name: str
    pip_package: str = None

    @staticmethod
    def is_available():
        # 抽象方法,子类需要实现该方法来检查后端是否可用
        raise NotImplementedError

    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 抽象方法,子类需要实现该方法来执行超参数搜索
        raise NotImplementedError

    def default_hp_space(self, trial):
        # 抽象方法,子类需要实现该方法来定义默认的超参数空间
        raise NotImplementedError

    def ensure_available(self):
        # 确保后端可用,否则抛出运行时异常
        if not self.is_available():
            raise RuntimeError(
                f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
            )

    @classmethod
    def pip_install(cls):
        # 返回安装当前后端所需的 pip 命令字符串
        return f"`pip install {cls.pip_package or cls.name}`"


# 定义 Optuna 后端类,继承自 HyperParamSearchBackendBase
class OptunaBackend(HyperParamSearchBackendBase):
    name = "optuna"

    @staticmethod
    def is_available():
        # 检查 Optuna 是否可用
        return is_optuna_available()

    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 使用 Optuna 执行超参数搜索
        return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)

    def default_hp_space(self, trial):
        # 返回 Optuna 的默认超参数空间
        return default_hp_space_optuna(trial)


# 定义 Ray Tune 后端类,继承自 HyperParamSearchBackendBase
class RayTuneBackend(HyperParamSearchBackendBase):
    name = "ray"
    pip_package = "'ray[tune]'"

    @staticmethod
    def is_available():
        # 检查 Ray Tune 是否可用
        return is_ray_tune_available()

    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 使用 Ray Tune 执行超参数搜索
        return run_hp_search_ray(trainer, n_trials, direction, **kwargs)

    def default_hp_space(self, trial):
        # 返回 Ray Tune 的默认超参数空间
        return default_hp_space_ray(trial)


# 定义 SigOpt 后端类,继承自 HyperParamSearchBackendBase
class SigOptBackend(HyperParamSearchBackendBase):
    name = "sigopt"

    @staticmethod
    def is_available():
        # 检查 SigOpt 是否可用
        return is_sigopt_available()

    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 使用 SigOpt 执行超参数搜索
        return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)

    def default_hp_space(self, trial):
        # 返回 SigOpt 的默认超参数空间
        return default_hp_space_sigopt(trial)


# 定义 Wandb 后端类,继承自 HyperParamSearchBackendBase
class WandbBackend(HyperParamSearchBackendBase):
    name = "wandb"

    @staticmethod
    def is_available():
        # 检查 Wandb 是否可用
        return is_wandb_available()

    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 使用 Wandb 执行超参数搜索
        return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)

    def default_hp_space(self, trial):
        # 返回 Wandb 的默认超参数空间
        return default_hp_space_wandb(trial)
    # 定义静态方法,用于检查是否安装了 Weights & Biases 库
    @staticmethod
    def is_available():
        # 调用 is_wandb_available 函数,检查 Weights & Biases 库是否可用
        return is_wandb_available()

    # 定义方法,用于运行超参数搜索
    def run(self, trainer, n_trials: int, direction: str, **kwargs):
        # 调用 run_hp_search_wandb 函数,运行基于 Weights & Biases 的超参数搜索
        return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)

    # 定义方法,返回默认的超参数空间
    def default_hp_space(self, trial):
        # 调用 default_hp_space_wandb 函数,返回基于 Weights & Biases 的默认超参数空间
        return default_hp_space_wandb(trial)
# 创建一个字典,将各个超参数搜索后端与其名称关联起来
ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
    HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
}

# 定义一个函数,用于获取默认的超参数搜索后端名称
def default_hp_search_backend() -> str:
    # 获取所有可用的超参数搜索后端
    available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
    
    # 如果至少有一个可用的后端,则选择第一个作为默认值
    if len(available_backends) > 0:
        name = available_backends[0].name
        
        # 如果有多个可用的后端,记录日志并使用第一个作为默认
        if len(available_backends) > 1:
            logger.info(
                f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
            )
        
        # 返回选定的后端名称
        return name
    
    # 如果没有可用的后端,则抛出运行时错误,并给出安装信息
    raise RuntimeError(
        "No hyperparameter search backend available.\n"
        + "\n".join(
            f" - To install {backend.name} run {backend.pip_install()}"
            for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
        )
    )

.\image_processing_utils.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy  # 导入深拷贝模块
import json  # 导入 JSON 模块
import os  # 导入操作系统功能模块
import warnings  # 导入警告模块
from io import BytesIO  # 从 io 模块导入 BytesIO 类
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union  # 导入类型提示模块

import numpy as np  # 导入 NumPy 模块
import requests  # 导入请求模块

from .dynamic_module_utils import custom_object_save  # 从当前包导入动态模块工具中的 custom_object_save 函数
from .feature_extraction_utils import BatchFeature as BaseBatchFeature  # 从当前包导入特征提取工具中的 BatchFeature 类并重命名为 BaseBatchFeature
from .image_transforms import center_crop, normalize, rescale  # 从当前包导入图像转换模块中的三个函数
from .image_utils import ChannelDimension  # 从当前包导入图像工具模块中的 ChannelDimension 类
from .utils import (
    IMAGE_PROCESSOR_NAME,  # 从当前包导入工具模块中的 IMAGE_PROCESSOR_NAME 常量
    PushToHubMixin,  # 从当前包导入工具模块中的 PushToHubMixin 类
    add_model_info_to_auto_map,  # 从当前包导入工具模块中的 add_model_info_to_auto_map 函数
    cached_file,  # 从当前包导入工具模块中的 cached_file 函数
    copy_func,  # 从当前包导入工具模块中的 copy_func 函数
    download_url,  # 从当前包导入工具模块中的 download_url 函数
    is_offline_mode,  # 从当前包导入工具模块中的 is_offline_mode 函数
    is_remote_url,  # 从当前包导入工具模块中的 is_remote_url 函数
    is_vision_available,  # 从当前包导入工具模块中的 is_vision_available 函数
    logging,  # 从当前包导入工具模块中的 logging 模块
)


if is_vision_available():  # 如果视觉功能可用
    from PIL import Image  # 从 PIL 模块导入 Image 类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


# TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
# We override the class string here, but logic is the same.
class BatchFeature(BaseBatchFeature):
    r"""
    Holds the output of the image processor specific `__call__` methods.

    This class is derived from a python dictionary and can be used as a dictionary.

    Args:
        data (`dict`):
            Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
        tensor_type (`Union[None, str, TensorType]`, *optional*):
            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
            initialization.
    """


# TODO: (Amy) - factor out the common parts of this and the feature extractor
class ImageProcessingMixin(PushToHubMixin):
    """
    This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
    extractors.
    """

    _auto_class = None
    # 初始化方法,用于设置对象的属性
    def __init__(self, **kwargs):
        """Set elements of `kwargs` as attributes."""
        # 由于图片处理现在使用 `XXXImageProcessor`,不再使用 `XXXFeatureExtractor`,因此删除此属性
        kwargs.pop("feature_extractor_type", None)
        
        # 将 "processor_class" 弹出并保存为私有属性 `_processor_class`
        self._processor_class = kwargs.pop("processor_class", None)
        
        # 遍历剩余的关键字参数,设置为对象的属性
        for key, value in kwargs.items():
            try:
                setattr(self, key, value)
            except AttributeError as err:
                # 如果设置属性失败,则记录错误日志并抛出异常
                logger.error(f"Can't set {key} with value {value} for {self}")
                raise err

    # 设置处理器类的方法,将传入的字符串参数 `processor_class` 设置为 `_processor_class` 属性
    def _set_processor_class(self, processor_class: str):
        """Sets processor class as an attribute."""
        self._processor_class = processor_class

    # 类方法,用于从预训练模型名或路径创建实例
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        cache_dir: Optional[Union[str, os.PathLike]] = None,
        force_download: bool = False,
        local_files_only: bool = False,
        token: Optional[Union[str, bool]] = None,
        revision: str = "main",
        **kwargs,
    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
        Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
        [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the image processor JSON file will be saved (will be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
        """
        use_auth_token = kwargs.pop("use_auth_token", None)  # 获取并移除 use_auth_token 参数

        if use_auth_token is not None:  # 如果 use_auth_token 不为 None,发出警告
            warnings.warn(
                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
                FutureWarning,
            )
            if kwargs.get("token", None) is not None:  # 如果同时指定了 token 和 use_auth_token,则抛出错误
                raise ValueError(
                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
                )
            kwargs["token"] = use_auth_token  # 将 use_auth_token 赋给 token 参数

        if os.path.isfile(save_directory):  # 如果 save_directory 是一个文件路径,则抛出错误
            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

        os.makedirs(save_directory, exist_ok=True)  # 创建 save_directory 目录,如果不存在则创建

        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)  # 获取并移除 commit_message 参数
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])  # 获取并移除 repo_id 参数,如果不存在则默认为 save_directory 的最后一部分名称
            repo_id = self._create_repo(repo_id, **kwargs)  # 创建或获取指定名称的 repository

            # 获取保存目录中文件的时间戳列表
            files_timestamps = self._get_files_timestamps(save_directory)

        # 如果有自定义配置 (_auto_class 不为 None),将当前对象以及其配置保存到目录中
        if self._auto_class is not None:
            custom_object_save(self, save_directory, config=self)

        # 将图像处理器对象保存为 JSON 文件
        output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
        self.to_json_file(output_image_processor_file)
        logger.info(f"Image processor saved in {output_image_processor_file}")

        if push_to_hub:
            # 上传修改后的文件到指定的 repository
            self._upload_modified_files(
                save_directory,
                repo_id,
                files_timestamps,
                commit_message=commit_message,
                token=kwargs.get("token"),
            )

        # 返回保存的文件路径列表
        return [output_image_processor_file]

    @classmethod
    @classmethod
    def get_image_processor_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ):
        """
        Creates a dictionary of parameters (`image_processor_dict`) needed to instantiate an image processor.

        Args:
            cls: Class method descriptor.
            pretrained_model_name_or_path (Union[str, os.PathLike]):
                Name or path of the pretrained model for the image processor.
            kwargs (Dict[str, Any]):
                Additional keyword arguments to customize the image processor.

        Returns:
            Dict[str, Any]: Dictionary of parameters (`image_processor_dict`) required to instantiate
                            the image processor.
        """
        image_processor_dict = {
            "pretrained_model_name_or_path": pretrained_model_name_or_path,
            **kwargs
        }
        return image_processor_dict

    @classmethod
    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
        """
        Instantiates an image processor object from a dictionary of parameters.

        Args:
            image_processor_dict (Dict[str, Any]):
                Dictionary containing parameters to instantiate the image processor.
                Typically obtained from a pretrained checkpoint using `to_dict` method.
            kwargs (Dict[str, Any]):
                Additional parameters to initialize the image processor object.

        Returns:
            ImageProcessingMixin: The instantiated image processor object.
        """
        image_processor_dict = image_processor_dict.copy()
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)

        # Ensure `size` and `crop_size` are correctly set from kwargs if provided
        if "size" in kwargs and "size" in image_processor_dict:
            image_processor_dict["size"] = kwargs.pop("size")
        if "crop_size" in kwargs and "crop_size" in image_processor_dict:
            image_processor_dict["crop_size"] = kwargs.pop("crop_size")

        # Instantiate the image processor object
        image_processor = cls(**image_processor_dict)

        # Update image_processor attributes with remaining kwargs if applicable
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(image_processor, key):
                setattr(image_processor, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)

        # Log information about the instantiated image processor
        logger.info(f"Image processor {image_processor}")

        # Return the instantiated image processor object with optional unused kwargs
        if return_unused_kwargs:
            return image_processor, kwargs
        else:
            return image_processor

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes the instance attributes of this image processor to a Python dictionary.

        Returns:
            Dict[str, Any]: Dictionary containing all attributes of the image processor instance.
        """
        output = copy.deepcopy(self.__dict__)
        output["image_processor_type"] = self.__class__.__name__

        return output
    def from_json_file(cls, json_file: Union[str, os.PathLike]):
        """
        从包含参数的 JSON 文件路径实例化一个 `~image_processing_utils.ImageProcessingMixin` 类型的图像处理器。

        Args:
            json_file (`str` or `os.PathLike`):
                包含参数的 JSON 文件路径。

        Returns:
            `~image_processing_utils.ImageProcessingMixin` 类型的图像处理器:从指定 JSON 文件实例化的图像处理器对象。
        """
        with open(json_file, "r", encoding="utf-8") as reader:
            text = reader.read()
        # 将 JSON 文本解析为字典
        image_processor_dict = json.loads(text)
        # 使用解析出的字典参数实例化当前类对象
        return cls(**image_processor_dict)

    def to_json_string(self) -> str:
        """
        将当前实例序列化为 JSON 字符串。

        Returns:
            `str`: 包含当前特征提取器实例所有属性的 JSON 格式字符串。
        """
        # 将当前实例转换为字典形式
        dictionary = self.to_dict()

        # 如果值为 numpy 数组,则转换为列表形式
        for key, value in dictionary.items():
            if isinstance(value, np.ndarray):
                dictionary[key] = value.tolist()

        # 确保私有名称 "_processor_class" 被正确保存为 "processor_class"
        _processor_class = dictionary.pop("_processor_class", None)
        if _processor_class is not None:
            dictionary["processor_class"] = _processor_class

        # 将字典转换为格式化的 JSON 字符串,并进行排序
        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        将当前实例保存到 JSON 文件中。

        Args:
            json_file_path (`str` or `os.PathLike`):
                将保存此图像处理器实例参数的 JSON 文件路径。
        """
        with open(json_file_path, "w", encoding="utf-8") as writer:
            # 将实例转换为 JSON 字符串并写入文件
            writer.write(self.to_json_string())

    def __repr__(self):
        """
        返回当前实例的字符串表示形式。

        Returns:
            `str`: 包含当前实例 JSON 格式化字符串的类名。
        """
        return f"{self.__class__.__name__} {self.to_json_string()}"

    @classmethod
    def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
        """
        使用给定的自动类注册此类。这仅适用于自定义图像处理器,因为库中的图像处理器已与 `AutoImageProcessor` 映射。

        <Tip warning={true}>
        此 API 是实验性的,可能在未来版本中有些微的破坏性更改。
        </Tip>

        Args:
            auto_class (`str` or `type`, *optional*, 默认为 `"AutoImageProcessor"`):
                要将此新图像处理器注册到的自动类。
        """
        if not isinstance(auto_class, str):
            auto_class = auto_class.__name__

        # 导入自动模块
        import transformers.models.auto as auto_module

        # 检查是否存在指定的自动类
        if not hasattr(auto_module, auto_class):
            raise ValueError(f"{auto_class} 不是有效的自动类。")

        # 将自动类名称存储在 `_auto_class` 属性中
        cls._auto_class = auto_class
    def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
        """
        Convert a single or a list of URLs into corresponding `PIL.Image` objects.

        If a single URL is passed, the return value will be a single object. If a list is passed, a list of objects is
        returned.
        """
        # 设置 HTTP 请求的头部信息,模拟浏览器行为
        headers = {
            "User-Agent": (
                "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
                " Safari/537.36"
            )
        }
        
        # 如果传入的参数是列表,则递归调用 fetch_images 处理列表中的每个 URL
        if isinstance(image_url_or_urls, list):
            return [self.fetch_images(x) for x in image_url_or_urls]
        # 如果传入的参数是字符串,则发送 HTTP 请求获取图片内容,并返回 PIL.Image 对象
        elif isinstance(image_url_or_urls, str):
            # 发送带有自定义头部信息的 HTTP GET 请求
            response = requests.get(image_url_or_urls, stream=True, headers=headers)
            # 如果响应状态码不是 200,则抛出异常
            response.raise_for_status()
            # 将响应内容封装为 PIL.Image 对象并返回
            return Image.open(BytesIO(response.content))
        else:
            # 如果传入的既不是字符串也不是列表,则抛出值错误异常
            raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
class BaseImageProcessor(ImageProcessingMixin):
    # 初始化函数,调用父类的初始化方法
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    # 调用对象时调用的方法,用于预处理单张或批量图片
    def __call__(self, images, **kwargs) -> BatchFeature:
        """Preprocess an image or a batch of images."""
        return self.preprocess(images, **kwargs)

    # 预处理方法的抽象定义,子类必须实现具体逻辑
    def preprocess(self, images, **kwargs) -> BatchFeature:
        raise NotImplementedError("Each image processor must implement its own preprocess method")

    # 图片按比例缩放的方法
    def rescale(
        self,
        image: np.ndarray,
        scale: float,
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Rescale an image by a scale factor. image = image * scale.

        Args:
            image (`np.ndarray`):
                Image to rescale.
            scale (`float`):
                The scaling factor to rescale pixel values by.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The rescaled image.
        """
        return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)

    # 图片标准化的方法
    def normalize(
        self,
        image: np.ndarray,
        mean: Union[float, Iterable[float]],
        std: Union[float, Iterable[float]],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Normalize an image by subtracting mean and dividing by standard deviation.

        Args:
            image (`np.ndarray`):
                Image to normalize.
            mean (`float` or `Iterable[float]`):
                Mean value(s) for normalization.
            std (`float` or `Iterable[float]`):
                Standard deviation value(s) for normalization.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The normalized image.
        """
        return normalize(image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs)
    ) -> np.ndarray:
        """
        Normalize an image. image = (image - image_mean) / image_std.

        Args:
            image (`np.ndarray`):
                Image to normalize.
            mean (`float` or `Iterable[float]`):
                Image mean to use for normalization.
            std (`float` or `Iterable[float]`):
                Image standard deviation to use for normalization.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: The normalized image.
        """
        # 调用 `normalize` 函数对图像进行归一化处理,并返回处理后的图像
        return normalize(
            image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
        )

    def center_crop(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        data_format: Optional[Union[str, ChannelDimension]] = None,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
        ):
        """
        Perform center cropping on the image.

        Args:
            image (`np.ndarray`):
                Image to crop.
            size (`Dict[str, int]`):
                Dictionary containing the target size for cropping, with keys 'height' and 'width'.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.

        Returns:
            `np.ndarray`: Cropped image.
        """
        # 执行图像的中心裁剪操作,并返回裁剪后的图像
        # 使用给定的尺寸参数对图像进行中心裁剪
        return center_crop(
            image, size=size, data_format=data_format, input_data_format=input_data_format, **kwargs
        )
    ) -> np.ndarray:
        """
        Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
        any edge, the image is padded with 0's and then center cropped.

        Args:
            image (`np.ndarray`):
                Image to center crop.
            size (`Dict[str, int]`):
                Size of the output image.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format for the output image. If unset, the channel dimension format of the input
                image is used. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            input_data_format (`ChannelDimension` or `str`, *optional*):
                The channel dimension format for the input image. If unset, the channel dimension format is inferred
                from the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
        """
        # 根据传入的 size 参数,获取确保其为字典格式的大小信息
        size = get_size_dict(size)
        # 检查 size 字典中是否包含 'height' 和 'width' 键,若不包含则引发 ValueError 异常
        if "height" not in size or "width" not in size:
            raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
        # 调用 center_crop 函数,对输入的 image 进行中心裁剪,并返回裁剪后的图像
        return center_crop(
            image,
            size=(size["height"], size["width"]),
            data_format=data_format,
            input_data_format=input_data_format,
            **kwargs,
        )
# 定义一个元组,包含多个集合,每个集合都是合法的尺寸字典的键集合
VALID_SIZE_DICT_KEYS = ({"height", "width"}, {"shortest_edge"}, {"shortest_edge", "longest_edge"}, {"longest_edge"})


def is_valid_size_dict(size_dict):
    # 判断输入的 size_dict 是否为字典类型,如果不是则返回 False
    if not isinstance(size_dict, dict):
        return False

    # 获取 size_dict 的键集合
    size_dict_keys = set(size_dict.keys())
    # 遍历预定义的合法尺寸字典键集合
    for allowed_keys in VALID_SIZE_DICT_KEYS:
        # 如果 size_dict 的键集合与某个预定义的合法键集合相同,则返回 True
        if size_dict_keys == allowed_keys:
            return True
    # 如果遍历完所有合法键集合后未找到匹配的,则返回 False
    return False


def convert_to_size_dict(
    size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
):
    # 默认情况下,如果 size 是整数且 default_to_square 为 True,则返回一个表示正方形尺寸的字典
    if isinstance(size, int) and default_to_square:
        if max_size is not None:
            raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
        return {"height": size, "width": size}
    
    # 在其他配置下,如果 size 是整数且 default_to_square 为 False,则返回一个表示最短边长度的字典
    elif isinstance(size, int) and not default_to_square:
        size_dict = {"shortest_edge": size}
        if max_size is not None:
            size_dict["longest_edge"] = max_size
        return size_dict
    
    # 如果 size 是元组且 height_width_order 为 True,则返回一个表示高度和宽度的字典
    elif isinstance(size, (tuple, list)) and height_width_order:
        return {"height": size[0], "width": size[1]}
    
    # 如果 size 是元组且 height_width_order 为 False,则返回一个表示高度和宽度的字典(顺序相反)
    elif isinstance(size, (tuple, list)) and not height_width_order:
        return {"height": size[1], "width": size[0]}
    
    # 如果 size 为 None 且 max_size 不为 None,则返回一个表示最长边长度的字典
    elif size is None and max_size is not None:
        if default_to_square:
            raise ValueError("Cannot specify both default_to_square=True and max_size")
        return {"longest_edge": max_size}

    # 如果 size 不满足以上任何条件,则抛出异常
    raise ValueError(f"Could not convert size input to size dict: {size}")


def get_size_dict(
    size: Union[int, Iterable[int], Dict[str, int]] = None,
    max_size: Optional[int] = None,
    height_width_order: bool = True,
    default_to_square: bool = True,
    param_name="size",
) -> dict:
    """
    Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
    compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
    width) or (width, height) format.

    - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
    size[0]}` if `height_width_order` is `False`.
    - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
    - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
      is set, it is added to the dict as `{"longest_edge": max_size}`.
    """
    # 调用 convert_to_size_dict 函数,将 size 转换为合适的尺寸字典
    return convert_to_size_dict(size, max_size, default_to_square, height_width_order)
    """
    Casts the `size` parameter into a standardized size dictionary.

    Args:
        size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
            The `size` parameter to be cast into a size dictionary.
        max_size (`Optional[int]`, *optional*):
            The `max_size` parameter to be cast into a size dictionary.
        height_width_order (`bool`, *optional*, defaults to `True`):
            If `size` is a tuple, specifies whether it's in (height, width) or (width, height) order.
        default_to_square (`bool`, *optional*, defaults to `True`):
            If `size` is an int, specifies whether to default to a square image or not.
    """
    # 如果 `size` 不是字典类型,则调用函数将其转换为标准化的大小字典
    if not isinstance(size, dict):
        size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
        # 记录日志,指出参数应该是一个包含指定键集合的字典,如果不是则进行了转换
        logger.info(
            f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
            f" Converted to {size_dict}.",
        )
    else:
        size_dict = size

    # 检查生成的大小字典是否有效,如果不是则抛出 ValueError 异常
    if not is_valid_size_dict(size_dict):
        raise ValueError(
            f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
        )
    # 返回标准化后的大小字典
    return size_dict
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
    """
    Selects the best resolution from a list of possible resolutions based on the original size.

    This is done by calculating the effective and wasted resolution for each possible resolution.

    The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.

    Args:
        original_size (tuple):
            The original size of the image in the format (height, width).
        possible_resolutions (list):
            A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].

    Returns:
        tuple: The best fit resolution in the format (height, width).
    """
    # 解包原始尺寸
    original_height, original_width = original_size
    # 初始化最佳匹配为None
    best_fit = None
    # 初始化最大有效分辨率为0
    max_effective_resolution = 0
    # 初始化最小浪费分辨率为无穷大
    min_wasted_resolution = float("inf")

    # 遍历可能的分辨率
    for height, width in possible_resolutions:
        # 计算缩放比例
        scale = min(width / original_width, height / original_height)
        # 计算缩小后的宽度和高度
        downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
        # 计算有效分辨率
        effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
        # 计算浪费分辨率
        wasted_resolution = (width * height) - effective_resolution

        # 更新最佳匹配条件:如果有效分辨率大于最大有效分辨率,或者有效分辨率相等且浪费分辨率小于最小浪费分辨率
        if effective_resolution > max_effective_resolution or (
            effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
        ):
            max_effective_resolution = effective_resolution
            min_wasted_resolution = wasted_resolution
            best_fit = (height, width)

    # 返回最佳匹配分辨率
    return best_fit

# 下面是一个稍微不同的注释块
ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
if ImageProcessingMixin.push_to_hub.__doc__ is not None:
    # 格式化文档字符串,替换对象描述中的占位符
    ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
        object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
    )

.\image_transforms.py

# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np

from .image_utils import (
    ChannelDimension,
    ImageInput,
    get_channel_dimension_axis,
    get_image_size,
    infer_channel_dimension_format,
)
from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
from .utils.import_utils import (
    is_flax_available,
    is_tf_available,
    is_torch_available,
    is_vision_available,
    requires_backends,
)

if is_vision_available():
    import PIL

    from .image_utils import PILImageResampling

if is_torch_available():
    import torch

if is_tf_available():
    import tensorflow as tf

if is_flax_available():
    import jax.numpy as jnp


def to_channel_dimension_format(
    image: np.ndarray,
    channel_dim: Union[ChannelDimension, str],
    input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
    """
    Converts `image` to the channel dimension format specified by `channel_dim`.

    Args:
        image (`numpy.ndarray`):
            The image to have its channel dimension set.
        channel_dim (`ChannelDimension`):
            The channel dimension format to use.
        input_channel_dim (`ChannelDimension`, *optional*):
            The channel dimension format of the input image. If not provided, it will be inferred from the input image.

    Returns:
        `np.ndarray`: The image with the channel dimension set to `channel_dim`.
    """
    if not isinstance(image, np.ndarray):
        raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

    if input_channel_dim is None:
        input_channel_dim = infer_channel_dimension_format(image)

    target_channel_dim = ChannelDimension(channel_dim)
    if input_channel_dim == target_channel_dim:
        return image

    if target_channel_dim == ChannelDimension.FIRST:
        image = image.transpose((2, 0, 1))  # Reorder dimensions to put channels first
    elif target_channel_dim == ChannelDimension.LAST:
        image = image.transpose((1, 2, 0))  # Reorder dimensions to put channels last
    else:
        raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))

    return image


def rescale(
    image: np.ndarray,
    scale: float,
    data_format: Optional[ChannelDimension] = None,
    dtype: np.dtype = np.float32,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
    """
    Rescales the input `image` by a factor of `scale`.

    Args:
        image (`numpy.ndarray`):
            The image to be rescaled.
        scale (`float`):
            The scaling factor to be applied to the image.
        data_format (`ChannelDimension`, *optional*):
            The desired channel dimension format of the output image.
        dtype (`np.dtype`, *optional*):
            The desired data type of the output image.
        input_data_format (`str` or `ChannelDimension`, *optional*):
            The channel dimension format of the input image. If not provided, it will be inferred.

    Returns:
        `np.ndarray`: The rescaled image.
    """
    # 按比例 `scale` 重新调整 `image` 的大小。

    # 检查输入参数 `image` 是否为 `np.ndarray` 类型,否则引发 ValueError 异常
    if not isinstance(image, np.ndarray):
        raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

    # 将 `image` 按比例 `scale` 进行重新调整
    rescaled_image = image * scale

    # 如果提供了 `data_format` 参数,则将 `rescaled_image` 转换为指定的通道维度格式
    if data_format is not None:
        rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)

    # 将 `rescaled_image` 转换为指定的数据类型 `dtype`
    rescaled_image = rescaled_image.astype(dtype)

    # 返回重新调整大小后的图像 `rescaled_image`
    return rescaled_image
# 检查输入的图像是否需要在转换为 PIL 图像之前进行重新缩放
def _rescale_for_pil_conversion(image):
    if image.dtype == np.uint8:
        # 如果图像类型为 np.uint8,则无需重新缩放
        do_rescale = False
    elif np.allclose(image, image.astype(int)):
        if np.all(0 <= image) and np.all(image <= 255):
            # 如果图像的所有值都在 [0, 255] 范围内,则无需重新缩放
            do_rescale = False
        else:
            # 抛出异常,因为图像包含超出 [0, 255] 范围的值
            raise ValueError(
                "The image to be converted to a PIL image contains values outside the range [0, 255], "
                f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
            )
    elif np.all(0 <= image) and np.all(image <= 1):
        # 如果图像的所有值都在 [0, 1] 范围内,则需要重新缩放
        do_rescale = True
    else:
        # 抛出异常,因为图像包含超出 [0, 1] 范围的值
        raise ValueError(
            "The image to be converted to a PIL image contains values outside the range [0, 1], "
            f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
        )
    return do_rescale


# 将输入的图像转换为 PIL.Image.Image 格式,并且如果需要,则重新缩放并将通道维度移到最后一个维度
def to_pil_image(
    image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
    do_rescale: Optional[bool] = None,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image":
    """
    Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
    needed.

    Args:
        image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
            The image to convert to the `PIL.Image` format.
        do_rescale (`bool`, *optional*):
            Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
            to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
            and `False` otherwise.
        input_data_format (`ChannelDimension`, *optional*):
            The channel dimension format of the input image. If unset, will use the inferred format from the input.

    Returns:
        `PIL.Image.Image`: The converted image.
    """
    # 确保所需的后端已加载
    requires_backends(to_pil_image, ["vision"])

    if isinstance(image, PIL.Image.Image):
        return image

    # 将所有张量转换为 numpy 数组,以便转换为 PIL 图像
    if is_torch_tensor(image) or is_tf_tensor(image):
        image = image.numpy()
    elif is_jax_tensor(image):
        image = np.array(image)
    elif not isinstance(image, np.ndarray):
        # 抛出异常,因为不支持的输入图像类型
        raise ValueError("Input image type not supported: {}".format(type(image)))

    # 如果通道维度已经移动到第一维度,我们将其放回到最后一个维度
    image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)

    # 如果只有一个通道,我们压缩它,因为 PIL 不能处理非压缩的单通道图像
    image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
    # 如果需要将图像转换为 PIL.Image 格式,确保其像素值在 0 到 255 之间。
    do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
    # 如果需要进行像素值的重新缩放,则调用 rescale 函数将图像像素值缩放到 0 到 255 的范围内。
    if do_rescale:
        image = rescale(image, 255)
    # 将图像的数据类型转换为 np.uint8,确保图像的像素值范围在 0 到 255 之间。
    image = image.astype(np.uint8)
    # 根据图像的 numpy 数组创建一个 PIL.Image 对象,并返回该对象。
    return PIL.Image.fromarray(image)
# 导入必要的库和模块
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
def get_resize_output_image_size(
    input_image: np.ndarray,
    size: Union[int, Tuple[int, int], List[int], Tuple[int]],
    default_to_square: bool = True,
    max_size: Optional[int] = None,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> tuple:
    """
    Find the target (height, width) dimension of the output image after resizing given the input image and the desired
    size.

    Args:
        input_image (`np.ndarray`):
            The image to resize.
        size (`int` or `Tuple[int, int]` or List[int] or Tuple[int]):
            The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
            this.

            If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
            `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
            number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
        default_to_square (`bool`, *optional*, defaults to `True`):
            How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
            (`size`,`size`). If set to `False`, will replicate
            [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
            with support for resizing only the smallest edge and providing an optional `max_size`.
        max_size (`int`, *optional*):
            The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
            than `max_size` after being resized according to `size`, then the image is resized again so that the longer
            edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
            than `size`. Only used if `default_to_square` is `False`.
        input_data_format (`ChannelDimension`, *optional*):
            The channel dimension format of the input image. If unset, will use the inferred format from the input.

    Returns:
        `tuple`: The target (height, width) dimension of the output image after resizing.
    """

    # 如果 `size` 是一个元组或列表
    if isinstance(size, (tuple, list)):
        # 如果 `size` 的长度为2,直接返回元组形式的大小
        if len(size) == 2:
            return tuple(size)
        # 如果 `size` 的长度为1,执行和整数大小相同的逻辑
        elif len(size) == 1:
            size = size[0]
        else:
            raise ValueError("size must have 1 or 2 elements if it is a list or tuple")

    # 如果默认需要输出为正方形
    if default_to_square:
        return (size, size)

    # 获取输入图像的高度和宽度
    height, width = get_image_size(input_image, input_data_format)
    # 确定较短和较长的边
    short, long = (width, height) if width <= height else (height, width)
    # 请求的新的较短边的大小
    requested_new_short = size
    # 从请求中获取新的短边和长边尺寸,计算新的长边尺寸为请求的新短边尺寸乘以长边与短边的比例
    new_short, new_long = requested_new_short, int(requested_new_short * long / short)

    # 如果设置了最大尺寸限制
    if max_size is not None:
        # 如果最大尺寸小于或等于请求的新短边尺寸,抛出值错误异常
        if max_size <= requested_new_short:
            raise ValueError(
                f"max_size = {max_size} must be strictly greater than the requested "
                f"size for the smaller edge size = {size}"
            )
        # 如果新的长边超过了最大尺寸,调整新的短边和长边的尺寸比例,并将长边限制为最大尺寸
        if new_long > max_size:
            new_short, new_long = int(max_size * new_short / new_long), max_size

    # 根据宽度和高度的比较,返回调整后的长短边尺寸元组
    return (new_long, new_short) if width <= height else (new_short, new_long)
    """
    使用 PIL 库将 `image` 调整大小为 `size` 指定的尺寸。

    Args:
        image (`np.ndarray`):
            要调整大小的图像。
        size (`Tuple[int, int]`):
            用于调整图像大小的尺寸。
        resample (`int`, *optional*, 默认为 `PILImageResampling.BILINEAR`):
            用于重采样的滤波器。
        reducing_gap (`int`, *optional*):
            通过两步骤优化图像调整大小。`reducing_gap` 越大,结果越接*公*重采样。详细信息请参考 Pillow 文档。
        data_format (`ChannelDimension`, *optional*):
            输出图像的通道维度格式。如果未设置,将从输入中推断格式。
        return_numpy (`bool`, *optional*, 默认为 `True`):
            是否将调整大小后的图像作为 numpy 数组返回。如果为 False,则返回 `PIL.Image.Image` 对象。
        input_data_format (`ChannelDimension`, *optional*):
            输入图像的通道维度格式。如果未设置,将从输入中推断格式。

    Returns:
        `np.ndarray`: 调整大小后的图像。
    """
    requires_backends(resize, ["vision"])

    # 如果未指定 resample 方法,则默认使用 BILINEAR 方法
    resample = resample if resample is not None else PILImageResampling.BILINEAR

    # 检查 size 是否包含两个元素,否则抛出 ValueError
    if not len(size) == 2:
        raise ValueError("size must have 2 elements")

    # 对于所有转换,我们希望保持与输入图像相同的数据格式,除非另有指定。
    # PIL 调整大小后的图像始终将通道放在最后,因此首先找到输入格式。
    if input_data_format is None:
        input_data_format = infer_channel_dimension_format(image)
    data_format = input_data_format if data_format is None else data_format

    # 为了保持与以前图像特征提取器中所做的调整大小的向后兼容性,我们使用 Pillow 库调整大小图像,然后转换回 numpy 数组。
    do_rescale = False
    if not isinstance(image, PIL.Image.Image):
        # 如果输入图像不是 PIL.Image.Image 对象,则进行相应的转换
        do_rescale = _rescale_for_pil_conversion(image)
        image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
    
    # 提取出 size 的高度和宽度
    height, width = size
    # PIL 图像的大小顺序为 (宽度, 高度)
    resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
    # 如果需要返回一个 NumPy 数组,则将 resized_image 转换为 NumPy 数组类型
    resized_image = np.array(resized_image)
    # 如果输入图像的通道维度为 1,在转换为 PIL 图像时会丢失通道维度,因此需要在必要时添加回来
    resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
    # 在从 PIL 图像转换后,图像始终处于通道最后的格式
    resized_image = to_channel_dimension_format(
        resized_image, data_format, input_channel_dim=ChannelDimension.LAST
    )
    # 如果在转换为 PIL 图像之前对图像进行了 [0, 255] 范围内的重新缩放,则需要将其重新缩放回原始范围
    resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
    # 返回处理后的图像
    return resized_image
# 定义函数 `center_crop`,用于对图像进行中心裁剪操作,返回裁剪后的图像
def center_crop(
    image: np.ndarray,
    size: Tuple[int, int],
    data_format: Optional[Union[str, ChannelDimension]] = None,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
    return_numpy: Optional[bool] = None,
) -> np.ndarray:
    """
    Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
    the size given, it will be padded (so the returned result will always be of size `size`).

    Args:
        image (`np.ndarray`):
            The input image to be cropped.
        size (`Tuple[int, int]`):
            The desired output size after cropping, specified as (height, width).
        data_format (`Union[str, ChannelDimension]`, *optional*):
            The channel dimension format of the output image. If unset, will use the inferred format from the input.
        input_data_format (`Union[str, ChannelDimension]`, *optional*):
            The channel dimension format of the input image. If unset, will use the inferred format from the input.
        return_numpy (`bool`, *optional*):
            Deprecated parameter. If provided, this should be set to `True`.

    Returns:
        `np.ndarray`: The cropped image of the specified `size`.
    """
    """
    Args:
        image (`np.ndarray`):
            The image to crop.
        size (`Tuple[int, int]`):
            The target size for the cropped image.
        data_format (`str` or `ChannelDimension`, *optional*):
            The channel dimension format for the output image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use the inferred format of the input image.
        input_data_format (`str` or `ChannelDimension`, *optional*):
            The channel dimension format for the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use the inferred format of the input image.
        return_numpy (`bool`, *optional*):
            Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
            previous ImageFeatureExtractionMixin method.
                - Unset: will return the same type as the input image.
                - `True`: will return a numpy array.
                - `False`: will return a `PIL.Image.Image` object.
    Returns:
        `np.ndarray`: The cropped image.
    """
    requires_backends(center_crop, ["vision"])

    # Warn about deprecation of `return_numpy` parameter
    if return_numpy is not None:
        warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)

    # Determine whether to return numpy array based on `return_numpy` parameter
    return_numpy = True if return_numpy is None else return_numpy

    # Validate input image type
    if not isinstance(image, np.ndarray):
        raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

    # Validate size parameter
    if not isinstance(size, Iterable) or len(size) != 2:
        raise ValueError("size must have 2 elements representing the height and width of the output image")

    # Determine input data format if not explicitly provided
    if input_data_format is None:
        input_data_format = infer_channel_dimension_format(image)
    output_data_format = data_format if data_format is not None else input_data_format

    # Convert image to (C, H, W) format if necessary
    image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)

    # Get original image dimensions in channels-first format
    orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
    crop_height, crop_width = size
    crop_height, crop_width = int(crop_height), int(crop_width)

    # Calculate top-left corner coordinates of the crop area
    top = (orig_height - crop_height) // 2
    bottom = top + crop_height
    left = (orig_width - crop_width) // 2
    right = left + crop_width

    # Check if the calculated crop area is within image boundaries
    # 如果裁剪区域在图片边界内,则直接裁剪
    if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
        # 根据给定的裁剪区域对图像进行裁剪
        image = image[..., top:bottom, left:right]
        # 调整图像的通道维度格式为指定的输出格式,通道维度置于最前面
        image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
        # 返回裁剪后的图像
        return image

    # 否则,如果图像太小,需要进行填充处理
    new_height = max(crop_height, orig_height)
    new_width = max(crop_width, orig_width)
    # 构建新图像的形状,保留除了最后两个维度外的所有维度,并添加新的高度和宽度维度
    new_shape = image.shape[:-2] + (new_height, new_width)
    # 创建与原图像相同形状的全零数组作为新图像
    new_image = np.zeros_like(image, shape=new_shape)

    # 计算需要填充的边界
    top_pad = (new_height - orig_height) // 2
    bottom_pad = top_pad + orig_height
    left_pad = (new_width - orig_width) // 2
    right_pad = left_pad + orig_width
    # 在新图像的指定位置进行填充
    new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image

    # 更新裁剪区域的边界位置
    top += top_pad
    bottom += top_pad
    left += left_pad
    right += left_pad

    # 根据更新后的裁剪边界对新图像进行进一步裁剪
    new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
    # 调整图像的通道维度格式为指定的输出格式,通道维度置于最前面
    new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)

    # 如果不需要返回 NumPy 数组,则转换成 PIL 图像格式
    if not return_numpy:
        new_image = to_pil_image(new_image)

    # 返回处理后的图像
    return new_image
# 将中心格式的边界框转换为角点格式的边界框(使用 PyTorch 张量)
def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
    # 从中心格式的边界框张量中解绑出中心坐标和宽度、高度信息
    center_x, center_y, width, height = bboxes_center.unbind(-1)
    # 计算角点格式的边界框张量:左上角 x、左上角 y、右下角 x、右下角 y
    bbox_corners = torch.stack(
        [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
        dim=-1,
    )
    return bbox_corners


# 将中心格式的边界框转换为角点格式的边界框(使用 NumPy 数组)
def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
    # 从中心格式的边界框数组中解绑出中心坐标和宽度、高度信息
    center_x, center_y, width, height = bboxes_center.T
    # 计算角点格式的边界框数组:左上角 x、左上角 y、右下角 x、右下角 y
    bboxes_corners = np.stack(
        [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
        axis=-1,
    )
    return bboxes_corners


# 将中心格式的边界框转换为角点格式的边界框(使用 TensorFlow 张量)
def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
    # 从中心格式的边界框张量中解绑出中心坐标和宽度、高度信息
    center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
    # 计算角点格式的边界框张量:左上角 x、左上角 y、右下角 x、右下角 y
    bboxes_corners = tf.stack(
        [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
        axis=-1,
    )
    return bboxes_corners


# 以下两个函数灵感来自 https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
# 将边界框从中心格式转换为角点格式的统一接口函数
def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
    """
    Converts bounding boxes from center format to corners format.

    center format: contains the coordinate for the center of the box and its width, height dimensions
        (center_x, center_y, width, height)
    corners format: contains the coodinates for the top-left and bottom-right corners of the box
        (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
    """
    # 根据输入类型选择对应的转换函数,用于模型前向传递时的边界框格式转换,尽可能不转换为 NumPy 数组
    if is_torch_tensor(bboxes_center):  # 如果是 PyTorch 张量
        return _center_to_corners_format_torch(bboxes_center)
    elif isinstance(bboxes_center, np.ndarray):  # 如果是 NumPy 数组
        return _center_to_corners_format_numpy(bboxes_center)
    elif is_tf_tensor(bboxes_center):  # 如果是 TensorFlow 张量
        return _center_to_corners_format_tf(bboxes_center)

    # 如果输入类型不支持,则抛出异常
    raise ValueError(f"Unsupported input type {type(bboxes_center)}")


# 将角点格式的边界框转换为中心格式(使用 PyTorch 张量)
def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
    # 从角点格式的边界框张量中解绑出左上角和右下角坐标信息
    top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
    # 计算中心格式的边界框张量:中心 x、中心 y、宽度、高度
    b = [
        (top_left_x + bottom_right_x) / 2,  # 中心 x
        (top_left_y + bottom_right_y) / 2,  # 中心 y
        (bottom_right_x - top_left_x),      # 宽度
        (bottom_right_y - top_left_y),      # 高度
    ]
    return torch.stack(b, dim=-1)


# 将角点格式的边界框转换为中心格式(使用 NumPy 数组)
def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
    # 从角点格式的边界框数组中解绑出左上角和右下角坐标信息
    top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
    # 创建一个包含边界框中心坐标和宽高的数组
    bboxes_center = np.stack(
        [
            (top_left_x + bottom_right_x) / 2,  # 计算边界框的中心 x 坐标
            (top_left_y + bottom_right_y) / 2,  # 计算边界框的中心 y 坐标
            (bottom_right_x - top_left_x),      # 计算边界框的宽度
            (bottom_right_y - top_left_y),      # 计算边界框的高度
        ],
        axis=-1,  # 沿着最后一个轴(即最内层)堆叠数组
    )
    # 返回包含所有边界框中心和尺寸信息的数组
    return bboxes_center
def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
    """
    Converts bounding boxes from corners format to center format using TensorFlow operations.

    Args:
        bboxes_corners (tf.Tensor): Tensor containing bounding box coordinates in corners format
            (top_left_x, top_left_y, bottom_right_x, bottom_right_y)

    Returns:
        tf.Tensor: Tensor containing bounding box coordinates in center format
            (center_x, center_y, width, height)
    """
    # Unstack the input tensor along the last axis to get individual coordinates
    top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
    # Compute center coordinates, width, and height using TensorFlow operations
    bboxes_center = tf.stack(
        [
            (top_left_x + bottom_right_x) / 2,  # center x
            (top_left_y + bottom_right_y) / 2,  # center y
            (bottom_right_x - top_left_x),      # width
            (bottom_right_y - top_left_y),      # height
        ],
        axis=-1,
    )
    return bboxes_center


def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
    """
    Converts bounding boxes from corners format to center format.

    Args:
        bboxes_corners (TensorType): Tensor or array containing bounding box coordinates in corners format
            (top_left_x, top_left_y, bottom_right_x, bottom_right_y)

    Returns:
        TensorType: Tensor or array containing bounding box coordinates in center format
            (center_x, center_y, width, height)

    Raises:
        ValueError: If the input type is unsupported
    """
    # Check the type of input and call the respective conversion function
    if is_torch_tensor(bboxes_corners):
        return _corners_to_center_format_torch(bboxes_corners)
    elif isinstance(bboxes_corners, np.ndarray):
        return _corners_to_center_format_numpy(bboxes_corners)
    elif is_tf_tensor(bboxes_corners):
        return _corners_to_center_format_tf(bboxes_corners)

    # Raise an error if the input type is not recognized
    raise ValueError(f"Unsupported input type {type(bboxes_corners)}")


# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
# Copyright (c) 2018, Alexander Kirillov
# All rights reserved.
def rgb_to_id(color):
    """
    Converts RGB color to unique ID.

    Args:
        color (np.ndarray or list): RGB color values

    Returns:
        int: Unique ID corresponding to the RGB color
    """
    if isinstance(color, np.ndarray) and len(color.shape) == 3:
        if color.dtype == np.uint8:
            color = color.astype(np.int32)
        return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
    return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


def id_to_rgb(id_map):
    """
    Converts unique ID to RGB color.

    Args:
        id_map (np.ndarray or int): Unique ID or array of IDs

    Returns:
        np.ndarray or list: RGB color corresponding to the unique ID or array of RGB colors
    """
    if isinstance(id_map, np.ndarray):
        id_map_copy = id_map.copy()
        rgb_shape = tuple(list(id_map.shape) + [3])
        rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
        for i in range(3):
            rgb_map[..., i] = id_map_copy % 256
            id_map_copy //= 256
        return rgb_map
    color = []
    for _ in range(3):
        color.append(id_map % 256)
        id_map //= 256
    return color


class PaddingMode(ExplicitEnum):
    """
    Enum class for the different padding modes to use when padding images.
    """

    CONSTANT = "constant"
    REFLECT = "reflect"
    REPLICATE = "replicate"
    SYMMETRIC = "symmetric"


def pad(
    image: np.ndarray,
    padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
    mode: PaddingMode = PaddingMode.CONSTANT,
    constant_values: Union[float, Iterable[float]] = 0.0,
    data_format: Optional[Union[str, ChannelDimension]] = None,
):
    """
    Pads an image array according to specified parameters.

    Args:
        image (np.ndarray): Image array to be padded.
        padding (int or Tuple[int, int] or Iterable[Tuple[int, int]]): Padding size or sizes in each dimension.
        mode (PaddingMode, optional): Padding mode, defaults to PaddingMode.CONSTANT.
        constant_values (float or Iterable[float], optional): Constant value(s) to pad with, defaults to 0.0.
        data_format (str or ChannelDimension, optional): Data format of the image array, defaults to None.

    Returns:
        np.ndarray: Padded image array.
    """
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
# 使用 numpy 数组作为参数的函数定义,该函数用于对图像进行填充操作。
def pad_image(image: np.ndarray,
              padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
              mode: PaddingMode,
              constant_values: Optional[Union[float, Iterable[float]]] = None,
              data_format: Optional[Union[str, ChannelDimension]] = None,
              input_data_format: Optional[Union[str, ChannelDimension]] = None) -> np.ndarray:
    """
    Pads the `image` with the specified (height, width) `padding` and `mode`.

    Args:
        image (`np.ndarray`):
            The image to pad.
        padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
            Padding to apply to the edges of the height, width axes. Can be one of three formats:
            - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
            - `((before, after),)` yields same before and after pad for height and width.
            - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
        mode (`PaddingMode`):
            The padding mode to use. Can be one of:
                - `"constant"`: pads with a constant value.
                - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
                  vector along each axis.
                - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
                - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
        constant_values (`float` or `Iterable[float]`, *optional*):
            The value to use for the padding if `mode` is `"constant"`.
        data_format (`str` or `ChannelDimension`, *optional*):
            The channel dimension format for the output image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use same as the input image.
        input_data_format (`str` or `ChannelDimension`, *optional*):
            The channel dimension format for the input image. Can be one of:
                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use the inferred format of the input image.

    Returns:
        `np.ndarray`: The padded image.

    """
    # 如果未指定输入数据的通道格式,使用推断的通道格式
    if input_data_format is None:
        input_data_format = infer_channel_dimension_format(image)
    def _expand_for_data_format(values):
        """
        Convert values to be in the format expected by np.pad based on the data format.
        """
        # 如果values是整数或浮点数,将其转换为二维元组格式
        if isinstance(values, (int, float)):
            values = ((values, values), (values, values))
        # 如果values是长度为1的元组,将其转换为二维元组格式
        elif isinstance(values, tuple) and len(values) == 1:
            values = ((values[0], values[0]), (values[0], values[0]))
        # 如果values是长度为2的元组且第一个元素是整数,保持不变
        elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
            values = (values, values)
        # 如果values是长度为2的元组且第一个元素是元组,保持不变
        elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
            values = values
        else:
            # 如果values不符合以上格式,抛出异常
            raise ValueError(f"Unsupported format: {values}")

        # 根据输入数据格式选择是否在通道维度前面添加0
        values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))

        # 如果图像维度为4,则在前面添加0作为批量维度
        values = (0, *values) if image.ndim == 4 else values
        return values

    # 根据数据格式扩展填充参数
    padding = _expand_for_data_format(padding)

    # 根据填充模式进行图像填充
    if mode == PaddingMode.CONSTANT:
        # 根据数据格式扩展常数填充值
        constant_values = _expand_for_data_format(constant_values)
        # 使用常数填充模式填充图像
        image = np.pad(image, padding, mode="constant", constant_values=constant_values)
    elif mode == PaddingMode.REFLECT:
        # 使用反射模式填充图像
        image = np.pad(image, padding, mode="reflect")
    elif mode == PaddingMode.REPLICATE:
        # 使用复制模式填充图像
        image = np.pad(image, padding, mode="edge")
    elif mode == PaddingMode.SYMMETRIC:
        # 使用对称模式填充图像
        image = np.pad(image, padding, mode="symmetric")
    else:
        # 如果填充模式无效,抛出异常
        raise ValueError(f"Invalid padding mode: {mode}")

    # 如果数据格式不为空,将图像转换为通道维度格式
    image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
    return image
# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
def convert_to_rgb(image: ImageInput) -> ImageInput:
    """
    Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
    as is.

    Args:
        image (Image):
            The image to convert.
    """
    # 确保当前函数所需的视觉后端已加载
    requires_backends(convert_to_rgb, ["vision"])

    # 如果传入的图像不是 PIL.Image.Image 类型,则直接返回
    if not isinstance(image, PIL.Image.Image):
        return image

    # 将图像转换为 RGB 格式
    image = image.convert("RGB")
    return image


def flip_channel_order(
    image: np.ndarray,
    data_format: Optional[ChannelDimension] = None,
    input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
    """
    Flips the channel order of the image.

    If the image is in RGB format, it will be converted to BGR and vice versa.

    Args:
        image (`np.ndarray`):
            The image to flip.
        data_format (`ChannelDimension`, *optional*):
            The channel dimension format for the output image. Can be one of:
                - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use same as the input image.
        input_data_format (`ChannelDimension`, *optional*):
            The channel dimension format for the input image. Can be one of:
                - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
                - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
            If unset, will use the inferred format of the input image.
    """
    # 推断输入图像的通道维度格式,如果未指定则使用推断的格式
    input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format

    # 根据输入图像的通道维度格式执行通道顺序翻转操作
    if input_data_format == ChannelDimension.LAST:
        image = image[..., ::-1]  # BGR 到 RGB 或 RGB 到 BGR 的转换
    elif input_data_format == ChannelDimension.FIRST:
        image = image[::-1, ...]  # BGR 到 RGB 或 RGB 到 BGR 的转换
    else:
        raise ValueError(f"Unsupported channel dimension: {input_data_format}")

    # 如果指定了输出图像的通道维度格式,则将图像转换为该格式
    if data_format is not None:
        image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
    return image
posted @ 2024-06-30 15:37  绝不原创的飞龙  阅读(97)  评论(0编辑  收藏  举报