diffusers-源码解析-三十四-

diffusers 源码解析(三十四)

.\diffusers\pipelines\kandinsky3\__init__.py

# 导入类型检查相关的模块
from typing import TYPE_CHECKING

# 从 utils 模块导入所需的工具函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入标记常量,用于慢导入判断
    OptionalDependencyNotAvailable,  # 导入自定义异常,用于处理依赖不可用情况
    _LazyModule,  # 导入懒加载模块的类
    get_objects_from_module,  # 导入从模块中获取对象的函数
    is_torch_available,  # 导入检查 PyTorch 是否可用的函数
    is_transformers_available,  # 导入检查 Transformers 是否可用的函数
)

# 初始化一个空字典用于存储假对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}

# 尝试检查依赖的可用性
try:
    # 如果 Transformers 和 PyTorch 不可用,则抛出异常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕获依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 模块导入假对象,避免依赖问题
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新假对象字典,添加从 dummy 模块中获取的对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖可用,则更新导入结构
else:
    _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]  # 添加 Kandinsky3Pipeline 到导入结构
    _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]  # 添加 Kandinsky3Img2ImgPipeline 到导入结构

# 如果是类型检查或慢导入模式,则进行依赖检查
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 如果 Transformers 和 PyTorch 不可用,则抛出异常
        if not (is_transformers_available() and is_torch_available()):
            raise OptionalDependencyNotAvailable()

    # 捕获依赖不可用的异常
    except OptionalDependencyNotAvailable:
        # 从 dummy 模块导入所有假对象
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 从 pipeline_kandinsky3 模块导入 Kandinsky3Pipeline
        from .pipeline_kandinsky3 import Kandinsky3Pipeline
        # 从 pipeline_kandinsky3_img2img 模块导入 Kandinsky3Img2ImgPipeline
        from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
# 如果不是类型检查或慢导入模式,则使用懒加载
else:
    import sys

    # 将当前模块替换为一个懒加载模块
    sys.modules[__name__] = _LazyModule(
        __name__,
        globals()["__file__"],  # 获取当前文件的全局变量
        _import_structure,  # 使用之前定义的导入结构
        module_spec=__spec__,  # 使用模块的规格
    )

    # 将假对象字典中的对象设置到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 为当前模块设置假对象

.\diffusers\pipelines\kolors\pipeline_kolors.py

# 版权声明,表示代码的版权所有者和保留权利
# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 只有在遵守许可证的情况下才能使用此文件
# you may not use this file except in compliance with the License.
# 可以在以下网址获取许可证的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非根据适用法律或书面协议另有规定,否则按“原样”提供软件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 参见许可证以获取特定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect  # 导入inspect模块以便进行对象检查
from typing import Any, Callable, Dict, List, Optional, Tuple, Union  # 导入类型提示工具

import torch  # 导入PyTorch库以进行张量计算
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection  # 导入transformers库中的图像处理器和模型

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 导入回调相关的模块
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理相关的模块
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin  # 导入加载器相关的混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel  # 导入模型相关的类
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor  # 导入注意力处理器
from ...schedulers import KarrasDiffusionSchedulers  # 导入调度器
from ...utils import is_torch_xla_available, logging, replace_example_docstring  # 导入工具函数
from ...utils.torch_utils import randn_tensor  # 导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入管道相关的类
from .pipeline_output import KolorsPipelineOutput  # 导入管道输出相关的类
from .text_encoder import ChatGLMModel  # 导入文本编码模型
from .tokenizer import ChatGLMTokenizer  # 导入聊天GLM的分词器


if is_torch_xla_available():  # 检查是否可用torch_xla
    import torch_xla.core.xla_model as xm  # 导入torch_xla相关模块

    XLA_AVAILABLE = True  # 如果可用,设置标志为True
else:
    XLA_AVAILABLE = False  # 否则设置为False


logger = logging.get_logger(__name__)  # 创建当前模块的日志记录器,禁用pylint警告


EXAMPLE_DOC_STRING = """  # 示例文档字符串,用于展示如何使用代码
    Examples:
        ```py
        >>> import torch  # 导入torch库
        >>> from diffusers import KolorsPipeline  # 从diffusers导入Kolors管道

        >>> pipe = KolorsPipeline.from_pretrained(  # 从预训练模型创建Kolors管道
        ...     "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16
        ... )
        >>> pipe = pipe.to("cuda")  # 将管道移动到GPU设备

        >>> prompt = (  # 定义生成图像的提示
        ...     "A photo of a ladybug, macro, zoom, high quality, film, holding a wooden sign with the text 'KOLORS'"
        ... )
        >>> image = pipe(prompt).images[0]  # 生成图像并获取第一张图像
        ```py
"""


# 从stable_diffusion管道复制的函数,用于检索时间步
def retrieve_timesteps(
    scheduler,  # 调度器对象
    num_inference_steps: Optional[int] = None,  # 可选的推理步骤数量
    device: Optional[Union[str, torch.device]] = None,  # 可选的设备信息
    timesteps: Optional[List[int]] = None,  # 可选的时间步列表
    sigmas: Optional[List[float]] = None,  # 可选的sigma值列表
    **kwargs,  # 其他关键字参数
):
    """
    调用调度器的`set_timesteps`方法并在调用后从调度器获取时间步。处理
    自定义时间步。任何关键字参数将传递给`scheduler.set_timesteps`。
    # 参数说明
    Args:
        scheduler (`SchedulerMixin`):  # 调度器,用于获取时间步
            The scheduler to get timesteps from.  # 从调度器中获取时间步
        num_inference_steps (`int`):  # 生成样本时使用的扩散步骤数
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.  # 如果使用此参数,`timesteps`必须为`None`
        device (`str` or `torch.device`, *optional*):  # 指定时间步移动到的设备
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.  # 如果为`None`,则不移动时间步
        timesteps (`List[int]`, *optional*):  # 自定义时间步,用于覆盖调度器的时间步间隔策略
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.  # 如果传递了`timesteps`,则`num_inference_steps`和`sigmas`必须为`None`
        sigmas (`List[float]`, *optional*):  # 自定义sigmas,用于覆盖调度器的时间步间隔策略
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.  # 如果传递了`sigmas`,则`num_inference_steps`和`timesteps`必须为`None`

    Returns:
        `Tuple[torch.Tensor, int]`:  # 返回一个元组
        A tuple where the first element is the timestep schedule from the scheduler and the  # 第一个元素是来自调度器的时间步安排
        second element is the number of inference steps.  # 第二个元素是推理步骤的数量
    """
    # 检查`timesteps`和`sigmas`是否同时存在
    if timesteps is not None and sigmas is not None:
        # 抛出错误,提示只能传递`timesteps`或`sigmas`中的一个
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    # 如果`timesteps`存在
    if timesteps is not None:
        # 检查调度器的`set_timesteps`方法是否接受`timesteps`参数
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,抛出错误
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的`set_timesteps`方法,设置时间步
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        # 获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    # 如果`sigmas`存在
    elif sigmas is not None:
        # 检查调度器的`set_timesteps`方法是否接受`sigmas`参数
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,抛出错误
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的`set_timesteps`方法,设置sigmas
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        # 获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    # 如果`timesteps`和`sigmas`都不存在
    else:
        # 调用调度器的`set_timesteps`方法,使用默认的推理步骤数量
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        # 获取设置后的时间步
        timesteps = scheduler.timesteps
    # 返回时间步和推理步骤数量
    return timesteps, num_inference_steps
# 定义 KolorsPipeline 类,继承自多个父类以实现文本到图像生成
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
    # 文档字符串,说明该管道的用途和继承关系
    r"""
    Pipeline for text-to-image generation using Kolors.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    The pipeline also inherits the following loading methods:
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`ChatGLMModel`]):
            Frozen text-encoder. Kolors uses [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b).
        tokenizer (`ChatGLMTokenizer`):
            Tokenizer of class
            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"False"`):
            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
            `Kwai-Kolors/Kolors-diffusers`.
    """

    # 定义模型在 CPU 上卸载的顺序
    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    # 定义可选组件的列表
    _optional_components = [
        "image_encoder",
        "feature_extractor",
    ]
    # 定义回调张量输入的列表
    _callback_tensor_inputs = [
        "latents",
        "prompt_embeds",
        "negative_prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
        "negative_pooled_prompt_embeds",
        "negative_add_time_ids",
    ]

    # 初始化方法,定义类的属性和参数
    def __init__(
        # 定义变分自编码器模型
        self,
        vae: AutoencoderKL,
        # 定义冻结的文本编码器
        text_encoder: ChatGLMModel,
        # 定义分词器
        tokenizer: ChatGLMTokenizer,
        # 定义条件 U-Net 模型
        unet: UNet2DConditionModel,
        # 定义调度器,用于图像去噪
        scheduler: KarrasDiffusionSchedulers,
        # 可选的图像编码器
        image_encoder: CLIPVisionModelWithProjection = None,
        # 可选的特征提取器
        feature_extractor: CLIPImageProcessor = None,
        # 是否强制将空提示的负向嵌入设置为零
        force_zeros_for_empty_prompt: bool = False,
    # 初始化父类
        ):
            super().__init__()
    
            # 注册模块,包括 VAE、文本编码器、分词器等
            self.register_modules(
                vae=vae,
                text_encoder=text_encoder,
                tokenizer=tokenizer,
                unet=unet,
                scheduler=scheduler,
                image_encoder=image_encoder,
                feature_extractor=feature_extractor,
            )
            # 将配置中的强制零填充选项注册到配置中
            self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
            # 计算 VAE 的缩放因子,如果存在 VAE 则取其通道数的块数
            self.vae_scale_factor = (
                2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
            )
            # 初始化图像处理器,使用 VAE 缩放因子
            self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
    
            # 设置默认样本大小,从 UNet 的配置中获取
            self.default_sample_size = self.unet.config.sample_size
    
        # 编码提示的函数,处理各种输入参数
        def encode_prompt(
            self,
            prompt,
            device: Optional[torch.device] = None,
            num_images_per_prompt: int = 1,
            do_classifier_free_guidance: bool = True,
            negative_prompt=None,
            prompt_embeds: Optional[torch.FloatTensor] = None,
            pooled_prompt_embeds: Optional[torch.Tensor] = None,
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
            negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
            max_sequence_length: int = 256,
        # 从 diffusers 库中复制的编码图像的函数
        def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
            # 获取图像编码器参数的数据类型
            dtype = next(self.image_encoder.parameters()).dtype
    
            # 如果输入的图像不是张量,则进行特征提取
            if not isinstance(image, torch.Tensor):
                image = self.feature_extractor(image, return_tensors="pt").pixel_values
    
            # 将图像数据移动到指定设备,并设置数据类型
            image = image.to(device=device, dtype=dtype)
            # 如果需要输出隐藏状态,进行相应的编码
            if output_hidden_states:
                image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
                # 重复编码的隐藏状态以匹配每个提示的图像数量
                image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
                # 编码未条件化图像的隐藏状态
                uncond_image_enc_hidden_states = self.image_encoder(
                    torch.zeros_like(image), output_hidden_states=True
                ).hidden_states[-2]
                # 重复未条件化图像的隐藏状态
                uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                    num_images_per_prompt, dim=0
                )
                # 返回编码后的图像和未条件化图像的隐藏状态
                return image_enc_hidden_states, uncond_image_enc_hidden_states
            else:
                # 编码图像并获取图像嵌入
                image_embeds = self.image_encoder(image).image_embeds
                # 重复图像嵌入以匹配每个提示的图像数量
                image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
                # 创建未条件化的图像嵌入
                uncond_image_embeds = torch.zeros_like(image_embeds)
    
                # 返回编码后的图像嵌入和未条件化的图像嵌入
                return image_embeds, uncond_image_embeds
    
        # 从 diffusers 库中复制的准备 IP 适配器图像嵌入的函数
        def prepare_ip_adapter_image_embeds(
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        # 初始化一个空列表,用于存储图像嵌入
        image_embeds = []
        # 如果启用无分类器自由引导,则初始化一个空列表,用于存储负图像嵌入
        if do_classifier_free_guidance:
            negative_image_embeds = []
        # 如果输入适配器的图像嵌入为 None
        if ip_adapter_image_embeds is None:
            # 检查 ip_adapter_image 是否为列表,如果不是,则将其转换为列表
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            # 检查 ip_adapter_image 的长度是否与 IP 适配器的数量相同
            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                # 如果长度不匹配,抛出 ValueError
                raise ValueError(
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            # 遍历每个适配器图像及其对应的图像投影层
            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                # 检查当前图像投影层是否为 ImageProjection 的实例,决定是否输出隐藏状态
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                # 对单个适配器图像进行编码,返回嵌入
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )

                # 将当前图像嵌入添加到图像嵌入列表中
                image_embeds.append(single_image_embeds[None, :])
                # 如果启用无分类器自由引导,则将负图像嵌入添加到负图像嵌入列表中
                if do_classifier_free_guidance:
                    negative_image_embeds.append(single_negative_image_embeds[None, :])
        else:
            # 如果输入适配器的图像嵌入不为 None,则遍历这些嵌入
            for single_image_embeds in ip_adapter_image_embeds:
                # 如果启用无分类器自由引导,则将当前图像嵌入拆分为负图像嵌入和正图像嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    # 将负图像嵌入添加到负图像嵌入列表中
                    negative_image_embeds.append(single_negative_image_embeds)
                # 将正图像嵌入添加到图像嵌入列表中
                image_embeds.append(single_image_embeds)

        # 初始化一个空列表,用于存储最终的适配器图像嵌入
        ip_adapter_image_embeds = []
        # 遍历图像嵌入及其索引
        for i, single_image_embeds in enumerate(image_embeds):
            # 将每个图像嵌入复制 num_images_per_prompt 次
            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
            # 如果启用无分类器自由引导,则复制负图像嵌入
            if do_classifier_free_guidance:
                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                # 将负图像嵌入与正图像嵌入拼接
                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

            # 将图像嵌入移动到指定的设备上
            single_image_embeds = single_image_embeds.to(device=device)
            # 将处理后的图像嵌入添加到最终列表中
            ip_adapter_image_embeds.append(single_image_embeds)

        # 返回最终的适配器图像嵌入列表
        return ip_adapter_image_embeds

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制
    # 准备额外参数用于调度器步骤,因为并非所有调度器具有相同的参数签名
        def prepare_extra_step_kwargs(self, generator, eta):
            # eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
            # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
            # 值应在 [0, 1] 之间
    
            # 检查调度器步骤是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 初始化额外步骤参数的字典
            extra_step_kwargs = {}
            # 如果调度器接受 eta,添加其值到字典
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器步骤是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果调度器接受 generator,添加其值到字典
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回包含额外参数的字典
            return extra_step_kwargs
    
        # 检查输入参数的有效性
        def check_inputs(
            self,
            prompt,
            num_inference_steps,
            height,
            width,
            negative_prompt=None,
            prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            ip_adapter_image=None,
            ip_adapter_image_embeds=None,
            callback_on_step_end_tensor_inputs=None,
            max_sequence_length=None,
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制
        def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
            # 计算生成的张量形状
            shape = (
                batch_size,
                num_channels_latents,
                int(height) // self.vae_scale_factor,
                int(width) // self.vae_scale_factor,
            )
            # 检查传入的 generator 列表长度是否与 batch_size 匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 如果未提供 latents,生成随机张量
            if latents is None:
                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            else:
                # 如果提供了 latents,将其移动到指定设备
                latents = latents.to(device)
    
            # 将初始噪声按调度器所需的标准差进行缩放
            latents = latents * self.scheduler.init_noise_sigma
            # 返回处理后的噪声张量
            return latents
    
        # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids 复制
        def _get_add_time_ids(
            self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
    ):
        # 将原始尺寸、裁剪坐标和目标尺寸合并为一个列表
        add_time_ids = list(original_size + crops_coords_top_left + target_size)

        # 计算通过添加时间嵌入维度和文本编码器投影维度得到的总嵌入维度
        passed_add_embed_dim = (
            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
        )
        # 获取模型期望的添加时间嵌入的输入特征维度
        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

        # 检查计算得到的维度与期望的维度是否相等
        if expected_add_embed_dim != passed_add_embed_dim:
            # 如果不相等,抛出错误提示
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )

        # 将添加的时间 ID 转换为张量,指定数据类型
        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        # 返回添加的时间 ID 张量
        return add_time_ids

    # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae 复制而来
    def upcast_vae(self):
        # 获取 VAE 模型的数据类型
        dtype = self.vae.dtype
        # 将 VAE 模型转换为 float32 类型
        self.vae.to(dtype=torch.float32)
        # 检查是否使用 torch 2.0 或 xformers
        use_torch_2_0_or_xformers = isinstance(
            self.vae.decoder.mid_block.attentions[0].processor,
            (
                AttnProcessor2_0,
                XFormersAttnProcessor,
                FusedAttnProcessor2_0,
            ),
        )
        # 如果使用 xformers 或 torch 2.0,则注意力模块不需要在 float32 中,从而节省大量内存
        if use_torch_2_0_or_xformers:
            # 将后量化卷积层转换为指定数据类型
            self.vae.post_quant_conv.to(dtype)
            # 将解码器输入卷积层转换为指定数据类型
            self.vae.decoder.conv_in.to(dtype)
            # 将解码器中间块转换为指定数据类型
            self.vae.decoder.mid_block.to(dtype)

    # 从 diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding 复制而来
    def get_guidance_scale_embedding(
        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
    ) -> torch.Tensor:  # 定义函数返回类型为 torch.Tensor
        """  # 文档字符串开始,描述函数的功能
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298  # 参考链接
        
        Args:  # 参数说明开始
            w (`torch.Tensor`):  # 输入的张量 w,用于生成嵌入向量
                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.  # 生成嵌入向量以丰富时间步嵌入
            embedding_dim (`int`, *optional*, defaults to 512):  # 嵌入维度,可选,默认为512
                Dimension of the embeddings to generate.  # 生成的嵌入维度
            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):  # 数据类型,可选,默认为 torch.float32
                Data type of the generated embeddings.  # 生成嵌入的数值类型
        
        Returns:  # 返回值说明开始
            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.  # 返回形状为 (len(w), embedding_dim) 的嵌入张量
        """  # 文档字符串结束
        assert len(w.shape) == 1  # 断言 w 的形状是一维的
        w = w * 1000.0  # 将 w 的值放大1000倍

        half_dim = embedding_dim // 2  # 计算嵌入维度的一半
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)  # 计算对数并归一化
        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)  # 生成指数衰减的嵌入
        emb = w.to(dtype)[:, None] * emb[None, :]  # 将 w 转换为目标数据类型并进行广播乘法
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)  # 将正弦和余弦嵌入在维度1上拼接
        if embedding_dim % 2 == 1:  # 如果嵌入维度为奇数
            emb = torch.nn.functional.pad(emb, (0, 1))  # 在最后一维进行零填充
        assert emb.shape == (w.shape[0], embedding_dim)  # 断言嵌入的形状符合预期
        return emb  # 返回生成的嵌入张量

    @property  # 将以下方法声明为属性
    def guidance_scale(self):  # 定义 guidance_scale 属性
        return self._guidance_scale  # 返回内部存储的引导比例

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)  # 解释 guidance_scale 的定义
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`  # 说明相关文献
    # corresponds to doing no classifier free guidance.  # 指出值为1时不进行无分类器引导
    @property  # 将以下方法声明为属性
    def do_classifier_free_guidance(self):  # 定义 do_classifier_free_guidance 属性
        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None  # 返回是否启用无分类器引导的布尔值

    @property  # 将以下方法声明为属性
    def cross_attention_kwargs(self):  # 定义 cross_attention_kwargs 属性
        return self._cross_attention_kwargs  # 返回交叉注意力参数

    @property  # 将以下方法声明为属性
    def denoising_end(self):  # 定义 denoising_end 属性
        return self._denoising_end  # 返回去噪结束位置

    @property  # 将以下方法声明为属性
    def num_timesteps(self):  # 定义 num_timesteps 属性
        return self._num_timesteps  # 返回时间步数

    @property  # 将以下方法声明为属性
    def interrupt(self):  # 定义 interrupt 属性
        return self._interrupt  # 返回中断状态

    @torch.no_grad()  # 禁用梯度计算,提升推理性能
    @replace_example_docstring(EXAMPLE_DOC_STRING)  # 用示例文档字符串替换当前文档字符串
    # 定义可调用对象的方法,允许使用一系列参数进行处理
        def __call__(
            # 提示文本,可以是字符串或字符串列表,默认为 None
            self,
            prompt: Union[str, List[str]] = None,
            # 输出图像的高度,默认为 None
            height: Optional[int] = None,
            # 输出图像的宽度,默认为 None
            width: Optional[int] = None,
            # 推理步骤的数量,默认为 50
            num_inference_steps: int = 50,
            # 时间步列表,默认为 None
            timesteps: List[int] = None,
            # 噪声级别的列表,默认为 None
            sigmas: List[float] = None,
            # 去噪结束的阈值,默认为 None
            denoising_end: Optional[float] = None,
            # 引导尺度,默认为 5.0
            guidance_scale: float = 5.0,
            # 负提示文本,可以是字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: Optional[int] = 1,
            # 用于控制生成过程的参数,默认为 0.0
            eta: float = 0.0,
            # 随机数生成器,可以是单个或多个,默认为 None
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量的张量,默认为 None
            latents: Optional[torch.Tensor] = None,
            # 提示的嵌入张量,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 经过池化的提示嵌入张量,默认为 None
            pooled_prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示的嵌入张量,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示的池化嵌入张量,默认为 None
            negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
            # 图像适配器输入,默认为 None
            ip_adapter_image: Optional[PipelineImageInput] = None,
            # 图像适配器的嵌入张量列表,默认为 None
            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
            # 输出类型,默认为 "pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为 True
            return_dict: bool = True,
            # 跨注意力的参数,默认为 None
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 原始图像的尺寸,默认为 None
            original_size: Optional[Tuple[int, int]] = None,
            # 裁剪坐标的左上角,默认为 (0, 0)
            crops_coords_top_left: Tuple[int, int] = (0, 0),
            # 目标尺寸,默认为 None
            target_size: Optional[Tuple[int, int]] = None,
            # 负样本的原始尺寸,默认为 None
            negative_original_size: Optional[Tuple[int, int]] = None,
            # 负样本的裁剪坐标的左上角,默认为 (0, 0)
            negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
            # 负样本的目标尺寸,默认为 None
            negative_target_size: Optional[Tuple[int, int]] = None,
            # 步骤结束时的回调函数,默认为 None
            callback_on_step_end: Optional[
                Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
            ] = None,
            # 用于步骤结束回调的张量输入列表,默认为 ["latents"]
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 最大序列长度,默认为 256
            max_sequence_length: int = 256,

.\diffusers\pipelines\kolors\pipeline_kolors_img2img.py

# 版权信息,说明文件的所有权及使用许可
# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可使用此文件;
# 除非遵循许可证,否则不得使用此文件。
# 可在以下网址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是在“按现状”基础上分发的,
# 不附带任何种类的保证或条件。
# 有关许可证的具体条款和条件,请参阅许可证。
import inspect  # 导入 inspect 模块,用于获取对象的信息
from typing import Any, Callable, Dict, List, Optional, Tuple, Union  # 导入类型提示,定义可用的类型

import PIL.Image  # 导入 PIL.Image,用于图像处理
import torch  # 导入 PyTorch 库,进行深度学习相关操作
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection  # 导入 CLIP 相关的图像处理和模型

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 从回调模块导入多管道回调和管道回调类
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理相关的类
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin  # 导入加载器混合类
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel  # 导入模型类
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor  # 导入注意力处理器
from ...schedulers import KarrasDiffusionSchedulers  # 导入调度器类
from ...utils import is_torch_xla_available, logging, replace_example_docstring  # 导入工具函数
from ...utils.torch_utils import randn_tensor  # 导入生成随机张量的函数
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道和混合类
from .pipeline_output import KolorsPipelineOutput  # 导入管道输出类
from .text_encoder import ChatGLMModel  # 导入文本编码器模型
from .tokenizer import ChatGLMTokenizer  # 导入文本分词器

if is_torch_xla_available():  # 检查是否可用 Torch XLA
    import torch_xla.core.xla_model as xm  # 导入 XLA 模块

    XLA_AVAILABLE = True  # 设置 XLA 可用标志为 True
else:
    XLA_AVAILABLE = False  # 如果不支持,则设置为 False

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器

EXAMPLE_DOC_STRING = """  # 示例文档字符串
    Examples:
        ```py
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import KolorsImg2ImgPipeline  # 从 diffusers 导入图像到图像管道
        >>> from diffusers.utils import load_image  # 从工具模块导入加载图像的函数

        >>> pipe = KolorsImg2ImgPipeline.from_pretrained(  # 从预训练模型创建管道实例
        ...     "Kwai-Kolors/Kolors-diffusers", variant="fp16", torch_dtype=torch.float16
        ... )
        >>> pipe = pipe.to("cuda")  # 将管道移动到 GPU
        >>> url = (  # 定义图像的 URL
        ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/bunny_source.png"
        ... )

        >>> init_image = load_image(url)  # 从 URL 加载初始图像
        >>> prompt = "high quality image of a capybara wearing sunglasses. In the background of the image there are trees, poles, grass and other objects. At the bottom of the object there is the road., 8k, highly detailed."  # 定义生成图像的提示
        >>> image = pipe(prompt, image=init_image).images[0]  # 使用提示和初始图像生成新图像
        ```py
"""

# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 复制的函数
def retrieve_latents(  # 定义函数以检索潜在变量
    encoder_output: torch.Tensor,  # 输入的编码器输出,类型为张量
    generator: Optional[torch.Generator] = None,  # 可选的随机数生成器
    sample_mode: str = "sample"  # 采样模式,默认为“sample”
):
    # 检查 encoder_output 是否具有 latent_dist 属性,并且样本模式为 "sample"
        if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
            # 从 latent_dist 中获取样本
            return encoder_output.latent_dist.sample(generator)
        # 检查 encoder_output 是否具有 latent_dist 属性,并且样本模式为 "argmax"
        elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
            # 返回 latent_dist 的众数
            return encoder_output.latent_dist.mode()
        # 检查 encoder_output 是否具有 latents 属性
        elif hasattr(encoder_output, "latents"):
            # 返回 latents 属性的值
            return encoder_output.latents
        # 如果以上条件都不满足,抛出属性错误
        else:
            raise AttributeError("Could not access latents of provided encoder_output")
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion 中复制的
def retrieve_timesteps(
    # 调度器对象,负责时间步的管理
    scheduler,
    # 用于生成样本的推理步骤数量,默认为 None
    num_inference_steps: Optional[int] = None,
    # 指定设备类型,默认为 None
    device: Optional[Union[str, torch.device]] = None,
    # 自定义时间步列表,默认为 None
    timesteps: Optional[List[int]] = None,
    # 自定义 sigma 列表,默认为 None
    sigmas: Optional[List[float]] = None,
    # 额外的关键字参数
    **kwargs,
):
    """
    调用调度器的 `set_timesteps` 方法,并在调用后从调度器获取时间步。处理自定义时间步。
    任何关键字参数将传递给 `scheduler.set_timesteps`。

    参数:
        scheduler (`SchedulerMixin`):
            获取时间步的调度器。
        num_inference_steps (`int`):
            生成样本时使用的扩散步骤数量。如果使用,则 `timesteps` 必须为 `None`。
        device (`str` 或 `torch.device`, *可选*):
            时间步要移动到的设备。如果为 `None`,则不移动时间步。
        timesteps (`List[int]`, *可选*):
            自定义时间步,用于覆盖调度器的时间步间隔策略。如果传递了 `timesteps`,
            `num_inference_steps` 和 `sigmas` 必须为 `None`。
        sigmas (`List[float]`, *可选*):
            自定义 sigma,用于覆盖调度器的时间步间隔策略。如果传递了 `sigmas`,
            `num_inference_steps` 和 `timesteps` 必须为 `None`。

    返回:
        `Tuple[torch.Tensor, int]`: 一个元组,元组的第一个元素是调度器的时间步调度,
        第二个元素是推理步骤的数量。
    """
    # 检查是否同时传递了时间步和 sigma,若是则抛出异常
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    # 如果传递了时间步
    if timesteps is not None:
        # 检查调度器的 `set_timesteps` 方法是否接受自定义时间步
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,则抛出异常
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 `set_timesteps` 方法设置时间步
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        # 从调度器获取当前的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    # 如果传递了 sigma
    elif sigmas is not None:
        # 检查调度器的 `set_timesteps` 方法是否接受自定义 sigma
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,则抛出异常
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 `set_timesteps` 方法设置 sigma
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        # 从调度器获取当前的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    else:  # 如果不满足前面的条件,则执行此代码块
        # 设置调度器的时间步数,使用指定的设备和其他参数
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        # 获取调度器的时间步数
        timesteps = scheduler.timesteps
    # 返回时间步数和推理步数
    return timesteps, num_inference_steps
# 定义 KolorsImg2ImgPipeline 类,继承多个父类以实现文本到图像生成的功能
class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
    r"""
    使用 Kolors 进行文本到图像生成的管道。

    该模型继承自 [`DiffusionPipeline`]。请查阅父类文档以获取库为所有管道实现的通用方法(如下载、保存、在特定设备上运行等)。

    该管道还继承以下加载方法:
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] 用于加载 LoRA 权重
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] 用于保存 LoRA 权重
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] 用于加载 IP 适配器

    参数:
        vae ([`AutoencoderKL`]):
            用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
        text_encoder ([`ChatGLMModel`]):
            冻结的文本编码器。Kolors 使用 [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b)。
        tokenizer (`ChatGLMTokenizer`):
            [ChatGLMTokenizer](https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py) 类的分词器。
        unet ([`UNet2DConditionModel`]): 条件 U-Net 架构,用于去噪编码后的图像潜在表示。
        scheduler ([`SchedulerMixin`]):
            与 `unet` 结合使用的调度器,用于去噪编码的图像潜在表示。可以是
            [`DDIMScheduler`], [`LMSDiscreteScheduler`] 或 [`PNDMScheduler`]。
        force_zeros_for_empty_prompt (`bool`, *可选*, 默认为 `"False"`):
            是否始终将负提示嵌入强制设置为 0。另请参见 `Kwai-Kolors/Kolors-diffusers` 的配置。
    """

    # 定义模型在 CPU 上的卸载顺序
    model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae"
    # 定义可选组件列表
    _optional_components = [
        "image_encoder",
        "feature_extractor",
    ]
    # 定义回调张量输入列表
    _callback_tensor_inputs = [
        "latents",
        "prompt_embeds",
        "negative_prompt_embeds",
        "add_text_embeds",
        "add_time_ids",
        "negative_pooled_prompt_embeds",
        "negative_add_time_ids",
    ]

    # 初始化方法,设置管道的参数
    def __init__(
        self,
        vae: AutoencoderKL,  # VAE 模型,用于图像的编码和解码
        text_encoder: ChatGLMModel,  # 文本编码器,负责处理输入文本
        tokenizer: ChatGLMTokenizer,  # 分词器,用于将文本转换为模型可处理的格式
        unet: UNet2DConditionModel,  # U-Net 模型,用于图像去噪
        scheduler: KarrasDiffusionSchedulers,  # 调度器,控制去噪过程
        image_encoder: CLIPVisionModelWithProjection = None,  # 可选的图像编码器
        feature_extractor: CLIPImageProcessor = None,  # 可选的特征提取器
        force_zeros_for_empty_prompt: bool = False,  # 是否将空提示的负嵌入强制为 0
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 注册各个模块,包括 VAE、文本编码器、分词器、UNet、调度器、图像编码器和特征提取器
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            image_encoder=image_encoder,
            feature_extractor=feature_extractor,
        )
        # 将配置参数注册到对象中,强制为空提示时使用零
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        # 计算 VAE 的缩放因子,根据 VAE 配置的块输出通道数决定值,若无 VAE 则默认为 8
        self.vae_scale_factor = (
            2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
        )
        # 创建图像处理器,使用之前计算的 VAE 缩放因子
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

        # 获取 UNet 配置中的默认采样大小
        self.default_sample_size = self.unet.config.sample_size

    # 从 diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline 复制的 encode_prompt 方法
    def encode_prompt(
        self,
        # 提示文本输入
        prompt,
        # 可选的设备参数
        device: Optional[torch.device] = None,
        # 每个提示生成的图像数量
        num_images_per_prompt: int = 1,
        # 是否使用无分类器自由引导
        do_classifier_free_guidance: bool = True,
        # 可选的负提示文本
        negative_prompt=None,
        # 可选的提示嵌入
        prompt_embeds: Optional[torch.FloatTensor] = None,
        # 可选的池化提示嵌入
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 可选的负提示嵌入
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        # 可选的负池化提示嵌入
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 提示的最大序列长度
        max_sequence_length: int = 256,
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的 encode_image 方法
    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
        # 获取图像编码器参数的数据类型
        dtype = next(self.image_encoder.parameters()).dtype

        # 如果输入的图像不是张量,使用特征提取器将其转换为张量格式
        if not isinstance(image, torch.Tensor):
            image = self.feature_extractor(image, return_tensors="pt").pixel_values

        # 将图像转移到指定设备并转换为指定数据类型
        image = image.to(device=device, dtype=dtype)
        # 如果需要输出隐藏状态
        if output_hidden_states:
            # 获取图像编码器的隐藏状态
            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
            # 根据每个提示的图像数量重复隐藏状态
            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
            # 对无条件图像进行编码以获取隐藏状态
            uncond_image_enc_hidden_states = self.image_encoder(
                torch.zeros_like(image), output_hidden_states=True
            ).hidden_states[-2]
            # 根据每个提示的图像数量重复无条件图像的隐藏状态
            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                num_images_per_prompt, dim=0
            )
            # 返回图像编码的隐藏状态和无条件图像的隐藏状态
            return image_enc_hidden_states, uncond_image_enc_hidden_states
        else:
            # 编码图像以获取图像嵌入
            image_embeds = self.image_encoder(image).image_embeds
            # 根据每个提示的图像数量重复图像嵌入
            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
            # 创建与图像嵌入同样形状的全零张量作为无条件图像嵌入
            uncond_image_embeds = torch.zeros_like(image_embeds)

            # 返回图像嵌入和无条件图像嵌入
            return image_embeds, uncond_image_embeds

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的 prepare_ip_adapter_image_embeds 方法
    # 准备 IP 适配器图像嵌入的函数
        def prepare_ip_adapter_image_embeds(
            # 输入参数:IP 适配器图像、图像嵌入、设备、每个提示的图像数量、是否进行无分类器自由引导
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
        ):
            # 初始化图像嵌入列表
            image_embeds = []
            # 如果进行无分类器自由引导,则初始化负图像嵌入列表
            if do_classifier_free_guidance:
                negative_image_embeds = []
            # 如果图像嵌入为 None
            if ip_adapter_image_embeds is None:
                # 如果输入的图像不是列表,则将其转换为列表
                if not isinstance(ip_adapter_image, list):
                    ip_adapter_image = [ip_adapter_image]
    
                # 检查输入图像的数量与 IP 适配器数量是否一致
                if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                    raise ValueError(
                        # 抛出值错误,说明输入图像数量与 IP 适配器数量不匹配
                        f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                    )
    
                # 遍历每个单一 IP 适配器图像及其对应的图像投影层
                for single_ip_adapter_image, image_proj_layer in zip(
                    ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
                ):
                    # 判断是否输出隐藏状态
                    output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                    # 编码图像,获取单一图像嵌入和单一负图像嵌入
                    single_image_embeds, single_negative_image_embeds = self.encode_image(
                        single_ip_adapter_image, device, 1, output_hidden_state
                    )
    
                    # 将单一图像嵌入添加到嵌入列表中
                    image_embeds.append(single_image_embeds[None, :])
                    # 如果进行无分类器自由引导,则添加单一负图像嵌入
                    if do_classifier_free_guidance:
                        negative_image_embeds.append(single_negative_image_embeds[None, :])
            else:
                # 遍历已提供的图像嵌入
                for single_image_embeds in ip_adapter_image_embeds:
                    # 如果进行无分类器自由引导,则拆分负图像嵌入和正图像嵌入
                    if do_classifier_free_guidance:
                        single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                        # 将负图像嵌入添加到列表中
                        negative_image_embeds.append(single_negative_image_embeds)
                    # 将图像嵌入添加到列表中
                    image_embeds.append(single_image_embeds)
    
            # 初始化最终的 IP 适配器图像嵌入列表
            ip_adapter_image_embeds = []
            # 遍历每个图像嵌入
            for i, single_image_embeds in enumerate(image_embeds):
                # 将单一图像嵌入重复 num_images_per_prompt 次
                single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
                # 如果进行无分类器自由引导,则重复负图像嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                    # 连接负图像嵌入和正图像嵌入
                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
    
                # 将图像嵌入转移到指定设备
                single_image_embeds = single_image_embeds.to(device=device)
                # 添加处理后的图像嵌入到最终列表中
                ip_adapter_image_embeds.append(single_image_embeds)
    
            # 返回 IP 适配器图像嵌入列表
            return ip_adapter_image_embeds
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的代码
    # 准备额外参数用于调度器步骤,因为并非所有调度器具有相同的参数签名
    def prepare_extra_step_kwargs(self, generator, eta):
        # eta (η) 仅在 DDIMScheduler 中使用,其他调度器将忽略该参数
        # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
        # 应该在 [0, 1] 之间
    
        # 检查调度器步骤是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化额外步骤参数字典
        extra_step_kwargs = {}
        # 如果调度器接受 eta,添加到额外步骤参数中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta
    
        # 检查调度器是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果调度器接受 generator,添加到额外步骤参数中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回包含额外参数的字典
        return extra_step_kwargs
    
    # 检查输入参数的有效性
    def check_inputs(
        self,
        prompt,
        strength,
        num_inference_steps,
        height,
        width,
        negative_prompt=None,
        prompt_embeds=None,
        pooled_prompt_embeds=None,
        negative_prompt_embeds=None,
        negative_pooled_prompt_embeds=None,
        ip_adapter_image=None,
        ip_adapter_image_embeds=None,
        callback_on_step_end_tensor_inputs=None,
        max_sequence_length=None,
        # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps 复制
    # 获取时间步长,参数包括推理步骤数、强度、设备和去噪起始时间
        def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
            # 获取原始时间步,使用 init_timestep 计算
            if denoising_start is None:
                # 计算初始时间步,确保不超过总推理步骤数
                init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
                # 计算起始时间步,确保不小于0
                t_start = max(num_inference_steps - init_timestep, 0)
            else:
                # 如果有去噪起始时间,则从0开始
                t_start = 0
    
            # 从调度器中获取时间步,从计算的起始位置开始
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
    
            # 如果指定了去噪起始时间,则强度不再相关
            if denoising_start is not None:
                # 计算离散时间步截止值,基于去噪起始时间
                discrete_timestep_cutoff = int(
                    round(
                        self.scheduler.config.num_train_timesteps
                        - (denoising_start * self.scheduler.config.num_train_timesteps)
                    )
                )
    
                # 计算有效的推理步骤数
                num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
                # 如果调度器为二阶调度器,检查推理步骤是否为偶数
                if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
                    # 处理偶数步骤的情况,确保去噪过程结束在正确的导数步骤
                    num_inference_steps = num_inference_steps + 1
    
                # 从最后开始切片时间步
                timesteps = timesteps[-num_inference_steps:]
                # 返回时间步和有效推理步骤数
                return timesteps, num_inference_steps
    
            # 如果没有去噪起始时间,返回时间步和推理步骤数减去起始步
            return timesteps, num_inference_steps - t_start
    
        # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents 复制
        def prepare_latents(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
        # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids 复制
        def _get_add_time_ids(
            self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
    # 追加时间 ID 列表,合并原始大小、裁剪坐标和目标大小
        ):
            add_time_ids = list(original_size + crops_coords_top_left + target_size)
    
            # 计算通过模型配置生成的嵌入维度
            passed_add_embed_dim = (
                self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
            )
            # 获取模型期望的嵌入维度
            expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
    
            # 检查生成的嵌入维度与期望值是否匹配
            if expected_add_embed_dim != passed_add_embed_dim:
                # 如果不匹配,抛出值错误
                raise ValueError(
                    f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
                )
    
            # 将追加时间 ID 转换为张量,并指定数据类型
            add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
            # 返回处理后的追加时间 ID
            return add_time_ids
    
        # 从 diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae 复制的函数
        def upcast_vae(self):
            # 获取 VAE 的数据类型
            dtype = self.vae.dtype
            # 将 VAE 转换为 float32 数据类型
            self.vae.to(dtype=torch.float32)
            # 检查是否使用了 Torch 2.0 或 Xformers 处理器
            use_torch_2_0_or_xformers = isinstance(
                self.vae.decoder.mid_block.attentions[0].processor,
                (
                    AttnProcessor2_0,
                    XFormersAttnProcessor,
                    FusedAttnProcessor2_0,
                ),
            )
            # 如果使用了 Xformers 或 Torch 2.0,注意力块不需要为 float32,从而节省内存
            if use_torch_2_0_or_xformers:
                # 将后量化卷积层转换为指定数据类型
                self.vae.post_quant_conv.to(dtype)
                # 将输入卷积层转换为指定数据类型
                self.vae.decoder.conv_in.to(dtype)
                # 将中间块转换为指定数据类型
                self.vae.decoder.mid_block.to(dtype)
    
        # 从 diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding 复制的函数
        def get_guidance_scale_embedding(
            # 输入张量和嵌入维度,默认为 512,数据类型默认为 float32
            self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
    # 此函数返回生成的嵌入向量,具有指定的引导尺度
        ) -> torch.Tensor:
            """
            参见指定链接以获取详细文档
    
            参数:
                w (`torch.Tensor`):
                    使用指定的引导尺度生成嵌入向量,以后用于丰富时间步嵌入。
                embedding_dim (`int`, *可选*, 默认值为 512):
                    要生成的嵌入的维度。
                dtype (`torch.dtype`, *可选*, 默认值为 `torch.float32`):
                    生成的嵌入的数据类型。
    
            返回:
                `torch.Tensor`: 形状为 `(len(w), embedding_dim)` 的嵌入向量。
            """
            # 确保输入张量 w 的形状是一维的
            assert len(w.shape) == 1
            # 将 w 的值放大 1000 倍
            w = w * 1000.0
    
            # 计算嵌入维度的一半
            half_dim = embedding_dim // 2
            # 计算用于缩放的对数值
            emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
            # 生成半个维度的嵌入值并取指数
            emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
            # 根据输入 w 转换数据类型并扩展维度
            emb = w.to(dtype)[:, None] * emb[None, :]
            # 将正弦和余弦值串联在一起
            emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
            # 如果嵌入维度是奇数,则填充零
            if embedding_dim % 2 == 1:  # zero pad
                emb = torch.nn.functional.pad(emb, (0, 1))
            # 确保最终嵌入的形状符合预期
            assert emb.shape == (w.shape[0], embedding_dim)
            # 返回生成的嵌入
            return emb
    
        # 返回当前的引导尺度
        @property
        def guidance_scale(self):
            return self._guidance_scale
    
        # 此属性定义与 Imagen 论文中的引导权重 w 类似的 `guidance_scale`
        # `guidance_scale = 1` 表示不进行无分类器引导
        @property
        def do_classifier_free_guidance(self):
            return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
    
        # 返回当前的交叉注意力参数
        @property
        def cross_attention_kwargs(self):
            return self._cross_attention_kwargs
    
        # 返回去噪的起始点
        @property
        def denoising_start(self):
            return self._denoising_start
    
        # 返回去噪的结束点
        @property
        def denoising_end(self):
            return self._denoising_end
    
        # 返回时间步的数量
        @property
        def num_timesteps(self):
            return self._num_timesteps
    
        # 返回中断标志
        @property
        def interrupt(self):
            return self._interrupt
    
        # 禁用梯度计算并替换示例文档字符串
        @torch.no_grad()
        @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用的方法,接受多个参数用于处理图像生成任务
        def __call__(
            # 提示文本,可以是单个字符串或字符串列表
            self,
            prompt: Union[str, List[str]] = None,
            # 输入图像,用于生成的图像输入
            image: PipelineImageInput = None,
            # 控制生成强度的参数,默认为0.3
            strength: float = 0.3,
            # 输出图像的高度,可选
            height: Optional[int] = None,
            # 输出图像的宽度,可选
            width: Optional[int] = None,
            # 推理步骤的数量,默认为50
            num_inference_steps: int = 50,
            # 指定时间步列表,可选
            timesteps: List[int] = None,
            # 指定 sigma 值的列表,可选
            sigmas: List[float] = None,
            # 去噪开始的值,可选
            denoising_start: Optional[float] = None,
            # 去噪结束的值,可选
            denoising_end: Optional[float] = None,
            # 指导比例,默认为5.0
            guidance_scale: float = 5.0,
            # 负面提示文本,可选,可以是字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为1
            num_images_per_prompt: Optional[int] = 1,
            # 影响生成的 eta 值,默认为0.0
            eta: float = 0.0,
            # 随机数生成器,可选,可以是单个生成器或生成器列表
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,可选,形状为张量
            latents: Optional[torch.Tensor] = None,
            # 提示嵌入,可选,形状为张量
            prompt_embeds: Optional[torch.Tensor] = None,
            # 池化的提示嵌入,可选,形状为张量
            pooled_prompt_embeds: Optional[torch.Tensor] = None,
            # 负面提示嵌入,可选,形状为张量
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 负面池化的提示嵌入,可选,形状为张量
            negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
            # 输入适配器图像,可选
            ip_adapter_image: Optional[PipelineImageInput] = None,
            # 输入适配器图像嵌入,可选,张量列表
            ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
            # 输出类型,可选,默认为“pil”
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为True
            return_dict: bool = True,
            # 交叉注意力参数,可选,字典类型
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 原始图像的尺寸,可选,元组类型
            original_size: Optional[Tuple[int, int]] = None,
            # 图像裁剪的左上角坐标,默认为(0, 0)
            crops_coords_top_left: Tuple[int, int] = (0, 0),
            # 目标尺寸,可选,元组类型
            target_size: Optional[Tuple[int, int]] = None,
            # 负面原始图像的尺寸,可选,元组类型
            negative_original_size: Optional[Tuple[int, int]] = None,
            # 负面图像裁剪的左上角坐标,默认为(0, 0)
            negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
            # 负面目标尺寸,可选,元组类型
            negative_target_size: Optional[Tuple[int, int]] = None,
            # 步骤结束时的回调函数,可选,支持多种类型
            callback_on_step_end: Optional[
                Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
            ] = None,
            # 结束步骤的张量输入列表,默认为["latents"]
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 最大序列长度,默认为256
            max_sequence_length: int = 256,

.\diffusers\pipelines\kolors\pipeline_output.py

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 从 typing 模块导入 List 和 Union 类型,用于类型注解
from typing import List, Union

# 导入 numpy 库,通常用于数组操作
import numpy as np
# 导入 PIL.Image 模块,用于处理图像
import PIL.Image

# 从上级模块导入 BaseOutput 类,可能是用于输出的基类
from ...utils import BaseOutput

# 定义 KolorsPipelineOutput 类,继承自 BaseOutput
@dataclass
class KolorsPipelineOutput(BaseOutput):
    """
    Kolors 管道输出类。

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
            图像列表,包含去噪后的 PIL 图像,长度为 `batch_size` 或形状为 `(batch_size, height, width,
            num_channels)` 的 numpy 数组。PIL 图像或 numpy 数组表示扩散管道的去噪图像。
    """

    # 定义 images 属性,可以是 PIL 图像列表或 numpy 数组
    images: Union[List[PIL.Image.Image], np.ndarray]

.\diffusers\pipelines\kolors\text_encoder.py

# 版权声明,说明该代码的版权归属
# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 只能在遵守许可证的情况下使用此文件
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面协议另有约定,软件按“现状”分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何形式的明示或暗示的担保或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取特定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入数学模块,提供数学函数支持
import math
# 导入类型注解支持
from typing import List, Optional, Tuple

# 导入 PyTorch 库,用于深度学习
import torch
# 导入 PyTorch 的功能模块,包含常用的激活函数和损失函数
import torch.nn.functional as F
# 导入 PyTorch 的神经网络模块
from torch import nn
# 导入 LayerNorm 归一化层
from torch.nn import LayerNorm
# 导入 PyTorch 的 skip_init 功能,用于初始化神经网络参数
from torch.nn.utils import skip_init
# 导入预训练模型配置和模型类
from transformers import PretrainedConfig, PreTrainedModel
# 导入包含模型输出的类
from transformers.modeling_outputs import BaseModelOutputWithPast

# 从 utils 模块中导入日志记录功能
from ...utils import logging

# 创建日志记录器实例,使用当前模块名称
logger = logging.get_logger(__name__)

# 定义 ChatGLMConfig 类,继承自 PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
    # 指定模型类型为 "chatglm"
    model_type = "chatglm"

    # 初始化函数,定义模型的各种超参数
    def __init__(
        # 定义默认层数为 28
        num_layers=28,
        # 定义填充词汇表的大小
        padded_vocab_size=65024,
        # 定义隐藏层的大小
        hidden_size=4096,
        # 定义前馈网络的隐藏层大小
        ffn_hidden_size=13696,
        # 定义键值通道的数量
        kv_channels=128,
        # 定义注意力头的数量
        num_attention_heads=32,
        # 定义序列长度
        seq_length=2048,
        # 定义隐藏层的丢弃率
        hidden_dropout=0.0,
        # 定义分类器的丢弃率(可选)
        classifier_dropout=None,
        # 定义注意力的丢弃率
        attention_dropout=0.0,
        # 定义 LayerNorm 的 epsilon 值
        layernorm_epsilon=1e-5,
        # 定义是否使用 RMSNorm
        rmsnorm=True,
        # 定义残差连接是否在 LayerNorm 后应用
        apply_residual_connection_post_layernorm=False,
        # 定义是否使用后层归一化
        post_layer_norm=True,
        # 定义是否添加线性层的偏置
        add_bias_linear=False,
        # 定义是否添加 QKV 的偏置
        add_qkv_bias=False,
        # 定义是否融合偏置和丢弃操作
        bias_dropout_fusion=True,
        # 定义是否使用多查询注意力
        multi_query_attention=False,
        # 定义多查询的组数量
        multi_query_group_num=1,
        # 定义是否应用查询键层缩放
        apply_query_key_layer_scaling=True,
        # 定义是否在 FP32 中执行 softmax 操作
        attention_softmax_in_fp32=True,
        # 定义是否在残差连接中使用 FP32
        fp32_residual_connection=False,
        # 定义量化位数
        quantization_bit=0,
        # 定义预序列长度(可选)
        pre_seq_len=None,
        # 定义是否进行前缀投影
        prefix_projection=False,
        # 接受其他关键字参数
        **kwargs,
    ):
        # 设置网络层的数量
        self.num_layers = num_layers
        # 设置词汇表的大小
        self.vocab_size = padded_vocab_size
        # 设置填充后的词汇表大小
        self.padded_vocab_size = padded_vocab_size
        # 设置隐藏层的大小
        self.hidden_size = hidden_size
        # 设置前馈网络的隐藏层大小
        self.ffn_hidden_size = ffn_hidden_size
        # 设置键值通道的数量
        self.kv_channels = kv_channels
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置序列的长度
        self.seq_length = seq_length
        # 设置隐藏层的 dropout 概率
        self.hidden_dropout = hidden_dropout
        # 设置分类器的 dropout 概率
        self.classifier_dropout = classifier_dropout
        # 设置注意力层的 dropout 概率
        self.attention_dropout = attention_dropout
        # 设置层归一化的 epsilon 值
        self.layernorm_epsilon = layernorm_epsilon
        # 是否使用 RMSNorm 进行归一化
        self.rmsnorm = rmsnorm
        # 是否在层归一化后应用残差连接
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        # 是否使用后层归一化
        self.post_layer_norm = post_layer_norm
        # 是否在线性变换中添加偏置项
        self.add_bias_linear = add_bias_linear
        # 是否在 Q、K、V 中添加偏置项
        self.add_qkv_bias = add_qkv_bias
        # 是否融合偏置和 dropout 操作
        self.bias_dropout_fusion = bias_dropout_fusion
        # 是否使用多查询注意力机制
        self.multi_query_attention = multi_query_attention
        # 设置多查询组的数量
        self.multi_query_group_num = multi_query_group_num
        # 是否在查询和键的层之间应用缩放
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        # 是否在计算 softmax 时使用 FP32 精度
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        # 是否在残差连接中使用 FP32 精度
        self.fp32_residual_connection = fp32_residual_connection
        # 设置量化位数
        self.quantization_bit = quantization_bit
        # 设置前序列的长度
        self.pre_seq_len = pre_seq_len
        # 是否使用前缀投影
        self.prefix_projection = prefix_projection
        # 调用父类的构造函数
        super().__init__(**kwargs)
# RMSNorm 类,继承自 PyTorch 的 Module 类
class RMSNorm(torch.nn.Module):
    # 初始化方法,接收标准化形状、epsilon、设备、数据类型及其他参数
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        # 调用父类的初始化方法
        super().__init__()
        # 创建可学习的权重参数,初始化为空的张量
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        # 设置 epsilon 值
        self.eps = eps

    # 前向传播方法,接收隐藏状态张量
    def forward(self, hidden_states: torch.Tensor):
        # 获取输入张量的数据类型
        input_dtype = hidden_states.dtype
        # 计算输入张量的方差
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        # 归一化隐藏状态张量
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        # 返回加权后的隐藏状态,转换回原始数据类型
        return (self.weight * hidden_states).to(input_dtype)


# 将配置对象转换为关键字参数的辅助函数
def _config_to_kwargs(args):
    # 创建包含数据类型的通用关键字参数字典
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    # 返回关键字参数字典
    return common_kwargs


# CoreAttention 类,继承自 PyTorch 的 Module 类
class CoreAttention(torch.nn.Module):
    # 初始化方法,接收配置和层编号
    def __init__(self, config: ChatGLMConfig, layer_number):
        # 调用父类的初始化方法
        super(CoreAttention, self).__init__()

        # 从配置中获取查询-键层的缩放应用标志
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        # 获取软最大值是否在 FP32 中的配置
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        # 如果应用查询-键层缩放,强制将软最大值设置为 FP32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        # 确保层编号至少为 1
        self.layer_number = max(1, layer_number)

        # 计算投影大小
        projection_size = config.kv_channels * config.num_attention_heads

        # 每个注意力头和每个分区的值
        self.hidden_size_per_partition = projection_size
        # 每个注意力头的隐藏大小
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        # 每个分区的注意力头数量
        self.num_attention_heads_per_partition = config.num_attention_heads

        # 初始化系数为 None
        coeff = None
        # 计算归一化因子
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        # 如果应用查询-键层缩放,更新系数和归一化因子
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff
        # 保存系数
        self.coeff = coeff

        # 初始化注意力 dropout
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

# 按最后一个维度拆分张量的函数
def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """拆分张量的最后一个维度。

    参数:
        tensor: 输入张量。
        num_partitions: 拆分张量的分区数量
        contiguous_split_chunks: 如果为 True,使每个块在内存中连续。

    返回:
        张量列表
    """
    # 获取张量的最后一维索引
    last_dim = tensor.dim() - 1
    # 计算每个分区的最后一维大小
    last_dim_size = tensor.size()[last_dim] // num_partitions
    # 拆分张量
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # 注意:torch.split 默认不会创建连续的张量。
    if contiguous_split_chunks:
        # 返回每个块的连续张量
        return tuple(chunk.contiguous() for chunk in tensor_list)

    # 返回拆分后的张量列表
    return tensor_list


# 应用旋转位置嵌入的 JIT 编译函数
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    # 获取输入张量的尺寸
    sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3)
    # 计算旋转维度
    rot_dim = rope_cache.shape[-2] * 2
    # 拆分输入张量,保留旋转维度部分和其余部分
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # 截断以支持可变大小
        rope_cache = rope_cache[:sq]
        # 重塑 x 为指定形状,-1 表示自动推断维度
        xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
        # 将 rope_cache 视图转换为新的形状,以便与 xshaped 对齐
        rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
        # 计算输出 x_out2,应用旋转公式
        x_out2 = torch.stack(
            [
                xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
                xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
            ],
            -1,
        )
        # 将 x_out2 在维度 3 上展平
        x_out2 = x_out2.flatten(3)
        # 将 x_out2 和 x_pass 在最后一个维度上连接
        return torch.cat((x_out2, x_pass), dim=-1)
# 自注意力层抽象类,继承自 PyTorch 的模块
class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h] and returns output of the same size.
    """

    # 初始化方法,接受配置、层数和设备参数
    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        # 调用父类构造函数
        super(SelfAttention, self).__init__()
        # 确保层数至少为 1
        self.layer_number = max(1, layer_number)

        # 计算投影大小
        self.projection_size = config.kv_channels * config.num_attention_heads

        # 每个注意力头和每个分区的值
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        # 是否使用多查询注意力
        self.multi_query_attention = config.multi_query_attention
        # QKV 隐藏层大小
        self.qkv_hidden_size = 3 * self.projection_size
        # 如果使用多查询注意力,调整 QKV 隐藏层大小
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            self.qkv_hidden_size = (
                self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        # 定义线性层以获取查询、键和值
        self.query_key_value = nn.Linear(
            config.hidden_size,
            self.qkv_hidden_size,
            bias=config.add_bias_linear or config.add_qkv_bias,
            device=device,
            **_config_to_kwargs(config),
        )

        # 核心注意力模块
        self.core_attention = CoreAttention(config, self.layer_number)

        # 输出线性层
        self.dense = nn.Linear(
            self.projection_size,
            config.hidden_size,
            bias=config.add_bias_linear,
            device=device,
            **_config_to_kwargs(config),
        )

    # 分配内存的方法
    def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
        # 根据是否使用多查询注意力确定注意力头数量
        if self.multi_query_attention:
            num_attention_heads = self.num_multi_query_groups_per_partition
        else:
            num_attention_heads = self.num_attention_heads_per_partition
        # 返回一个空的张量以存储注意力结果
        return torch.empty(
            inference_max_sequence_len,
            batch_size,
            num_attention_heads,
            self.hidden_size_per_attention_head,
            dtype=dtype,
            device=device,
        )

# 多层感知机类,继承自 PyTorch 的模块
class MLP(torch.nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation,
    and project the state back into h hidden dimension.
    """
    # 初始化 MLP 类,接收配置和设备参数
        def __init__(self, config: ChatGLMConfig, device=None):
            # 调用父类的初始化方法
            super(MLP, self).__init__()
    
            # 设置是否在线性层中添加偏置
            self.add_bias = config.add_bias_linear
    
            # 创建一个线性层,将输入维度投影到 4h 维度,如果使用 swiglu,输出宽度翻倍
            self.dense_h_to_4h = nn.Linear(
                config.hidden_size,
                config.ffn_hidden_size * 2,
                bias=self.add_bias,
                device=device,
                **_config_to_kwargs(config),
            )
    
            # 定义 swiglu 激活函数
            def swiglu(x):
                # 将输入张量分为两部分
                x = torch.chunk(x, 2, dim=-1)
                # 返回第一部分的 silu 激活值乘以第二部分
                return F.silu(x[0]) * x[1]
    
            # 设置激活函数为 swiglu
            self.activation_func = swiglu
    
            # 创建另一个线性层,将 4h 维度投影回原始 h 维度
            self.dense_4h_to_h = nn.Linear(
                config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
            )
    
        # 前向传播方法,接收隐藏状态作为输入
        def forward(self, hidden_states):
            # 通过第一层线性层处理输入,得到中间结果
            intermediate_parallel = self.dense_h_to_4h(hidden_states)
            # 应用激活函数于中间结果
            intermediate_parallel = self.activation_func(intermediate_parallel)
            # 通过第二层线性层得到最终输出
            output = self.dense_4h_to_h(intermediate_parallel)
            # 返回输出结果
            return output
# 定义单个变换器层的类
class GLMBlock(torch.nn.Module):
    """单个变换器层。

    变换器层接受大小为 [s, b, h] 的输入并返回相同大小的输出。
    """

    # 初始化方法,接收配置、层编号和设备参数
    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        # 调用父类构造函数
        super(GLMBlock, self).__init__()
        # 设置当前层的编号
        self.layer_number = layer_number

        # 是否在层归一化后应用残差连接
        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm

        # 是否使用 FP32 残差连接
        self.fp32_residual_connection = config.fp32_residual_connection

        # 根据配置选择归一化函数
        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # 对输入数据进行层归一化
        self.input_layernorm = LayerNormFunc(
            config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
        )

        # 自注意力机制
        self.self_attention = SelfAttention(config, layer_number, device=device)
        # 隐藏层的 dropout 概率
        self.hidden_dropout = config.hidden_dropout

        # 对注意力输出进行层归一化
        self.post_attention_layernorm = LayerNormFunc(
            config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
        )

        # 多层感知机
        self.mlp = MLP(config, device=device)

    # 前向传播方法
    def forward(
        self,
        hidden_states,
        attention_mask,
        rotary_pos_emb,
        kv_cache=None,
        use_cache=True,
    ):
        # hidden_states: [s, b, h]

        # 在变换器层开始进行层归一化
        layernorm_output = self.input_layernorm(hidden_states)
        # 自注意力计算
        attention_output, kv_cache = self.self_attention(
            layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache
        )

        # 残差连接
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # 对注意力输出进行 dropout,并准备进行层归一化输入
        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input

        # 在自注意力后进行层归一化
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # 多层感知机计算
        mlp_output = self.mlp(layernorm_output)

        # 第二次残差连接
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        # 对多层感知机输出进行 dropout,并完成最终输出
        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output

        # 返回输出和键值缓存
        return output, kv_cache


# 定义变换器类
class GLMTransformer(torch.nn.Module):
    """变换器类。"""
    # 初始化方法,接受配置和设备参数
        def __init__(self, config: ChatGLMConfig, device=None):
            # 调用父类初始化方法
            super(GLMTransformer, self).__init__()
    
            # 设置浮点32位残差连接选项
            self.fp32_residual_connection = config.fp32_residual_connection
            # 设置后层归一化选项
            self.post_layer_norm = config.post_layer_norm
    
            # 设置层数
            self.num_layers = config.num_layers
    
            # 定义构建层的方法
            def build_layer(layer_number):
                # 创建并返回 GLMBlock 层实例
                return GLMBlock(config, layer_number, device=device)
    
            # 生成指定数量的层并加入模块列表
            self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
    
            # 如果启用后层归一化
            if self.post_layer_norm:
                # 选择归一化方法
                LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
                # 创建最终的层归一化层
                self.final_layernorm = LayerNormFunc(
                    config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
                )
    
            # 初始化梯度检查点开关
            self.gradient_checkpointing = False
    
        # 获取指定层的方法
        def _get_layer(self, layer_number):
            # 返回指定层的实例
            return self.layers[layer_number]
    
        # 前向传播方法
        def forward(
            self,
            hidden_states,
            attention_mask,
            rotary_pos_emb,
            kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
        ):
            # 如果未提供 kv_caches,初始化为 None 列表
            if not kv_caches:
                kv_caches = [None for _ in range(self.num_layers)]
            # 根据 use_cache 设置 presents 为元组或 None
            presents = () if use_cache else None
            # 如果启用梯度检查点且处于训练模式
            if self.gradient_checkpointing and self.training:
                # 如果使用缓存,记录警告并禁用缓存
                if use_cache:
                    logger.warning_once(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False
    
            # 初始化存储自注意力和隐藏状态的变量
            all_self_attentions = None
            all_hidden_states = () if output_hidden_states else None
            # 遍历每一层
            for index in range(self.num_layers):
                # 如果输出隐藏状态,记录当前隐藏状态
                if output_hidden_states:
                    all_hidden_states = all_hidden_states + (hidden_states,)
    
                # 获取当前层
                layer = self._get_layer(index)
                # 如果启用梯度检查点且处于训练模式
                if self.gradient_checkpointing and self.training:
                    # 使用检查点计算层的输出
                    layer_ret = torch.utils.checkpoint.checkpoint(
                        layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
                    )
                else:
                    # 正常计算层的输出
                    layer_ret = layer(
                        hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache
                    )
                # 解包层输出
                hidden_states, kv_cache = layer_ret
                # 如果使用缓存,记录 kv_cache
                if use_cache:
                    presents = presents + (kv_cache,)
    
            # 如果输出隐藏状态,记录最后的隐藏状态
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
    
            # 如果启用后层归一化
            if self.post_layer_norm:
                # 应用最终的层归一化
                hidden_states = self.final_layernorm(hidden_states)
    
            # 返回最终的隐藏状态、缓存和所有隐藏状态及自注意力
            return hidden_states, presents, all_hidden_states, all_self_attentions
# 定义一个抽象类,用于处理权重初始化和下载、加载预训练模型的接口
class ChatGLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    # 设置是否可并行化
    is_parallelizable = False
    # 支持梯度检查点
    supports_gradient_checkpointing = True
    # 配置类
    config_class = ChatGLMConfig
    # 基础模型前缀
    base_model_prefix = "transformer"
    # 不可拆分的模块列表
    _no_split_modules = ["GLMBlock"]

    # 初始化权重的方法
    def _init_weights(self, module: nn.Module):
        """Initialize the weights."""
        return

    # 获取掩码的方法
    def get_masks(self, input_ids, past_key_values, padding_mask=None):
        # 获取输入的批次大小和序列长度
        batch_size, seq_length = input_ids.shape
        # 创建一个全为1的注意力掩码
        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
        # 只保留下三角部分的注意力掩码
        full_attention_mask.tril_()
        past_length = 0
        # 如果有过去的键值,获取其长度
        if past_key_values:
            past_length = past_key_values[0][0].shape[0]
        # 如果有过去的长度,拼接全为1的掩码
        if past_length:
            full_attention_mask = torch.cat(
                (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
            )
        # 如果提供了填充掩码,更新注意力掩码
        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        # 如果没有过去的长度并且有填充掩码,调整掩码
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1
        # 将掩码转换为布尔值
        full_attention_mask = (full_attention_mask < 0.5).bool()
        # 增加维度
        full_attention_mask.unsqueeze_(1)
        return full_attention_mask

    # 获取位置ID的方法
    def get_position_ids(self, input_ids, device):
        # 获取批次大小和序列长度
        batch_size, seq_length = input_ids.shape
        # 创建位置ID
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
        return position_ids

    # 设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块是GLMTransformer,设置其梯度检查点
        if isinstance(module, GLMTransformer):
            module.gradient_checkpointing = value


# 默认初始化方法
def default_init(cls, *args, **kwargs):
    # 创建类的实例
    return cls(*args, **kwargs)


# 定义一个嵌入类
class Embedding(torch.nn.Module):
    """Language model embeddings."""

    # 初始化方法
    def __init__(self, config: ChatGLMConfig, device=None):
        super(Embedding, self).__init__()

        # 获取隐藏层大小
        self.hidden_size = config.hidden_size
        # 创建词嵌入层
        self.word_embeddings = nn.Embedding(
            config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
        )
        # 获取是否使用fp32残差连接的配置
        self.fp32_residual_connection = config.fp32_residual_connection

    # 前向传播方法
    def forward(self, input_ids):
        # 获取词嵌入
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        # 转置数据格式以避免显式转置
        embeddings = embeddings.transpose(0, 1).contiguous()
        # 如果启用fp32残差连接,转换为浮点数
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings


# 定义一个旋转嵌入类
class RotaryEmbedding(nn.Module):
    # 初始化方法,用于设置类的初始状态
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        # 调用父类的初始化方法
        super().__init__()
        # 计算逆频率,公式为 1/(10000^(2i/d)),用于位置编码
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        # 将计算得到的逆频率注册为缓冲区,以便在模型保存和加载时保持
        self.register_buffer("inv_freq", inv_freq)
        # 存储维度参数
        self.dim = dim
        # 存储原始实现的标志
        self.original_impl = original_impl

    # 前向传播实现方法,计算位置编码
    def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
        """增强型变换器,带有旋转位置嵌入。

        来源: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT 许可证:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # 计算位置编码的 theta 值,公式为 1/(base^(2(i-1)/n_elem))
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

        # 创建位置索引,范围为 [0, 1, ..., seq_len - 1]
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

        # 计算位置索引与 theta 的外积,生成位置编码
        idx_theta = torch.outer(seq_idx, theta).float()

        # 计算缓存,将余弦和正弦值按最后一个维度堆叠
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # 为了模拟 complex32 的行为,避免不同结果,进行数据类型转换
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            # 如果 dtype 为 bfloat16,则转换缓存为 bfloat16;否则转换为 half
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        # 返回计算得到的缓存
        return cache

    # 前向传播方法,接收最大序列长度和偏移量
    def forward(self, max_seq_len, offset=0):
        # 调用 forward_impl 方法,传入相应参数
        return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
# 定义一个前缀编码器类,继承自 PyTorch 的 nn.Module
class PrefixEncoder(torch.nn.Module):
    """
    前缀编码的 PyTorch nn 模型 输入形状: (batch-size, prefix-length) 输出形状: (batch-size,
    prefix-length, 2*layers*hidden)
    """

    # 初始化函数,接受配置对象
    def __init__(self, config: ChatGLMConfig):
        # 调用父类初始化
        super().__init__()
        # 获取前缀投影的配置
        self.prefix_projection = config.prefix_projection
        # 如果启用了前缀投影
        if self.prefix_projection:
            # 使用两层 MLP 编码前缀
            kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
            # 创建嵌入层,输出维度为 kv_size
            self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
            # 创建一个顺序的网络结构,包括两层线性变换和一个 Tanh 激活函数
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(kv_size, config.hidden_size),  # 第一层线性变换
                torch.nn.Tanh(),  # Tanh 激活函数
                torch.nn.Linear(config.hidden_size, kv_size),  # 第二层线性变换
            )
        else:
            # 如果没有前缀投影,直接创建嵌入层,输出维度为 num_layers * kv_channels * multi_query_group_num * 2
            self.embedding = torch.nn.Embedding(
                config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
            )

    # 前向传播函数,接受前缀张量
    def forward(self, prefix: torch.Tensor):
        # 如果启用了前缀投影
        if self.prefix_projection:
            # 使用嵌入层对前缀进行编码
            prefix_tokens = self.embedding(prefix)
            # 通过转换网络得到过去的键值对
            past_key_values = self.trans(prefix_tokens)
        else:
            # 直接通过嵌入层得到过去的键值对
            past_key_values = self.embedding(prefix)
        # 返回过去的键值对
        return past_key_values


# 定义 ChatGLM 模型类,继承自 ChatGLMPreTrainedModel
class ChatGLMModel(ChatGLMPreTrainedModel):
    # 初始化函数,接受配置对象、设备和空初始化标志
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        # 调用父类初始化
        super().__init__(config)
        # 根据空初始化标志选择初始化方法
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        # 如果指定了设备,将设备信息添加到初始化参数
        if device is not None:
            init_kwargs["device"] = device
        # 初始化嵌入层
        self.embedding = init_method(Embedding, config, **init_kwargs)
        # 保存层数、查询组数和键值通道数
        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels

        # 旋转位置嵌入的序列长度
        self.seq_length = config.seq_length
        # 计算旋转嵌入的维度
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )

        # 创建旋转嵌入对象
        self.rotary_pos_emb = RotaryEmbedding(
            rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
        )
        # 初始化 GLMTransformer 编码器
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)
        # 初始化输出层,线性变换
        self.output_layer = init_method(
            nn.Linear,
            config.hidden_size,
            config.padded_vocab_size,
            bias=False,
            dtype=config.torch_dtype,
            **init_kwargs,
        )
        # 获取前缀序列长度
        self.pre_seq_len = config.pre_seq_len
        # 获取前缀投影的配置
        self.prefix_projection = config.prefix_projection
        # 如果前缀序列长度不为空
        if self.pre_seq_len is not None:
            # 将所有参数的梯度计算标志设置为 False
            for param in self.parameters():
                param.requires_grad = False
            # 创建前缀 token 的张量
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
            # 创建前缀编码器对象
            self.prefix_encoder = PrefixEncoder(config)
            # 创建 dropout 层,丢弃率为 0.1
            self.dropout = torch.nn.Dropout(0.1)
    # 获取输入的嵌入层
        def get_input_embeddings(self):
            # 返回嵌入层中的单词嵌入
            return self.embedding.word_embeddings
    
        # 获取提示信息,供模型使用
        def get_prompt(self, batch_size, device, dtype=torch.half):
            # 扩展前缀标记以匹配批量大小,并移动到指定设备
            prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
            # 通过前缀编码器处理前缀标记并转换为指定数据类型
            past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
            # 将处理后的数据重塑为特定形状以符合模型需求
            past_key_values = past_key_values.view(
                batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
            )
            # 应用丢弃层以防止过拟合
            past_key_values = self.dropout(past_key_values)
            # 调整维度顺序并分割为多个张量以供后续使用
            past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
            # 返回处理后的过去键值对
            return past_key_values
    
        # 定义前向传播方法
        def forward(
            self,
            input_ids,
            # 可选的位置 ID,用于编码输入的位置信息
            position_ids: Optional[torch.Tensor] = None,
            # 可选的注意力掩码,用于屏蔽输入中的无效位置
            attention_mask: Optional[torch.BoolTensor] = None,
            # 完整注意力掩码,用于更复杂的注意力机制
            full_attention_mask: Optional[torch.BoolTensor] = None,
            # 可选的过去键值对,用于缓存上一次的计算结果
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            # 可选的输入嵌入,替代 input_ids 使用
            inputs_embeds: Optional[torch.Tensor] = None,
            # 可选的缓存使用标志
            use_cache: Optional[bool] = None,
            # 可选的隐藏状态输出标志
            output_hidden_states: Optional[bool] = None,
            # 可选的返回字典的标志
            return_dict: Optional[bool] = None,
    ):
        # 如果 output_hidden_states 没有指定,则使用配置中的默认值
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # 如果 use_cache 没有指定,则使用配置中的默认值
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        # 如果 return_dict 没有指定,则使用配置中的默认值
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 获取输入的 batch_size 和 seq_length
        batch_size, seq_length = input_ids.shape

        # 如果没有输入的嵌入,则使用输入的 ID 生成嵌入
        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        # 检查预序列长度
        if self.pre_seq_len is not None:
            # 如果过去的键值对为空,则获取提示信息
            if past_key_values is None:
                past_key_values = self.get_prompt(
                    batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
                )
            # 如果存在注意力掩码,则将预序列的掩码添加到前面
            if attention_mask is not None:
                attention_mask = torch.cat(
                    [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
                )

        # 如果全注意力掩码为空
        if full_attention_mask is None:
            # 如果存在注意力掩码且不是全为 1,或者过去的键值对存在且序列长度不为 1,则获取掩码
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

        # 计算旋转位置嵌入
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        # 如果指定了位置 ID,则根据位置 ID 索引旋转嵌入
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            # 如果没有位置 ID,则使用序列长度生成旋转嵌入
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]
        # 转置并确保连续性
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        # 运行编码器
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds,
            full_attention_mask,
            rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
        )

        # 如果不返回字典,则返回非 None 的元组
        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        # 返回包含隐藏状态、过去的键值、所有隐藏状态和注意力的自定义输出对象
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

.\diffusers\pipelines\kolors\tokenizer.py

# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# 许可信息,声明版权和许可证条款
# Licensed under the Apache License, Version 2.0 (the "License");
# 在遵守许可证的前提下才能使用此文件
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 在没有适用的法律或书面协议情况下,软件以“按现状”方式分发
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以获取具体的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入所需的库
import json
import os
import re
from typing import Dict, List, Optional, Union

# 从 SentencePiece 导入处理器
from sentencepiece import SentencePieceProcessor
# 从 transformers 导入预训练的 tokenizer
from transformers import PreTrainedTokenizer
# 导入批处理编码和编码输入的工具
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
# 导入填充策略
from transformers.utils import PaddingStrategy

# 定义 SPTokenizer 类
class SPTokenizer:
    # 初始化函数,接收模型路径
    def __init__(self, model_path: str):
        # 断言模型文件存在
        assert os.path.isfile(model_path), model_path
        # 通过模型文件加载 SentencePiece 处理器
        self.sp_model = SentencePieceProcessor(model_file=model_path)

        # 获取 BOS / EOS token 的 ID
        self.n_words: int = self.sp_model.vocab_size()  # 词汇表大小
        self.bos_id: int = self.sp_model.bos_id()      # BOS token ID
        self.eos_id: int = self.sp_model.eos_id()      # EOS token ID
        self.pad_id: int = self.sp_model.unk_id()      # PAD token ID
        # 确保词汇表大小与片段大小相同
        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

        # 定义角色特定的特殊 tokens
        role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
        # 定义其他特殊 tokens
        special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
        # 初始化特殊 tokens 的字典和索引
        self.special_tokens = {}
        self.index_special_tokens = {}
        # 为每个特殊 token 分配唯一的 ID
        for token in special_tokens:
            self.special_tokens[token] = self.n_words
            self.index_special_tokens[self.n_words] = token
            self.n_words += 1
        # 将角色特殊 tokens 组成正则表达式
        self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])

    # 对输入字符串进行分词
    def tokenize(self, s: str, encode_special_tokens=False):
        # 如果需要编码特殊 tokens
        if encode_special_tokens:
            last_index = 0
            t = []
            # 查找匹配的角色特殊 tokens
            for match in re.finditer(self.role_special_token_expression, s):
                # 如果有普通文本,先编码它
                if last_index < match.start():
                    t.extend(self.sp_model.EncodeAsPieces(s[last_index : match.start()]))
                # 添加匹配的特殊 token
                t.append(s[match.start() : match.end()])
                last_index = match.end()
            # 编码最后一段普通文本
            if last_index < len(s):
                t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
            return t
        else:
            # 如果不需要编码特殊 tokens,直接编码整个字符串
            return self.sp_model.EncodeAsPieces(s)

    # 对输入字符串进行编码,返回 token ID 列表
    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
        # 断言输入为字符串
        assert isinstance(s, str)
        # 使用 SentencePiece 编码字符串
        t = self.sp_model.encode(s)
        # 如果需要,添加 BOS token ID
        if bos:
            t = [self.bos_id] + t
        # 如果需要,添加 EOS token ID
        if eos:
            t = t + [self.eos_id]
        return t
    # 定义解码函数,将一组整数标记转换为字符串
    def decode(self, t: List[int]) -> str:
        # 初始化解码后的文本和一个缓冲区列表
        text, buffer = "", []
        # 遍历每个标记
        for token in t:
            # 检查当前标记是否为特殊标记
            if token in self.index_special_tokens:
                # 如果缓冲区不为空,解码缓冲区中的标记并添加到文本中
                if buffer:
                    text += self.sp_model.decode(buffer)
                    # 清空缓冲区
                    buffer = []
                # 将特殊标记对应的文本添加到解码文本中
                text += self.index_special_tokens[token]
            else:
                # 将普通标记添加到缓冲区
                buffer.append(token)
        # 如果缓冲区仍然有标记,解码缓冲区中的标记并添加到文本中
        if buffer:
            text += self.sp_model.decode(buffer)
        # 返回解码后的文本
        return text
    
    # 定义将标记列表解码为字符串的函数
    def decode_tokens(self, tokens: List[str]) -> str:
        # 使用 sp_model 解码标记列表,返回解码结果
        text = self.sp_model.DecodePieces(tokens)
        # 返回解码后的文本
        return text
    
    # 定义将标记(字符串)转换为 ID 的函数
    def convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        # 如果标记是特殊标记,则返回其对应的 ID
        if token in self.special_tokens:
            return self.special_tokens[token]
        # 否则,使用 sp_model 将标记转换为 ID
        return self.sp_model.PieceToId(token)
    
    # 定义将索引(整数)转换为标记(字符串)的函数
    def convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        # 如果索引是特殊标记的索引,返回对应的标记
        if index in self.index_special_tokens:
            return self.index_special_tokens[index]
        # 如果索引是结束标记、开始标记、填充标记,或小于 0,返回空字符串
        if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
            return ""
        # 否则,使用 sp_model 将索引转换为标记
        return self.sp_model.IdToPiece(index)
# 定义一个名为 ChatGLMTokenizer 的类,继承自 PreTrainedTokenizer
class ChatGLMTokenizer(PreTrainedTokenizer):
    # 定义词汇文件名称,指定 tokenizer.model 为 vocab_file
    vocab_files_names = {"vocab_file": "tokenizer.model"}

    # 定义模型输入的名称,包括输入ID、注意力掩码和位置ID
    model_input_names = ["input_ids", "attention_mask", "position_ids"]

    # 初始化方法,接收词汇文件及其他可选参数
    def __init__(
        self,
        vocab_file,
        padding_side="left",  # 默认填充方向为左侧
        clean_up_tokenization_spaces=False,  # 是否清理标记化空间的选项
        encode_special_tokens=False,  # 是否编码特殊标记的选项
        **kwargs,  # 其他额外的关键字参数
    ):
        # 设置 tokenizer 的名称
        self.name = "GLMTokenizer"

        # 保存词汇文件的路径
        self.vocab_file = vocab_file
        # 使用词汇文件初始化 SPTokenizer
        self.tokenizer = SPTokenizer(vocab_file)
        # 定义特殊标记及其对应的ID
        self.special_tokens = {
            "<bos>": self.tokenizer.bos_id,  # 句首标记
            "<eos>": self.tokenizer.eos_id,  # 句尾标记
            "<pad>": self.tokenizer.pad_id,  # 填充标记
        }
        # 保存是否编码特殊标记的选项
        self.encode_special_tokens = encode_special_tokens
        # 调用父类的初始化方法
        super().__init__(
            padding_side=padding_side,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            encode_special_tokens=encode_special_tokens,
            **kwargs,
        )

    # 根据传入的标记获取相应的命令ID
    def get_command(self, token):
        # 如果标记在特殊标记字典中,返回对应的ID
        if token in self.special_tokens:
            return self.special_tokens[token]
        # 确保传入的标记是有效的特殊标记
        assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
        # 返回 tokenizer 中对应的特殊标记ID
        return self.tokenizer.special_tokens[token]

    # 属性,返回未知标记的字符串
    @property
    def unk_token(self) -> str:
        return "<unk>"

    # 设置未知标记的字符串
    @unk_token.setter
    def unk_token(self, value: str):
        self._unk_token = value

    # 属性,返回填充标记的字符串
    @property
    def pad_token(self) -> str:
        return "<unk>"

    # 设置填充标记的字符串
    @pad_token.setter
    def pad_token(self, value: str):
        self._pad_token = value

    # 属性,返回填充标记的ID
    @property
    def pad_token_id(self):
        return self.get_command("<pad>")

    # 属性,返回结束标记的字符串
    @property
    def eos_token(self) -> str:
        return "</s>"

    # 设置结束标记的字符串
    @eos_token.setter
    def eos_token(self, value: str):
        self._eos_token = value

    # 属性,返回结束标记的ID
    @property
    def eos_token_id(self):
        return self.get_command("<eos>")

    # 属性,返回词汇表的大小
    @property
    def vocab_size(self):
        return self.tokenizer.n_words

    # 获取词汇表并返回为字典
    def get_vocab(self):
        """Returns vocab as a dict"""
        # 创建一个字典,将词汇ID映射到对应的标记
        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
        # 更新字典,包含添加的标记
        vocab.update(self.added_tokens_encoder)
        return vocab

    # 对输入文本进行标记化
    def _tokenize(self, text, **kwargs):
        return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)

    # 将标记字符串转换为对应的ID
    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.tokenizer.convert_token_to_id(token)

    # 将ID转换为对应的标记字符串
    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.tokenizer.convert_id_to_token(index)

    # 将标记列表转换为字符串
    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        return self.tokenizer.decode_tokens(tokens)
    # 定义保存词汇和特殊标记文件的方法
    def save_vocabulary(self, save_directory, filename_prefix=None):
        """
        保存词汇和特殊标记文件到指定目录。

        参数:
            save_directory (`str`):
                要保存词汇的目录。
            filename_prefix (`str`, *可选*):
                保存文件名时添加的可选前缀。

        返回:
            `Tuple(str)`: 保存的文件路径。
        """
        # 检查保存目录是否存在
        if os.path.isdir(save_directory):
            # 如果目录存在,构建词汇文件的完整路径
            vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
        else:
            # 如果目录不存在,使用提供的保存目录作为词汇文件路径
            vocab_file = save_directory

        # 以二进制读取模式打开当前的词汇文件
        with open(self.vocab_file, "rb") as fin:
            # 读取文件内容并存储为字节串
            proto_str = fin.read()

        # 以二进制写入模式打开目标词汇文件
        with open(vocab_file, "wb") as writer:
            # 将读取的内容写入到目标词汇文件
            writer.write(proto_str)

        # 返回保存的词汇文件路径
        return (vocab_file,)

    # 定义获取前缀标记的方法
    def get_prefix_tokens(self):
        # 获取特殊前缀标记
        prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
        # 返回前缀标记列表
        return prefix_tokens

    # 定义构建单个消息的方法
    def build_single_message(self, role, metadata, message):
        # 确保角色是有效的选项之一
        assert role in ["system", "user", "assistant", "observation"], role
        # 根据角色构建角色标记和元数据的编码
        role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
        # 编码消息内容
        message_tokens = self.tokenizer.encode(message)
        # 合并角色标记和消息标记
        tokens = role_tokens + message_tokens
        # 返回合并后的标记
        return tokens

    # 定义构建聊天输入的方法
    def build_chat_input(self, query, history=None, role="user"):
        # 如果历史记录为空,初始化为空列表
        if history is None:
            history = []
        # 初始化输入标识符列表
        input_ids = []
        # 遍历历史记录
        for item in history:
            # 获取内容
            content = item["content"]
            # 如果角色是系统并且有工具信息,将其添加到内容中
            if item["role"] == "system" and "tools" in item:
                content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
            # 将构建的单个消息标记扩展到输入标识符列表中
            input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
        # 将当前查询的消息标记添加到输入标识符列表中
        input_ids.extend(self.build_single_message(role, "", query))
        # 添加结束标记
        input_ids.extend([self.get_command("<|assistant|>")])
        # 返回经过批量编码后的输入标识符
        return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)

    # 定义构建带特殊标记的输入的方法
    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    # 返回一个整数列表,构建序列分类任务的模型输入
    ) -> List[int]:
        # 文档字符串,说明该函数的作用和输入输出格式
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:
    
        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`
    
        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
    
        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        # 获取前缀特殊令牌
        prefix_tokens = self.get_prefix_tokens()
        # 将前缀令牌添加到第一个序列的 ID 列表中
        token_ids_0 = prefix_tokens + token_ids_0
        # 如果第二个序列存在,则将其添加到第一个序列中
        if token_ids_1 is not None:
            # 合并两个序列,并添加结束符令牌
            token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
        # 返回包含特殊令牌的 ID 列表
        return token_ids_0
    
        # 定义一个私有函数,用于填充编码后的输入
        def _pad(
            self,
            # 编码输入的字典或批处理编码
            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
            # 最大长度,默认值为 None
            max_length: Optional[int] = None,
            # 填充策略,默认不填充
            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            # 填充到的倍数,默认值为 None
            pad_to_multiple_of: Optional[int] = None,
            # 是否返回注意力掩码,默认值为 None
            return_attention_mask: Optional[bool] = None,
    ) -> dict:
        """
        对编码后的输入进行填充(左右填充以及根据预定义长度或批次中的最大长度进行填充)

        参数:
            encoded_inputs:
                标记化输入的字典(`List[int]`)或标记化输入的批次(`List[List[int]]`)。
            max_length: 返回列表的最大长度以及可选的填充长度(见下文)。
                将通过考虑特殊标记来截断。
            padding_strategy: 填充策略,用于填充。

                - PaddingStrategy.LONGEST 填充到批次中最长的序列
                - PaddingStrategy.MAX_LENGTH: 填充到最大长度(默认)
                - PaddingStrategy.DO_NOT_PAD: 不进行填充
                标记器的填充方向由 self.padding_side 定义:

                    - 'left': 在序列的左侧进行填充
                    - 'right': 在序列的右侧进行填充
            pad_to_multiple_of: (可选)如果设置,将序列填充到提供值的倍数。
                这在启用 NVIDIA 硬件的 Tensor Core 使用时尤其有用,计算能力 `>= 7.5`(Volta)。
            return_attention_mask:
                (可选)设置为 False 以避免返回注意力掩码(默认值:根据模型具体情况设置)
        """
        # 从模型默认值加载
        assert self.padding_side == "left"  # 确保填充方向为左侧

        required_input = encoded_inputs[self.model_input_names[0]]  # 获取所需的输入数据
        seq_length = len(required_input)  # 计算输入序列的长度

        if padding_strategy == PaddingStrategy.LONGEST:  # 如果填充策略为最长
            max_length = len(required_input)  # 设置最大长度为输入的长度

        # 如果 max_length 和 pad_to_multiple_of 都被定义且 max_length 不是 pad_to_multiple_of 的倍数
        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            # 将 max_length 调整为 pad_to_multiple_of 的下一个倍数
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        # 判断是否需要填充:填充策略不为不填充且输入长度不等于最大长度
        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length

        # 如果没有注意力掩码,则初始化注意力掩码
        if "attention_mask" not in encoded_inputs:
            encoded_inputs["attention_mask"] = [1] * seq_length  # 填充为1,表示有效的输入

        # 如果没有位置 ID,则初始化位置 ID
        if "position_ids" not in encoded_inputs:
            encoded_inputs["position_ids"] = list(range(seq_length))  # 填充为从0到序列长度的范围

        # 如果需要填充
        if needs_to_be_padded:
            difference = max_length - len(required_input)  # 计算需要填充的长度

            # 如果存在注意力掩码,则在前面填充0
            if "attention_mask" in encoded_inputs:
                encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
            # 如果存在位置 ID,则在前面填充0
            if "position_ids" in encoded_inputs:
                encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
            # 在输入数据前面填充 pad_token_id
            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input

        return encoded_inputs  # 返回填充后的输入数据
posted @ 2024-10-22 12:33  绝不原创的飞龙  阅读(12)  评论(0编辑  收藏  举报