diffusers-源码解析-四十八-

diffusers 源码解析(四十八)

.\diffusers\pipelines\stable_diffusion_3\pipeline_stable_diffusion_3_img2img.py

# 版权声明,指定版权所有者及保留权利
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# 在 Apache License, Version 2.0 下授权(“许可证”);
# 除非遵循许可证的规定,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,
# 否则根据许可证分发的软件是以“原样”基础提供的,
# 不提供任何明示或暗示的保证或条件。
# 有关许可证规定的权限和限制,请参见许可证。

# 导入 inspect 模块,用于获取对象的各种信息
import inspect
# 从 typing 模块导入所需的类型注解
from typing import Callable, Dict, List, Optional, Union

# 导入 PIL.Image 库,用于图像处理
import PIL.Image
# 导入 PyTorch 库,深度学习框架
import torch
# 从 transformers 库导入所需的模型和分词器
from transformers import (
    CLIPTextModelWithProjection,  # CLIP 文本模型
    CLIPTokenizer,                 # CLIP 分词器
    T5EncoderModel,                # T5 编码器模型
    T5TokenizerFast,               # T5 快速分词器
)

# 从本地模块中导入图像处理和模型相关类
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin  # 导入 SD3 Lora 加载器混合类
from ...models.autoencoders import AutoencoderKL  # 导入自动编码器模型
from ...models.transformers import SD3Transformer2DModel  # 导入 SD3 2D 转换模型
from ...schedulers import FlowMatchEulerDiscreteScheduler  # 导入调度器
from ...utils import (
    USE_PEFT_BACKEND,              # 导入是否使用 PEFT 后端的标识
    is_torch_xla_available,        # 导入检查是否可用 Torch XLA 的函数
    logging,                       # 导入日志模块
    replace_example_docstring,     # 导入替换示例文档字符串的函数
    scale_lora_layers,             # 导入缩放 Lora 层的函数
    unscale_lora_layers,           # 导入取消缩放 Lora 层的函数
)
from ...utils.torch_utils import randn_tensor  # 导入生成随机张量的函数
from ..pipeline_utils import DiffusionPipeline  # 导入扩散管道类
from .pipeline_output import StableDiffusion3PipelineOutput  # 导入稳定扩散 3 的管道输出类


# 检查是否可用 Torch XLA,适用于 TPU
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm  # 导入 XLA 核心模块

    XLA_AVAILABLE = True  # 设置 XLA 可用标志为 True
else:
    XLA_AVAILABLE = False  # 设置 XLA 可用标志为 False


# 创建日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文档字符串,展示如何使用该模块
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch

        >>> from diffusers import AutoPipelineForImage2Image
        >>> from diffusers.utils import load_image

        >>> device = "cuda"  # 设置设备为 CUDA
        >>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"  # 指定模型路径
        >>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)  # 从预训练模型加载管道
        >>> pipe = pipe.to(device)  # 将管道移动到指定设备

        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"  # 指定输入图像 URL
        >>> init_image = load_image(url).resize((1024, 1024))  # 加载并调整图像大小

        >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"  # 设置生成的提示语

        >>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]  # 生成图像
        ```py
"""


# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 复制的函数
def retrieve_latents(
    encoder_output: torch.Tensor,  # 输入的编码器输出,类型为张量
    generator: Optional[torch.Generator] = None,  # 随机数生成器,默认为 None
    sample_mode: str = "sample"  # 采样模式,默认为 "sample"
):
    # 如果 encoder_output 具有 latent_dist 属性且采样模式为 "sample"
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        # 从 latent_dist 中进行采样,并返回结果
        return encoder_output.latent_dist.sample(generator)
    # 检查 encoder_output 是否具有 "latent_dist" 属性,并且 sample_mode 为 "argmax"
        elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
            # 返回 encoder_output 中 latent_dist 的众数
            return encoder_output.latent_dist.mode()
        # 检查 encoder_output 是否具有 "latents" 属性
        elif hasattr(encoder_output, "latents"):
            # 返回 encoder_output 中的 latents 属性
            return encoder_output.latents
        # 如果以上条件都不满足,抛出属性错误
        else:
            raise AttributeError("Could not access latents of provided encoder_output")
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion 中复制的代码
def retrieve_timesteps(
    scheduler,  # 调度器,用于获取时间步
    num_inference_steps: Optional[int] = None,  # 推断步骤数量,默认为 None
    device: Optional[Union[str, torch.device]] = None,  # 设备类型,默认为 None
    timesteps: Optional[List[int]] = None,  # 自定义时间步,默认为 None
    sigmas: Optional[List[float]] = None,  # 自定义 sigma 值,默认为 None
    **kwargs,  # 其他关键字参数
):
    """
    调用调度器的 `set_timesteps` 方法,并在调用后从调度器中检索时间步。处理
    自定义时间步。所有关键字参数将传递给 `scheduler.set_timesteps`。

    参数:
        scheduler (`SchedulerMixin`):
            用于获取时间步的调度器。
        num_inference_steps (`int`):
            生成样本时使用的扩散步骤数。如果使用,则 `timesteps`
            必须为 `None`。
        device (`str` 或 `torch.device`, *可选*):
            要将时间步移动到的设备。如果为 `None`,则时间步不移动。
        timesteps (`List[int]`, *可选*):
            用于覆盖调度器的时间步间距策略的自定义时间步。如果传递 `timesteps`,
            `num_inference_steps` 和 `sigmas` 必须为 `None`。
        sigmas (`List[float]`, *可选*):
            用于覆盖调度器的时间步间距策略的自定义 sigma。如果传递 `sigmas`,
            `num_inference_steps` 和 `timesteps` 必须为 `None`。

    返回:
        `Tuple[torch.Tensor, int]`: 一个元组,第一个元素是来自调度器的时间步调度,
        第二个元素是推断步骤的数量。
    """
    # 检查是否同时传入了自定义时间步和 sigma,若是则抛出异常
    if timesteps is not None and sigmas is not None:
        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
    
    # 如果传入了自定义时间步
    if timesteps is not None:
        # 检查调度器的 `set_timesteps` 方法是否接受自定义时间步
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不支持自定义时间步,抛出异常
        if not accepts_timesteps:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" timestep schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 `set_timesteps` 方法设置自定义时间步
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        # 从调度器中获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推断步骤数量
        num_inference_steps = len(timesteps)
    
    # 如果传入了自定义 sigma
    elif sigmas is not None:
        # 检查调度器的 `set_timesteps` 方法是否接受自定义 sigma
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不支持自定义 sigma,抛出异常
        if not accept_sigmas:
            raise ValueError(
                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
                f" sigmas schedules. Please check whether you are using the correct scheduler."
            )
        # 调用调度器的 `set_timesteps` 方法设置自定义 sigma
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        # 从调度器中获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推断步骤数量
        num_inference_steps = len(timesteps)
    else:  # 如果不满足前面的条件,则执行以下代码
        # 设置调度器的时间步数,传入推理步数和设备信息,以及其他可选参数
        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
        # 获取调度器的时间步数并赋值给变量 timesteps
        timesteps = scheduler.timesteps
    # 返回时间步数和推理步数
    return timesteps, num_inference_steps
# 定义一个名为 StableDiffusion3Img2ImgPipeline 的类,继承自 DiffusionPipeline
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
    r"""
    Args:
        transformer ([`SD3Transformer2DModel`]):
            Conditional Transformer (MMDiT) 结构,用于对编码后的图像潜变量进行去噪。
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            一个调度器,结合 `transformer` 用于对编码后的图像潜变量进行去噪。
        vae ([`AutoencoderKL`]):
            变分自编码器 (VAE) 模型,用于在潜在表示和图像之间进行编码和解码。
        text_encoder ([`CLIPTextModelWithProjection`]):
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
            特别是 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 变体,
            并添加了一个投影层,该层使用对角矩阵初始化,维度为 `hidden_size`。
        text_encoder_2 ([`CLIPTextModelWithProjection`]):
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
            特别是
            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
            变体。
        text_encoder_3 ([`T5EncoderModel`]):
            冻结的文本编码器。Stable Diffusion 3 使用
            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel),特别是
            [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) 变体。
        tokenizer (`CLIPTokenizer`):
            类
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) 的标记器。
        tokenizer_2 (`CLIPTokenizer`):
            类
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) 的第二个标记器。
        tokenizer_3 (`T5TokenizerFast`):
            类
            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer) 的标记器。
    """

    # 定义一个字符串,表示模型组件的加载顺序
    model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
    # 定义一个可选组件列表,初始为空
    _optional_components = []
    # 定义一个回调张量输入列表,包含潜变量和提示嵌入等
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]

    # 初始化方法,定义所需的组件
    def __init__(
        self,
        # 定义 transformer 参数,类型为 SD3Transformer2DModel
        transformer: SD3Transformer2DModel,
        # 定义 scheduler 参数,类型为 FlowMatchEulerDiscreteScheduler
        scheduler: FlowMatchEulerDiscreteScheduler,
        # 定义 vae 参数,类型为 AutoencoderKL
        vae: AutoencoderKL,
        # 定义 text_encoder 参数,类型为 CLIPTextModelWithProjection
        text_encoder: CLIPTextModelWithProjection,
        # 定义 tokenizer 参数,类型为 CLIPTokenizer
        tokenizer: CLIPTokenizer,
        # 定义第二个 text_encoder 参数,类型为 CLIPTextModelWithProjection
        text_encoder_2: CLIPTextModelWithProjection,
        # 定义第二个 tokenizer 参数,类型为 CLIPTokenizer
        tokenizer_2: CLIPTokenizer,
        # 定义第三个 text_encoder 参数,类型为 T5EncoderModel
        text_encoder_3: T5EncoderModel,
        # 定义第三个 tokenizer 参数,类型为 T5TokenizerFast
        tokenizer_3: T5TokenizerFast,
    ):
        # 调用父类的构造函数进行初始化
        super().__init__()

        # 注册多个模块,方便在后续操作中使用
        self.register_modules(
            # 注册变分自编码器模块
            vae=vae,
            # 注册文本编码器模块
            text_encoder=text_encoder,
            # 注册第二个文本编码器模块
            text_encoder_2=text_encoder_2,
            # 注册第三个文本编码器模块
            text_encoder_3=text_encoder_3,
            # 注册标记器模块
            tokenizer=tokenizer,
            # 注册第二个标记器模块
            tokenizer_2=tokenizer_2,
            # 注册第三个标记器模块
            tokenizer_3=tokenizer_3,
            # 注册变换器模块
            transformer=transformer,
            # 注册调度器模块
            scheduler=scheduler,
        )
        # 计算 VAE 的缩放因子,基于块输出通道的数量
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 创建图像处理器,传入 VAE 的缩放因子和潜在通道数
        self.image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
        )
        # 获取标记器的最大长度
        self.tokenizer_max_length = self.tokenizer.model_max_length
        # 获取变换器的默认样本大小
        self.default_sample_size = self.transformer.config.sample_size

    # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds 复制的方法
    def _get_t5_prompt_embeds(
        self,
        # 输入的提示,可以是字符串或字符串列表
        prompt: Union[str, List[str]] = None,
        # 每个提示生成的图像数量
        num_images_per_prompt: int = 1,
        # 最大序列长度
        max_sequence_length: int = 256,
        # 设备类型,可选
        device: Optional[torch.device] = None,
        # 数据类型,可选
        dtype: Optional[torch.dtype] = None,
    # 方法的定义,接受多个参数
        ):
            # 如果没有指定设备,则使用类中定义的执行设备
            device = device or self._execution_device
            # 如果没有指定数据类型,则使用文本编码器的数据类型
            dtype = dtype or self.text_encoder.dtype
    
            # 如果提示为字符串,则转换为列表形式;否则保持原样
            prompt = [prompt] if isinstance(prompt, str) else prompt
            # 获取提示的批处理大小,即提示的数量
            batch_size = len(prompt)
    
            # 如果没有文本编码器 3,则返回一个全零的张量
            if self.text_encoder_3 is None:
                return torch.zeros(
                    # 返回形状为 (批处理大小 * 每个提示的图像数量, 最大序列长度, 联合注意力维度)
                    (
                        batch_size * num_images_per_prompt,
                        self.tokenizer_max_length,
                        self.transformer.config.joint_attention_dim,
                    ),
                    # 指定设备和数据类型
                    device=device,
                    dtype=dtype,
                )
    
            # 使用文本编码器 3 对提示进行编码,返回张量格式
            text_inputs = self.tokenizer_3(
                prompt,
                padding="max_length",
                max_length=max_sequence_length,
                truncation=True,
                add_special_tokens=True,
                return_tensors="pt",
            )
            # 获取输入的 ID
            text_input_ids = text_inputs.input_ids
            # 获取未截断的 ID,用于检测是否有内容被截断
            untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
    
            # 检查是否未截断 ID 的长度大于或等于输入 ID 的长度,并且两者不相等
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
                # 解码被截断的文本并发出警告
                removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
                logger.warning(
                    "The following part of your input was truncated because `max_sequence_length` is set to "
                    f" {max_sequence_length} tokens: {removed_text}"
                )
    
            # 获取文本输入的嵌入
            prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
    
            # 更新数据类型为文本编码器 3 的数据类型
            dtype = self.text_encoder_3.dtype
            # 将嵌入转换为指定的数据类型和设备
            prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
    
            # 获取嵌入的形状信息
            _, seq_len, _ = prompt_embeds.shape
    
            # 为每个提示生成的图像复制文本嵌入
            prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
            # 调整嵌入形状,以便于处理
            prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
    
            # 返回最终的文本嵌入
            return prompt_embeds
    
        # 从 StableDiffusion3Pipeline 类复制的方法,用于获取 CLIP 提示嵌入
        def _get_clip_prompt_embeds(
            self,
            prompt: Union[str, List[str]],
            num_images_per_prompt: int = 1,
            device: Optional[torch.device] = None,
            clip_skip: Optional[int] = None,
            clip_model_index: int = 0,
    # 设备设置,如果未指定,则使用默认执行设备
        ):
            device = device or self._execution_device
    
            # 定义 CLIP 使用的分词器
            clip_tokenizers = [self.tokenizer, self.tokenizer_2]
            # 定义 CLIP 使用的文本编码器
            clip_text_encoders = [self.text_encoder, self.text_encoder_2]
    
            # 根据给定的模型索引选择分词器
            tokenizer = clip_tokenizers[clip_model_index]
            # 根据给定的模型索引选择文本编码器
            text_encoder = clip_text_encoders[clip_model_index]
    
            # 如果 prompt 是字符串,则转为列表形式
            prompt = [prompt] if isinstance(prompt, str) else prompt
            # 获取 prompt 的批处理大小
            batch_size = len(prompt)
    
            # 使用选择的分词器对 prompt 进行编码
            text_inputs = tokenizer(
                prompt,
                padding="max_length",  # 填充到最大长度
                max_length=self.tokenizer_max_length,  # 最大长度限制
                truncation=True,  # 允许截断
                return_tensors="pt",  # 返回 PyTorch 张量
            )
    
            # 提取编码后的输入 ID
            text_input_ids = text_inputs.input_ids
            # 进行最长填充以获取未截断的 ID
            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
            # 检查未截断的 ID 是否比当前输入 ID 更长且不相等
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
                # 解码被截断的部分,并记录警告
                removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer_max_length} tokens: {removed_text}"
                )
            # 使用文本编码器生成 prompt 的嵌入
            prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
            # 获取池化后的 prompt 嵌入
            pooled_prompt_embeds = prompt_embeds[0]
    
            # 判断是否跳过某些隐藏状态
            if clip_skip is None:
                # 使用倒数第二个隐藏状态作为嵌入
                prompt_embeds = prompt_embeds.hidden_states[-2]
            else:
                # 根据 clip_skip 使用相应的隐藏状态
                prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
    
            # 将嵌入转换为所需的数据类型和设备
            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
    
            # 获取嵌入的形状
            _, seq_len, _ = prompt_embeds.shape
            # 针对每个 prompt 复制文本嵌入,使用适应 MPS 的方法
            prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
            # 重新调整嵌入的形状
            prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
    
            # 复制池化后的嵌入
            pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            # 重新调整池化嵌入的形状
            pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
    
            # 返回最终的 prompt 嵌入和池化后的嵌入
            return prompt_embeds, pooled_prompt_embeds
    
        # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt 复制的内容
    # 定义一个编码提示的函数,接收多个参数
        def encode_prompt(
            # 第一个提示,可以是字符串或字符串列表
            self,
            prompt: Union[str, List[str]],
            # 第二个提示,可以是字符串或字符串列表
            prompt_2: Union[str, List[str]],
            # 第三个提示,可以是字符串或字符串列表
            prompt_3: Union[str, List[str]],
            # 设备选项,默认为 None
            device: Optional[torch.device] = None,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: int = 1,
            # 是否进行分类器自由引导,默认为 True
            do_classifier_free_guidance: bool = True,
            # 负提示,可以是字符串或字符串列表,默认为 None
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 第二个负提示,可以是字符串或字符串列表,默认为 None
            negative_prompt_2: Optional[Union[str, List[str]]] = None,
            # 第三个负提示,可以是字符串或字符串列表,默认为 None
            negative_prompt_3: Optional[Union[str, List[str]]] = None,
            # 提示嵌入,默认为 None
            prompt_embeds: Optional[torch.FloatTensor] = None,
            # 负提示嵌入,默认为 None
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 池化后的提示嵌入,默认为 None
            pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 池化后的负提示嵌入,默认为 None
            negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 可选的跳过参数,默认为 None
            clip_skip: Optional[int] = None,
            # 最大序列长度,默认为 256
            max_sequence_length: int = 256,
            # 可选的 Lora 缩放参数,默认为 None
            lora_scale: Optional[float] = None,
        # 定义一个检查输入的函数,接收多个参数
        def check_inputs(
            # 第一个提示
            self,
            prompt,
            # 第二个提示
            prompt_2,
            # 第三个提示
            prompt_3,
            # 强度参数
            strength,
            # 负提示,默认为 None
            negative_prompt=None,
            # 第二个负提示,默认为 None
            negative_prompt_2=None,
            # 第三个负提示,默认为 None
            negative_prompt_3=None,
            # 提示嵌入,默认为 None
            prompt_embeds=None,
            # 负提示嵌入,默认为 None
            negative_prompt_embeds=None,
            # 池化后的提示嵌入,默认为 None
            pooled_prompt_embeds=None,
            # 池化后的负提示嵌入,默认为 None
            negative_pooled_prompt_embeds=None,
            # 步骤结束时的回调输入,默认为 None
            callback_on_step_end_tensor_inputs=None,
            # 最大序列长度,默认为 None
            max_sequence_length=None,
        # 定义获取时间步的函数,接收推理步骤数、强度和设备参数
        def get_timesteps(self, num_inference_steps, strength, device):
            # 计算初始化时间步的原始值
            init_timestep = min(num_inference_steps * strength, num_inference_steps)
    
            # 计算开始的时间步
            t_start = int(max(num_inference_steps - init_timestep, 0))
            # 从调度器获取时间步
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器具有设置开始索引的属性,设置开始索引
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步和剩余的推理步骤数
            return timesteps, num_inference_steps - t_start
    # 准备潜在向量,用于图像生成的前处理
    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
        # 检查输入的图像类型是否为 torch.Tensor, PIL.Image.Image 或列表
        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
            # 抛出类型错误,提示用户输入的图像类型不正确
            raise ValueError(
                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
            )

        # 将图像转换为指定设备和数据类型
        image = image.to(device=device, dtype=dtype)

        # 计算有效批次大小
        batch_size = batch_size * num_images_per_prompt
        # 如果图像的通道数与 VAE 的潜在通道数相同,则初始化潜在向量为图像
        if image.shape[1] == self.vae.config.latent_channels:
            init_latents = image

        else:
            # 如果生成器是列表且其长度与批次大小不符,抛出错误
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )

            elif isinstance(generator, list):
                # 对每个图像生成潜在向量,并将结果合并成一个张量
                init_latents = [
                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
                    for i in range(batch_size)
                ]
                # 在第0维上拼接所有潜在向量
                init_latents = torch.cat(init_latents, dim=0)
            else:
                # 对单个图像生成潜在向量
                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

            # 根据 VAE 配置调整潜在向量的值
            init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

        # 如果要求的批次大小大于初始化的潜在向量数量且可以整除
        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            # 为批次大小扩展初始化潜在向量
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            # 通过复制初始化潜在向量来增加批次大小
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        # 如果要求的批次大小大于初始化的潜在向量数量且不能整除
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            # 抛出错误,提示无法复制图像以满足批次大小
            raise ValueError(
                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
            )
        else:
            # 确保潜在向量为二维,方便后续处理
            init_latents = torch.cat([init_latents], dim=0)

        # 获取潜在向量的形状
        shape = init_latents.shape
        # 生成与潜在向量形状相同的随机噪声张量
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        # 获取潜在向量,通过调度器缩放噪声
        init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
        # 将潜在向量转换为指定设备和数据类型
        latents = init_latents.to(device=device, dtype=dtype)

        # 返回处理后的潜在向量
        return latents

    # 返回当前的指导比例
    @property
    def guidance_scale(self):
        return self._guidance_scale

    # 返回当前的剪辑跳过值
    @property
    def clip_skip(self):
        return self._clip_skip

    # 判断是否进行无分类器引导,依据指导比例
    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1

    # 返回当前的时间步数
    @property
    def num_timesteps(self):
        return self._num_timesteps

    @property
    # 定义一个方法以返回中断状态
        def interrupt(self):
            # 返回中断标志的值
            return self._interrupt
    
        # 禁用梯度计算以节省内存和计算
        @torch.no_grad()
        # 替换文档字符串以提供示例文档
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        # 定义可调用方法,处理各种输入参数
        def __call__(
            # 主提示文本,可以是单个字符串或字符串列表
            prompt: Union[str, List[str]] = None,
            # 第二个提示文本,可选
            prompt_2: Optional[Union[str, List[str]]] = None,
            # 第三个提示文本,可选
            prompt_3: Optional[Union[str, List[str]]] = None,
            # 输入图像,类型为管道图像输入
            image: PipelineImageInput = None,
            # 强度参数,默认值为0.6
            strength: float = 0.6,
            # 推理步骤数,默认值为50
            num_inference_steps: int = 50,
            # 时间步长列表,可选
            timesteps: List[int] = None,
            # 引导比例,默认值为7.0
            guidance_scale: float = 7.0,
            # 负提示文本,可选
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 第二个负提示文本,可选
            negative_prompt_2: Optional[Union[str, List[str]]] = None,
            # 第三个负提示文本,可选
            negative_prompt_3: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为1
            num_images_per_prompt: Optional[int] = 1,
            # 随机数生成器,可选
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜在变量,类型为FloatTensor,可选
            latents: Optional[torch.FloatTensor] = None,
            # 提示嵌入,类型为FloatTensor,可选
            prompt_embeds: Optional[torch.FloatTensor] = None,
            # 负提示嵌入,类型为FloatTensor,可选
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 池化后的提示嵌入,类型为FloatTensor,可选
            pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 负池化提示嵌入,类型为FloatTensor,可选
            negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
            # 输出类型,默认为"pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典,默认为True
            return_dict: bool = True,
            # 跳过的剪辑层数,可选
            clip_skip: Optional[int] = None,
            # 在步骤结束时的回调函数,可选
            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
            # 步骤结束时的张量输入回调列表,默认包含"latents"
            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
            # 最大序列长度,默认为256
            max_sequence_length: int = 256,

.\diffusers\pipelines\stable_diffusion_3\pipeline_stable_diffusion_3_inpaint.py

# 版权声明,包含版权持有者及其授权信息
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版进行授权
# 该文件仅可在遵循许可的情况下使用
# 许可证的副本可以在以下地址获得
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面协议另有约定,否则以 "按原样" 基础分发软件,
# 不提供任何形式的担保或条件
# 查看许可证以获取特定语言的权限和限制

import inspect  # 导入 inspect 模块以检查对象
from typing import Callable, Dict, List, Optional, Union  # 导入类型提示相关的类

import torch  # 导入 PyTorch 库
from transformers import (  # 从 transformers 库导入必要的类
    CLIPTextModelWithProjection,  # 导入 CLIP 文本模型类
    CLIPTokenizer,  # 导入 CLIP 词元化工具
    T5EncoderModel,  # 导入 T5 编码器模型
    T5TokenizerFast,  # 导入快速 T5 词元化工具
)

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 导入回调相关类
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理类
from ...loaders import SD3LoraLoaderMixin  # 导入 Lora 加载器混合类
from ...models.autoencoders import AutoencoderKL  # 导入自动编码器类
from ...models.transformers import SD3Transformer2DModel  # 导入 2D 转换模型
from ...schedulers import FlowMatchEulerDiscreteScheduler  # 导入调度器类
from ...utils import (  # 导入实用工具
    USE_PEFT_BACKEND,  # 导入 PEFT 后端标志
    is_torch_xla_available,  # 导入检查 Torch XLA 可用性的函数
    logging,  # 导入日志模块
    replace_example_docstring,  # 导入替换示例文档字符串的函数
    scale_lora_layers,  # 导入缩放 Lora 层的函数
    unscale_lora_layers,  # 导入取消缩放 Lora 层的函数
)
from ...utils.torch_utils import randn_tensor  # 导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline  # 导入扩散管道类
from .pipeline_output import StableDiffusion3PipelineOutput  # 导入稳定扩散输出类


# 检查 Torch XLA 是否可用,导入相应模块
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm  # 导入 XLA 模型核心模块

    XLA_AVAILABLE = True  # 设置 XLA 可用标志
else:
    XLA_AVAILABLE = False  # 设置 XLA 不可用标志


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁用 pylint 命名警告

EXAMPLE_DOC_STRING = """  # 定义示例文档字符串
    Examples:  # 示例说明
        ```py  # 开始代码块
        >>> import torch  # 导入 PyTorch 库
        >>> from diffusers import StableDiffusion3InpaintPipeline  # 导入稳定扩散修复管道类
        >>> from diffusers.utils import load_image  # 导入加载图像的实用工具

        >>> pipe = StableDiffusion3InpaintPipeline.from_pretrained(  # 从预训练模型加载管道
        ...     "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16  # 指定模型名称及数据类型
        ... )
        >>> pipe.to("cuda")  # 将管道转移到 CUDA 设备
        >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"  # 定义生成图像的提示
        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"  # 定义源图像 URL
        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"  # 定义掩模图像 URL
        >>> source = load_image(img_url)  # 加载源图像
        >>> mask = load_image(mask_url)  # 加载掩模图像
        >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0]  # 生成修复后的图像
        >>> image.save("sd3_inpainting.png")  # 保存生成的图像
        ```py  # 结束代码块
"""


# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 复制的函数
def retrieve_latents(  # 定义函数以检索潜在变量
    encoder_output: torch.Tensor,  # 输入为编码器输出张量
    generator: Optional[torch.Generator] = None,  # 可选参数,指定随机数生成器
    sample_mode: str = "sample"  # 指定采样模式,默认为 "sample"
):
    # 检查 encoder_output 是否有 latent_dist 属性,并且样本模式为 "sample"
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        # 从 latent_dist 中采样,使用指定的生成器
        return encoder_output.latent_dist.sample(generator)
    # 检查 encoder_output 是否有 latent_dist 属性,并且样本模式为 "argmax"
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        # 返回 latent_dist 的众数
        return encoder_output.latent_dist.mode()
    # 检查 encoder_output 是否有 latents 属性
    elif hasattr(encoder_output, "latents"):
        # 返回 latents 属性的值
        return encoder_output.latents
    # 如果以上条件都不满足,则抛出 AttributeError
    else:
        raise AttributeError("Could not access latents of provided encoder_output")
# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion 中复制的代码
def retrieve_timesteps(
    # 调度器,用于获取时间步
    scheduler,
    # 推理步骤的数量,默认为 None
    num_inference_steps: Optional[int] = None,
    # 要移动到的设备,默认为 None
    device: Optional[Union[str, torch.device]] = None,
    # 自定义时间步,默认为 None
    timesteps: Optional[List[int]] = None,
    # 自定义 sigma,默认为 None
    sigmas: Optional[List[float]] = None,
    # 其他关键字参数,传递给调度器的 set_timesteps 方法
    **kwargs,
):
    """
    调用调度器的 `set_timesteps` 方法,并在调用后从调度器获取时间步。处理
    自定义时间步。任何 kwargs 将被传递给 `scheduler.set_timesteps`。

    参数:
        scheduler (`SchedulerMixin`):
            用于获取时间步的调度器。
        num_inference_steps (`int`):
            用于生成样本的扩散步骤数。如果使用,则 `timesteps` 必须为 `None`。
        device (`str` 或 `torch.device`,*可选*):
            时间步应移动到的设备。如果为 `None`,则时间步不会移动。
        timesteps (`List[int]`,*可选*):
            自定义时间步,用于覆盖调度器的时间步间隔策略。如果传递 `timesteps`,
            则 `num_inference_steps` 和 `sigmas` 必须为 `None`。
        sigmas (`List[float]`,*可选*):
            自定义 sigma,用于覆盖调度器的时间步间隔策略。如果传递 `sigmas`,
            则 `num_inference_steps` 和 `timesteps` 必须为 `None`。

    返回:
        `Tuple[torch.Tensor, int]`:一个元组,第一个元素是调度器的时间步调度,
        第二个元素是推理步骤的数量。
    """
    # 如果同时传递了 timesteps 和 sigmas,则抛出错误
    if timesteps is not None and sigmas is not None:
        raise ValueError("只能传递 `timesteps` 或 `sigmas` 之一。请选择一个设置自定义值")
    # 如果传递了 timesteps
    if timesteps is not None:
        # 检查调度器的 set_timesteps 方法是否接受 timesteps 参数
        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,则抛出错误
        if not accepts_timesteps:
            raise ValueError(
                f"当前调度器类 {scheduler.__class__} 的 `set_timesteps` 不支持自定义"
                f" 时间步调度。请检查您是否使用了正确的调度器。"
            )
        # 调用调度器的 set_timesteps 方法,传递自定义时间步和设备
        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
        # 从调度器获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    # 如果传递了 sigmas
    elif sigmas is not None:
        # 检查调度器的 set_timesteps 方法是否接受 sigmas 参数
        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
        # 如果不接受,则抛出错误
        if not accept_sigmas:
            raise ValueError(
                f"当前调度器类 {scheduler.__class__} 的 `set_timesteps` 不支持自定义"
                f" sigma 调度。请检查您是否使用了正确的调度器。"
            )
        # 调用调度器的 set_timesteps 方法,传递自定义 sigma 和设备
        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
        # 从调度器获取设置后的时间步
        timesteps = scheduler.timesteps
        # 计算推理步骤的数量
        num_inference_steps = len(timesteps)
    # 如果不是特定条件,则设置推理步骤的时间步数
        else:
            # 调用调度器设置推理步骤数,并指定设备及其他参数
            scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
            # 获取调度器的时间步数
            timesteps = scheduler.timesteps
        # 返回时间步数和推理步骤数
        return timesteps, num_inference_steps
# 定义一个名为 StableDiffusion3InpaintPipeline 的类,继承自 DiffusionPipeline
class StableDiffusion3InpaintPipeline(DiffusionPipeline):
    r"""
    Args:
        transformer ([`SD3Transformer2DModel`]):
            条件变换器(MMDiT)架构,用于对编码后的图像潜变量进行去噪。
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            用于与 `transformer` 结合使用的调度器,用于去噪编码后的图像潜变量。
        vae ([`AutoencoderKL`]):
            变分自编码器(VAE)模型,用于将图像编码为潜在表示并解码。
        text_encoder ([`CLIPTextModelWithProjection`]):
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
            特别是 [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) 变体,
            具有额外的投影层,该层用具有 `hidden_size` 维度的对角矩阵初始化。
        text_encoder_2 ([`CLIPTextModelWithProjection`]):
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
            特别是 [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) 变体。
        text_encoder_3 ([`T5EncoderModel`]):
            冻结的文本编码器。Stable Diffusion 3 使用 [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel),
            特别是 [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) 变体。
        tokenizer (`CLIPTokenizer`):
            类 [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) 的分词器。
        tokenizer_2 (`CLIPTokenizer`):
            第二个类 [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) 的分词器。
        tokenizer_3 (`T5TokenizerFast`):
            类 [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer) 的分词器。
    """
    
    # 定义模型 CPU 卸载的顺序
    model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
    # 定义可选组件的空列表
    _optional_components = []
    # 定义回调张量输入的列表
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]

    # 初始化方法,接收多个参数
    def __init__(
        # 接收条件变换器模型
        self,
        transformer: SD3Transformer2DModel,
        # 接收调度器
        scheduler: FlowMatchEulerDiscreteScheduler,
        # 接收变分自编码器
        vae: AutoencoderKL,
        # 接收文本编码器
        text_encoder: CLIPTextModelWithProjection,
        # 接收分词器
        tokenizer: CLIPTokenizer,
        # 接收第二个文本编码器
        text_encoder_2: CLIPTextModelWithProjection,
        # 接收第二个分词器
        tokenizer_2: CLIPTokenizer,
        # 接收第三个文本编码器
        text_encoder_3: T5EncoderModel,
        # 接收第三个分词器
        tokenizer_3: T5TokenizerFast,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 注册多个模块,方便后续使用
        self.register_modules(
            # 注册变分自编码器
            vae=vae,
            # 注册文本编码器
            text_encoder=text_encoder,
            # 注册第二个文本编码器
            text_encoder_2=text_encoder_2,
            # 注册第三个文本编码器
            text_encoder_3=text_encoder_3,
            # 注册分词器
            tokenizer=tokenizer,
            # 注册第二个分词器
            tokenizer_2=tokenizer_2,
            # 注册第三个分词器
            tokenizer_3=tokenizer_3,
            # 注册转换器
            transformer=transformer,
            # 注册调度器
            scheduler=scheduler,
        )
        # 计算 VAE 的缩放因子,基于块输出通道数量
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 初始化图像处理器,使用 VAE 缩放因子和潜在通道数
        self.image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
        )
        # 初始化掩码处理器,设置参数以处理图像
        self.mask_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor,
            vae_latent_channels=self.vae.config.latent_channels,
            do_normalize=False,
            do_binarize=True,
            do_convert_grayscale=True,
        )
        # 获取分词器的最大长度
        self.tokenizer_max_length = self.tokenizer.model_max_length
        # 获取转换器的默认采样大小
        self.default_sample_size = self.transformer.config.sample_size

    # 从稳定扩散管道复制的方法,获取 T5 提示嵌入
    def _get_t5_prompt_embeds(
        self,
        # 输入提示,可以是字符串或字符串列表
        prompt: Union[str, List[str]] = None,
        # 每个提示生成的图像数量
        num_images_per_prompt: int = 1,
        # 最大序列长度
        max_sequence_length: int = 256,
        # 可选的设备参数
        device: Optional[torch.device] = None,
        # 可选的数据类型
        dtype: Optional[torch.dtype] = None,
    # 定义一个方法,接受多个参数,处理输入文本以生成提示嵌入
        ):
            # 如果未指定设备,则使用默认执行设备
            device = device or self._execution_device
            # 如果未指定数据类型,则使用文本编码器的数据类型
            dtype = dtype or self.text_encoder.dtype
    
            # 如果提示是字符串,则将其转换为列表;否则保持原样
            prompt = [prompt] if isinstance(prompt, str) else prompt
            # 获取提示的批处理大小
            batch_size = len(prompt)
    
            # 如果第三个文本编码器为空,则返回一个零张量
            if self.text_encoder_3 is None:
                return torch.zeros(
                    (
                        batch_size * num_images_per_prompt,
                        self.tokenizer_max_length,
                        self.transformer.config.joint_attention_dim,
                    ),
                    device=device,
                    dtype=dtype,
                )
    
            # 使用第三个文本编码器对提示进行标记化,返回张量
            text_inputs = self.tokenizer_3(
                prompt,
                padding="max_length",
                max_length=max_sequence_length,
                truncation=True,
                add_special_tokens=True,
                return_tensors="pt",
            )
            # 提取输入 ID
            text_input_ids = text_inputs.input_ids
            # 获取未截断的 ID
            untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
    
            # 检查未截断的 ID 是否大于等于输入 ID,并且不相等
            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
                # 解码被截断的文本并记录警告
                removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
                logger.warning(
                    "The following part of your input was truncated because `max_sequence_length` is set to "
                    f" {max_sequence_length} tokens: {removed_text}"
                )
    
            # 获取文本输入的嵌入
            prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
    
            # 获取文本编码器的数据类型
            dtype = self.text_encoder_3.dtype
            # 将提示嵌入转换为指定的数据类型和设备
            prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
    
            # 获取嵌入的形状,提取序列长度
            _, seq_len, _ = prompt_embeds.shape
    
            # 复制文本嵌入和注意力掩码以适应每个提示的生成,使用适合 MPS 的方法
            prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
            # 重塑嵌入以匹配批处理大小和生成的图像数量
            prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
    
            # 返回生成的提示嵌入
            return prompt_embeds
    
        # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds 复制的方法
        def _get_clip_prompt_embeds(
            self,
            prompt: Union[str, List[str]],
            num_images_per_prompt: int = 1,
            device: Optional[torch.device] = None,
            clip_skip: Optional[int] = None,
            clip_model_index: int = 0,
    ):
        # 如果没有指定设备,则使用当前对象的执行设备
        device = device or self._execution_device

        # 定义两个 CLIP 分词器的列表
        clip_tokenizers = [self.tokenizer, self.tokenizer_2]
        # 定义两个 CLIP 文本编码器的列表
        clip_text_encoders = [self.text_encoder, self.text_encoder_2]

        # 根据所选模型索引选择相应的分词器
        tokenizer = clip_tokenizers[clip_model_index]
        # 根据所选模型索引选择相应的文本编码器
        text_encoder = clip_text_encoders[clip_model_index]

        # 如果 prompt 是字符串,则将其转换为列表,否则保持不变
        prompt = [prompt] if isinstance(prompt, str) else prompt
        # 获取 prompt 的批量大小
        batch_size = len(prompt)

        # 使用选择的分词器对 prompt 进行编码,返回张量
        text_inputs = tokenizer(
            prompt,
            padding="max_length",  # 填充到最大长度
            max_length=self.tokenizer_max_length,  # 最大长度限制
            truncation=True,  # 如果超出最大长度则截断
            return_tensors="pt",  # 返回 PyTorch 张量
        )

        # 获取编码后的输入 ID
        text_input_ids = text_inputs.input_ids
        # 获取未截断的输入 ID
        untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
        # 检查未截断的输入 ID 是否超过最大长度,并且与截断的 ID 是否不同
        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
            # 解码被截断的文本部分
            removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
            # 记录警告,指出被截断的部分
            logger.warning(
                "The following part of your input was truncated because CLIP can only handle sequences up to"
                f" {self.tokenizer_max_length} tokens: {removed_text}"
            )
        # 使用文本编码器对输入 ID 进行编码,并输出隐藏状态
        prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
        # 获取池化后的提示嵌入
        pooled_prompt_embeds = prompt_embeds[0]

        # 如果没有指定跳过层,则使用倒数第二层的嵌入
        if clip_skip is None:
            prompt_embeds = prompt_embeds.hidden_states[-2]
        else:
            # 根据指定的跳过层获取相应的嵌入
            prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]

        # 将提示嵌入转换为指定的数据类型和设备
        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        # 获取提示嵌入的形状信息
        _, seq_len, _ = prompt_embeds.shape
        # 为每个提示生成多个文本嵌入,使用适合 MPS 的方法
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        # 调整形状以符合批处理要求
        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        # 重复池化的提示嵌入以匹配生成数量
        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        # 调整形状以符合批处理要求
        pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)

        # 返回处理后的提示嵌入和池化提示嵌入
        return prompt_embeds, pooled_prompt_embeds

    # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt 复制而来
    # 定义一个方法来编码提示信息
        def encode_prompt(
            self,  # 方法的第一个参数,指向当前实例
            prompt: Union[str, List[str]],  # 第一个提示,可以是字符串或字符串列表
            prompt_2: Union[str, List[str]],  # 第二个提示,可以是字符串或字符串列表
            prompt_3: Union[str, List[str]],  # 第三个提示,可以是字符串或字符串列表
            device: Optional[torch.device] = None,  # 可选参数,指定设备(CPU或GPU)
            num_images_per_prompt: int = 1,  # 每个提示生成的图像数量,默认为1
            do_classifier_free_guidance: bool = True,  # 是否使用无分类器引导,默认为True
            negative_prompt: Optional[Union[str, List[str]]] = None,  # 可选的负面提示,可以是字符串或字符串列表
            negative_prompt_2: Optional[Union[str, List[str]]] = None,  # 第二个负面提示
            negative_prompt_3: Optional[Union[str, List[str]]] = None,  # 第三个负面提示
            prompt_embeds: Optional[torch.FloatTensor] = None,  # 可选参数,提示的嵌入表示
            negative_prompt_embeds: Optional[torch.FloatTensor] = None,  # 可选参数,负面提示的嵌入表示
            pooled_prompt_embeds: Optional[torch.FloatTensor] = None,  # 可选参数,池化后的提示嵌入表示
            negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,  # 可选参数,池化后的负面提示嵌入表示
            clip_skip: Optional[int] = None,  # 可选参数,控制剪辑跳过的层数
            max_sequence_length: int = 256,  # 最大序列长度,默认为256
            lora_scale: Optional[float] = None,  # 可选参数,LORA缩放因子
        # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.check_inputs 拷贝而来
        def check_inputs(
            self,  # 当前实例
            prompt,  # 第一个提示
            prompt_2,  # 第二个提示
            prompt_3,  # 第三个提示
            strength,  # 强度参数
            negative_prompt=None,  # 可选的负面提示
            negative_prompt_2=None,  # 第二个负面提示
            negative_prompt_3=None,  # 第三个负面提示
            prompt_embeds=None,  # 可选的提示嵌入表示
            negative_prompt_embeds=None,  # 可选的负面提示嵌入表示
            pooled_prompt_embeds=None,  # 可选的池化提示嵌入表示
            negative_pooled_prompt_embeds=None,  # 可选的池化负面提示嵌入表示
            callback_on_step_end_tensor_inputs=None,  # 可选的步骤结束回调输入
            max_sequence_length=None,  # 可选的最大序列长度
        # 从 diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps 拷贝而来
        def get_timesteps(self, num_inference_steps, strength, device):  # 定义获取时间步的方法
            # 使用 init_timestep 获取原始时间步
            init_timestep = min(num_inference_steps * strength, num_inference_steps)  # 计算初始化时间步
    
            t_start = int(max(num_inference_steps - init_timestep, 0))  # 计算起始时间步
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]  # 获取调度器中的时间步
            if hasattr(self.scheduler, "set_begin_index"):  # 检查调度器是否有设置开始索引的方法
                self.scheduler.set_begin_index(t_start * self.scheduler.order)  # 设置调度器的开始索引
    
            return timesteps, num_inference_steps - t_start  # 返回时间步和剩余的推理步骤
    
        def prepare_latents(  # 定义准备潜在变量的方法
            self,  # 当前实例
            batch_size,  # 批处理大小
            num_channels_latents,  # 潜在变量的通道数
            height,  # 图像高度
            width,  # 图像宽度
            dtype,  # 数据类型
            device,  # 设备
            generator,  # 随机数生成器
            latents=None,  # 可选的潜在变量
            image=None,  # 可选的输入图像
            timestep=None,  # 可选的时间步
            is_strength_max=True,  # 强度是否为最大值,默认为True
            return_noise=False,  # 是否返回噪声,默认为False
            return_image_latents=False,  # 是否返回图像潜在变量,默认为False
    ):
        # 定义输出的形状,包含批量大小、通道数、高度和宽度
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,  # 根据 VAE 的缩放因子计算高度
            int(width) // self.vae_scale_factor,    # 根据 VAE 的缩放因子计算宽度
        )
        # 检查生成器是否是列表且长度与批量大小不匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 抛出值错误,提示生成器长度与批量大小不匹配
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 检查图像或时间步是否为 None,且强度未达到最大值
        if (image is None or timestep is None) and not is_strength_max:
            # 抛出值错误,提示必须提供图像或噪声时间步
            raise ValueError(
                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
                "However, either the image or the noise timestep has not been provided."
            )

        # 检查是否返回图像潜变量,或潜变量为 None 且强度未达到最大值
        if return_image_latents or (latents is None and not is_strength_max):
            # 将图像转换到指定设备和数据类型
            image = image.to(device=device, dtype=dtype)

            # 检查图像的通道数是否为 16
            if image.shape[1] == 16:
                # 如果是,则直接将图像潜变量设置为图像
                image_latents = image
            else:
                # 否则,通过 VAE 编码图像来获取潜变量
                image_latents = self._encode_vae_image(image=image, generator=generator)
            # 根据批量大小重复潜变量,以匹配批量尺寸
            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)

        # 检查潜变量是否为 None
        if latents is None:
            # 根据形状生成噪声张量
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            # 如果强度为 1,则初始化潜变量为噪声,否则初始化为图像与噪声的组合
            latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
        else:
            # 将潜变量移动到指定设备
            noise = latents.to(device)
            # 直接将噪声赋值给潜变量
            latents = noise

        # 创建输出元组,包含潜变量
        outputs = (latents,)

        # 如果需要返回噪声,则将噪声添加到输出中
        if return_noise:
            outputs += (noise,)

        # 如果需要返回图像潜变量,则将其添加到输出中
        if return_image_latents:
            outputs += (image_latents,)

        # 返回输出元组
        return outputs

    # 定义编码 VAE 图像的函数
    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
        # 检查生成器是否为列表
        if isinstance(generator, list):
            # 遍历图像,编码每个图像并提取潜变量
            image_latents = [
                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
                for i in range(image.shape[0])
            ]
            # 将潜变量沿着第一个维度拼接成一个张量
            image_latents = torch.cat(image_latents, dim=0)
        else:
            # 如果不是列表,则直接编码整个图像
            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

        # 对潜变量进行缩放和偏移处理
        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

        # 返回处理后的潜变量
        return image_latents

    # 定义准备掩码潜变量的函数
    def prepare_mask_latents(
        self,
        mask,                       # 掩码张量
        masked_image,              # 被掩盖的图像
        batch_size,                # 批量大小
        num_images_per_prompt,     # 每个提示的图像数量
        height,                    # 图像高度
        width,                     # 图像宽度
        dtype,                     # 数据类型
        device,                    # 设备类型
        generator,                 # 随机生成器
        do_classifier_free_guidance,# 是否进行无分类器引导
    ):
        # 将掩码调整为与潜在向量形状相同,以便在连接掩码和潜在向量时使用
        # 在转换数据类型之前执行此操作,以避免在使用 cpu_offload 和半精度时出现问题
        mask = torch.nn.functional.interpolate(
            # 使用插值方法将掩码调整为指定的高度和宽度
            mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
        )
        # 将掩码移动到指定的设备并转换为指定的数据类型
        mask = mask.to(device=device, dtype=dtype)

        # 计算总的批大小,考虑每个提示生成的图像数量
        batch_size = batch_size * num_images_per_prompt

        # 将遮罩图像移动到指定的设备并转换为指定的数据类型
        masked_image = masked_image.to(device=device, dtype=dtype)

        # 如果掩码图像的形状为 16,直接赋值给潜在图像变量
        if masked_image.shape[1] == 16:
            masked_image_latents = masked_image
        else:
            # 使用 VAE 编码器检索潜在图像
            masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)

        # 对潜在图像进行归一化处理,减去偏移量并乘以缩放因子
        masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor

        # 为每个提示的生成重复掩码和潜在图像,使用适合 MPS 的方法
        if mask.shape[0] < batch_size:
            # 检查掩码数量是否能整除批大小
            if not batch_size % mask.shape[0] == 0:
                raise ValueError(
                    # 如果掩码数量和批大小不匹配,抛出错误
                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
                    " of masks that you pass is divisible by the total requested batch size."
                )
            # 根据批大小重复掩码
            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
        # 检查潜在图像数量是否能整除批大小
        if masked_image_latents.shape[0] < batch_size:
            if not batch_size % masked_image_latents.shape[0] == 0:
                raise ValueError(
                    # 如果潜在图像数量和批大小不匹配,抛出错误
                    "The passed images and the required batch size don't match. Images are supposed to be duplicated"
                    f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
                    " Make sure the number of images that you pass is divisible by the total requested batch size."
                )
            # 根据批大小重复潜在图像
            masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)

        # 根据是否使用无分类器自由引导选择重复掩码
        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
        # 根据是否使用无分类器自由引导选择重复潜在图像
        masked_image_latents = (
            torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
        )

        # 将潜在图像移动到指定的设备并转换为指定的数据类型,以防拼接时出现设备错误
        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
        # 返回处理后的掩码和潜在图像
        return mask, masked_image_latents

    @property
    def guidance_scale(self):
        # 返回指导缩放因子
        return self._guidance_scale

    @property
    def clip_skip(self):
        # 返回跳过剪辑的参数
        return self._clip_skip

    # 此处 `guidance_scale` 类似于方程 (2) 中的指导权重 `w`
    # 来自 Imagen 论文: https://arxiv.org/pdf/2205.11487.pdf 。`guidance_scale = 1`
    # 表示不进行分类器自由引导的情况
    @property
    def do_classifier_free_guidance(self):
        # 判断引导尺度是否大于1,返回布尔值
        return self._guidance_scale > 1

    @property
    def num_timesteps(self):
        # 返回时间步的数量
        return self._num_timesteps

    @property
    def interrupt(self):
        # 返回中断状态
        return self._interrupt

    # 该装饰器用于禁止梯度计算,减少内存消耗
    @torch.no_grad()
    # 替换文档字符串为示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        # 接收的提示文本,可以是字符串或字符串列表
        prompt: Union[str, List[str]] = None,
        # 第二个提示文本,默认为 None
        prompt_2: Optional[Union[str, List[str]]] = None,
        # 第三个提示文本,默认为 None
        prompt_3: Optional[Union[str, List[str]]] = None,
        # 输入的图像,可以是图像数据
        image: PipelineImageInput = None,
        # 掩码图像,用于特定操作
        mask_image: PipelineImageInput = None,
        # 被掩码的图像潜变量
        masked_image_latents: PipelineImageInput = None,
        # 图像高度,默认为 None
        height: int = None,
        # 图像宽度,默认为 None
        width: int = None,
        # 可选的填充掩码裁剪值,默认为 None
        padding_mask_crop: Optional[int] = None,
        # 强度参数,默认为 0.6
        strength: float = 0.6,
        # 推理步骤的数量,默认为 50
        num_inference_steps: int = 50,
        # 时间步列表,默认为 None
        timesteps: List[int] = None,
        # 引导尺度,默认为 7.0
        guidance_scale: float = 7.0,
        # 可选的负面提示文本,默认为 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 第二个负面提示文本,默认为 None
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        # 第三个负面提示文本,默认为 None
        negative_prompt_3: Optional[Union[str, List[str]]] = None,
        # 每个提示生成的图像数量,默认为 1
        num_images_per_prompt: Optional[int] = 1,
        # 随机数生成器,可以是单个生成器或生成器列表
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潜变量,默认为 None
        latents: Optional[torch.Tensor] = None,
        # 提示嵌入,默认为 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负面提示嵌入,默认为 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 聚合的提示嵌入,默认为 None
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 负面聚合提示嵌入,默认为 None
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 输出类型,默认为 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典,默认为 True
        return_dict: bool = True,
        # 可选的跳过剪辑参数,默认为 None
        clip_skip: Optional[int] = None,
        # 每步结束时的回调函数,默认为 None
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        # 在每步结束时使用的张量输入回调,默认为 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 最大序列长度,默认为 256
        max_sequence_length: int = 256,

.\diffusers\pipelines\stable_diffusion_3\__init__.py

# 从类型检查模块导入 TYPE_CHECKING,用于静态类型检查
from typing import TYPE_CHECKING

# 从上级目录的 utils 模块导入所需的工具和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入标志,指示是否慢速导入
    OptionalDependencyNotAvailable,  # 导入自定义异常,用于处理可选依赖未满足的情况
    _LazyModule,  # 导入延迟加载模块的工具
    get_objects_from_module,  # 导入从模块获取对象的工具
    is_flax_available,  # 导入检查 Flax 库是否可用的工具
    is_torch_available,  # 导入检查 PyTorch 库是否可用的工具
    is_transformers_available,  # 导入检查 Transformers 库是否可用的工具
)

# 初始化一个空字典,用于存放虚拟对象
_dummy_objects = {}
# 初始化一个空字典,用于存放额外的导入对象
_additional_imports = {}
# 定义模块的导入结构,指定模块中的输出内容
_import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]}

# 尝试检查 Transformers 和 Torch 是否可用
try:
    if not (is_transformers_available() and is_torch_available()):  # 如果两个库都不可用
        raise OptionalDependencyNotAvailable()  # 抛出可选依赖未满足的异常
except OptionalDependencyNotAvailable:  # 捕获异常
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403  # 从 utils 导入虚拟对象的模块

    # 更新虚拟对象字典,获取从虚拟模块中提取的对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:  # 如果没有异常发生
    # 将可用的模块添加到导入结构中
    _import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"]
    _import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
    _import_structure["pipeline_stable_diffusion_3_inpaint"] = ["StableDiffusion3InpaintPipeline"]

# 检查是否在类型检查阶段或者是否慢速导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        if not (is_transformers_available() and is_torch_available()):  # 如果两个库都不可用
            raise OptionalDependencyNotAvailable()  # 抛出可选依赖未满足的异常
    except OptionalDependencyNotAvailable:  # 捕获异常
        from ...utils.dummy_torch_and_transformers_objects import *  # noqa F403  # 从 utils 导入虚拟对象模块
    else:  # 如果没有异常发生
        # 从稳定扩散的模块导入相应的管道类
        from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
        from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
        from .pipeline_stable_diffusion_3_inpaint import StableDiffusion3InpaintPipeline

else:  # 如果不是类型检查阶段或慢速导入
    import sys  # 导入系统模块

    # 使用延迟加载模块替换当前模块的 sys.modules 条目
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 当前文件路径
        _import_structure,  # 模块的导入结构
        module_spec=__spec__,  # 模块的规格
    )

    # 将虚拟对象字典中的对象添加到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)
    # 将额外的导入对象添加到当前模块
    for name, value in _additional_imports.items():
        setattr(sys.modules[__name__], name, value)

.\diffusers\pipelines\stable_diffusion_attend_and_excite\pipeline_stable_diffusion_attend_and_excite.py

# 版权声明,说明该文件的版权归 HuggingFace 团队所有
# 
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守该许可证,否则不得使用此文件。
# 可以在以下网址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,软件
# 根据许可证分发是以“原样”基础进行的,
# 不提供任何形式的明示或暗示的担保或条件。
# 查看许可证以了解有关权限和
# 限制的具体信息。

# 导入 inspect 模块,用于获取对象的信息
import inspect
# 导入 math 模块,提供数学函数
import math
# 从 typing 模块导入各种类型注解
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# 导入 numpy 库,进行数值计算
import numpy as np
# 导入 torch 库,进行深度学习操作
import torch
# 从 torch.nn.functional 导入常用的神经网络功能
from torch.nn import functional as F
# 从 transformers 库导入 CLIP 相关的处理器和模型
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

# 从相对路径导入 VaeImageProcessor 类
from ...image_processor import VaeImageProcessor
# 从相对路径导入 Lora 和 Textual Inversion 的加载器混合类
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
# 从相对路径导入自动编码器和 U-Net 模型
from ...models import AutoencoderKL, UNet2DConditionModel
# 从相对路径导入 Attention 类
from ...models.attention_processor import Attention
# 从相对路径导入调整 LoRA 规模的函数
from ...models.lora import adjust_lora_scale_text_encoder
# 从相对路径导入 Karras Diffusion 调度器
from ...schedulers import KarrasDiffusionSchedulers
# 从相对路径导入实用工具
from ...utils import (
    USE_PEFT_BACKEND,       # 导入标志以使用 PEFT 后端
    deprecate,             # 导入用于标记弃用功能的工具
    logging,               # 导入日志记录功能
    replace_example_docstring, # 导入用于替换示例文档字符串的工具
    scale_lora_layers,     # 导入缩放 LoRA 层的函数
    unscale_lora_layers,   # 导入取消缩放 LoRA 层的函数
)
# 从相对路径导入随机张量生成工具
from ...utils.torch_utils import randn_tensor
# 从相对路径导入扩散管道和稳定扩散混合类
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
# 从相对路径导入稳定扩散的输出模型
from ..stable_diffusion import StableDiffusionPipelineOutput
# 从相对路径导入稳定扩散的安全检查器
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker

# 创建日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__)

# 示例文档字符串,展示如何使用此类的示例代码
EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionAttendAndExcitePipeline

        >>> pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
        ...     "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
        ... ).to("cuda")

        >>> prompt = "a cat and a frog"

        >>> # 使用 get_indices 函数查找要更改的令牌的索引
        >>> pipe.get_indices(prompt)
        {0: '<|startoftext|>', 1: 'a</w>', 2: 'cat</w>', 3: 'and</w>', 4: 'a</w>', 5: 'frog</w>', 6: '<|endoftext|>'}

        >>> token_indices = [2, 5]
        >>> seed = 6141
        >>> generator = torch.Generator("cuda").manual_seed(seed)

        >>> images = pipe(
        ...     prompt=prompt,
        ...     token_indices=token_indices,
        ...     guidance_scale=7.5,
        ...     generator=generator,
        ...     num_inference_steps=50,
        ...     max_iter_to_alter=25,
        ... ).images

        >>> image = images[0]
        >>> image.save(f"../images/{prompt}_{seed}.png")
        ```py
"""

# 定义 AttentionStore 类
class AttentionStore:
    # 定义静态方法,用于获取一个空的注意力存储
    @staticmethod
    def get_empty_store():
        # 返回一个包含三个空列表的字典,分别对应不同的注意力层
        return {"down": [], "mid": [], "up": []}
    # 定义一个可调用的函数,处理注意力矩阵
        def __call__(self, attn, is_cross: bool, place_in_unet: str):
            # 如果当前注意力层索引有效且为交叉注意力
            if self.cur_att_layer >= 0 and is_cross:
                # 检查注意力矩阵的形状是否与期望的分辨率一致
                if attn.shape[1] == np.prod(self.attn_res):
                    # 将当前注意力矩阵存储到相应位置
                    self.step_store[place_in_unet].append(attn)
    
            # 更新当前注意力层索引
            self.cur_att_layer += 1
            # 如果达到最后一层,重置索引并调用间隔步骤方法
            if self.cur_att_layer == self.num_att_layers:
                self.cur_att_layer = 0
                self.between_steps()
    
        # 定义间隔步骤方法,用于更新注意力存储
        def between_steps(self):
            # 将步骤存储的注意力矩阵赋值给注意力存储
            self.attention_store = self.step_store
            # 获取一个空的步骤存储
            self.step_store = self.get_empty_store()
    
        # 获取平均注意力矩阵
        def get_average_attention(self):
            # 将注意力存储返回为平均注意力
            average_attention = self.attention_store
            return average_attention
    
        # 聚合来自不同层和头部的注意力矩阵
        def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
            """在指定的分辨率下聚合不同层和头部的注意力。"""
            out = []  # 初始化输出列表
            attention_maps = self.get_average_attention()  # 获取平均注意力
            # 遍历来源位置
            for location in from_where:
                # 遍历对应的注意力矩阵
                for item in attention_maps[location]:
                    # 重塑注意力矩阵为适当形状
                    cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
                    # 将重塑的矩阵添加到输出列表
                    out.append(cross_maps)
            # 沿第0维连接所有注意力矩阵
            out = torch.cat(out, dim=0)
            # 计算所有矩阵的平均值
            out = out.sum(0) / out.shape[0]
            return out  # 返回聚合后的注意力矩阵
    
        # 重置注意力存储和索引
        def reset(self):
            self.cur_att_layer = 0  # 重置当前注意力层索引
            self.step_store = self.get_empty_store()  # 重置步骤存储
            self.attention_store = {}  # 清空注意力存储
    
        # 初始化方法,设置初始参数
        def __init__(self, attn_res):
            """
            初始化一个空的 AttentionStore :param step_index: 用于可视化扩散过程中的特定步骤
            """
            self.num_att_layers = -1  # 初始化注意力层数量
            self.cur_att_layer = 0  # 初始化当前注意力层索引
            self.step_store = self.get_empty_store()  # 初始化步骤存储
            self.attention_store = {}  # 初始化注意力存储
            self.curr_step_index = 0  # 初始化当前步骤索引
            self.attn_res = attn_res  # 设置注意力分辨率
# 定义一个 AttendExciteAttnProcessor 类
class AttendExciteAttnProcessor:
    # 初始化方法,接受注意力存储器和 UNet 中的位置
    def __init__(self, attnstore, place_in_unet):
        # 调用父类的初始化方法
        super().__init__()
        # 存储传入的注意力存储器
        self.attnstore = attnstore
        # 存储在 UNet 中的位置
        self.place_in_unet = place_in_unet

    # 定义调用方法,处理注意力计算
    def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        # 获取批次大小和序列长度
        batch_size, sequence_length, _ = hidden_states.shape
        # 准备注意力掩码
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        # 将隐藏状态转换为查询向量
        query = attn.to_q(hidden_states)

        # 判断是否为交叉注意力
        is_cross = encoder_hidden_states is not None
        # 如果没有编码器隐藏状态,则使用隐藏状态
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        # 将编码器隐藏状态转换为键向量
        key = attn.to_k(encoder_hidden_states)
        # 将编码器隐藏状态转换为值向量
        value = attn.to_v(encoder_hidden_states)

        # 将查询、键和值转换为批次维度
        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # 计算注意力分数
        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        # 仅在 Attend 和 Excite 过程中存储注意力图
        if attention_probs.requires_grad:
            # 存储注意力概率
            self.attnstore(attention_probs, is_cross, self.place_in_unet)

        # 使用注意力概率和值向量计算新的隐藏状态
        hidden_states = torch.bmm(attention_probs, value)
        # 将隐藏状态转换回头维度
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # 进行线性变换
        hidden_states = attn.to_out[0](hidden_states)
        # 应用 dropout
        hidden_states = attn.to_out[1](hidden_states)

        # 返回更新后的隐藏状态
        return hidden_states


# 定义一个 StableDiffusionAttendAndExcitePipeline 类,继承多个基类
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin):
    r"""
    使用 Stable Diffusion 和 Attend-and-Excite 进行文本到图像生成的管道。

    该模型继承自 [`DiffusionPipeline`]。请查看超类文档,以获取所有管道实现的通用方法
    (下载、保存、在特定设备上运行等)。

    该管道还继承以下加载方法:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] 用于加载文本反演嵌入
    # 定义函数参数的说明文档,描述各参数的作用
    Args:
        vae ([`AutoencoderKL`]):
            # Variational Auto-Encoder (VAE) 模型,用于将图像编码和解码为潜在表示
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            # 冻结的文本编码器,使用 CLIP 模型进行文本特征提取
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            # 用于将文本进行分词的 CLIPTokenizer
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            # UNet 模型,用于对编码后的图像潜在特征进行去噪处理
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            # 调度器,用于与 UNet 结合去噪编码后的图像潜在特征,可以是 DDIM、LMS 或 PNDM 调度器
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            # 分类模块,用于评估生成的图像是否可能被认为是冒犯性或有害的
            Classification module that estimates whether generated images could be considered offensive or harmful.
            # 参见模型卡以获取有关模型潜在危害的更多详细信息
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            # CLIP 图像处理器,用于从生成的图像中提取特征;这些特征作为输入提供给安全检查器
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    # 定义模型在 CPU 上的卸载顺序
    model_cpu_offload_seq = "text_encoder->unet->vae"
    # 定义可选组件列表
    _optional_components = ["safety_checker", "feature_extractor"]
    # 定义在 CPU 卸载时排除的组件
    _exclude_from_cpu_offload = ["safety_checker"]

    # 初始化方法定义,接收各个参数
    def __init__(
        # VAE 模型实例
        self,
        vae: AutoencoderKL,
        # 文本编码器实例
        text_encoder: CLIPTextModel,
        # 分词器实例
        tokenizer: CLIPTokenizer,
        # UNet 模型实例
        unet: UNet2DConditionModel,
        # 调度器实例
        scheduler: KarrasDiffusionSchedulers,
        # 安全检查器实例
        safety_checker: StableDiffusionSafetyChecker,
        # 特征提取器实例
        feature_extractor: CLIPImageProcessor,
        # 是否需要安全检查器的布尔标志,默认为 True
        requires_safety_checker: bool = True,
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 检查安全检查器是否未定义且需要安全检查器时发出警告
        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 检查安全检查器已定义但特征提取器未定义时引发错误
        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        # 注册多个模块,包括 VAE、文本编码器、标记器等
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        # 计算 VAE 的缩放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 创建图像处理器,使用 VAE 缩放因子
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        # 将配置中是否需要安全检查器进行注册
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 复制
    def _encode_prompt(
        # 定义编码提示所需的参数
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        **kwargs,
    ):
        # 定义弃用消息,告知用户 `_encode_prompt()` 方法将被移除,建议使用 `encode_prompt()`
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        # 调用 deprecate 函数记录弃用信息,设置标准警告为 False
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        # 调用 encode_prompt 方法生成提示嵌入元组,传入必要参数
        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            **kwargs,
        )

        # 连接嵌入元组中的两个部分以兼容以前的实现
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        # 返回最终的提示嵌入
        return prompt_embeds

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的 encode_prompt 方法
    def encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的 run_safety_checker 方法
    def run_safety_checker(self, image, device, dtype):
        # 如果安全检查器未定义,则设置 has_nsfw_concept 为 None
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            # 如果输入为张量,进行后处理以转换为 PIL 格式
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                # 如果输入为 numpy 数组,转换为 PIL 格式
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            # 提取特征并将其移动到指定设备上
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            # 使用安全检查器处理图像,并返回处理后的图像和 NSFW 概念标识
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        # 返回处理后的图像和 NSFW 概念标识
        return image, has_nsfw_concept

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的 decode_latents 方法
    # 解码潜在变量
    def decode_latents(self, latents):
        # 定义弃用警告信息
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        # 调用弃用函数,发出警告
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
    
        # 按照缩放因子调整潜在变量
        latents = 1 / self.vae.config.scaling_factor * latents
        # 解码潜在变量,返回第一个输出
        image = self.vae.decode(latents, return_dict=False)[0]
        # 归一化图像并限制在[0, 1]范围内
        image = (image / 2 + 0.5).clamp(0, 1)
        # 将图像转换为float32格式,适应bfloat16并避免显著开销
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回最终图像
        return image
    
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制
    def prepare_extra_step_kwargs(self, generator, eta):
        # 准备调度器步骤的额外参数,因为不是所有调度器具有相同的签名
        # eta (η) 仅用于DDIMScheduler,其他调度器将被忽略
        # eta 对应于DDIM论文中的η: https://arxiv.org/abs/2010.02502
        # 应在 [0, 1] 范围内
    
        # 检查调度器是否接受eta参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化额外步骤参数字典
        extra_step_kwargs = {}
        # 如果接受eta,则将其添加到字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta
    
        # 检查调度器是否接受generator参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受generator,则将其添加到字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回额外步骤参数字典
        return extra_step_kwargs
    
    # 检查输入参数
    def check_inputs(
        self,
        prompt,
        indices,
        height,
        width,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
    ):
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制
        def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
            # 定义形状以匹配潜在变量的尺寸
            shape = (
                batch_size,
                num_channels_latents,
                int(height) // self.vae_scale_factor,
                int(width) // self.vae_scale_factor,
            )
            # 检查生成器的数量是否与批大小匹配
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )
    
            # 如果潜在变量为空,则生成随机张量
            if latents is None:
                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
            else:
                # 将给定的潜在变量移动到指定设备
                latents = latents.to(device)
    
            # 按调度器要求的标准差缩放初始噪声
            latents = latents * self.scheduler.init_noise_sigma
            # 返回调整后的潜在变量
            return latents
    
        @staticmethod
    # 计算每个需要修改的 token 的最大注意力值
    def _compute_max_attention_per_index(
        # 输入的注意力图张量
        attention_maps: torch.Tensor,
        # 需要关注的 token 索引列表
        indices: List[int],
    ) -> List[torch.Tensor]:
        """计算我们希望改变的每个 token 的最大注意力值。"""
        # 获取注意力图中去掉首尾 token 的部分
        attention_for_text = attention_maps[:, :, 1:-1]
        # 将注意力值放大 100 倍
        attention_for_text *= 100
        # 对注意力值进行 softmax 处理,规范化
        attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)

        # 因为去掉了第一个 token,调整索引
        indices = [index - 1 for index in indices]

        # 提取最大值的列表
        max_indices_list = []
        # 遍历每个索引
        for i in indices:
            # 获取指定索引的注意力图
            image = attention_for_text[:, :, i]
            # 创建高斯平滑对象并移动到相应设备
            smoothing = GaussianSmoothing().to(attention_maps.device)
            # 对图像进行填充以反射模式
            input = F.pad(image.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode="reflect")
            # 应用高斯平滑,并去掉多余的维度
            image = smoothing(input).squeeze(0).squeeze(0)
            # 将最大值添加到结果列表
            max_indices_list.append(image.max())
        # 返回最大值列表
        return max_indices_list

    # 聚合每个 token 的注意力并计算最大激活值
    def _aggregate_and_get_max_attention_per_token(
        self,
        # 需要关注的 token 索引列表
        indices: List[int],
    ):
        """聚合每个 token 的注意力,并计算每个 token 的最大激活值。"""
        # 从注意力存储中聚合注意力图
        attention_maps = self.attention_store.aggregate_attention(
            # 从不同来源获取的注意力图
            from_where=("up", "down", "mid"),
        )
        # 计算每个 token 的最大注意力值
        max_attention_per_index = self._compute_max_attention_per_index(
            # 传入注意力图和索引
            attention_maps=attention_maps,
            indices=indices,
        )
        # 返回最大注意力值
        return max_attention_per_index

    @staticmethod
    # 计算损失值
    def _compute_loss(max_attention_per_index: List[torch.Tensor]) -> torch.Tensor:
        """使用每个 token 的最大注意力值计算 attend-and-excite 损失。"""
        # 计算损失列表,确保不低于 0
        losses = [max(0, 1.0 - curr_max) for curr_max in max_attention_per_index]
        # 获取损失中的最大值
        loss = max(losses)
        # 返回最大损失
        return loss

    @staticmethod
    # 更新潜在变量
    def _update_latent(latents: torch.Tensor, loss: torch.Tensor, step_size: float) -> torch.Tensor:
        """根据计算出的损失更新潜在变量。"""
        # 计算损失对潜在变量的梯度
        grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents], retain_graph=True)[0]
        # 更新潜在变量,使用学习率乘以梯度
        latents = latents - step_size * grad_cond
        # 返回更新后的潜在变量
        return latents

    # 进行迭代细化步骤
    def _perform_iterative_refinement_step(
        # 潜在变量张量
        latents: torch.Tensor,
        # 需要关注的 token 索引列表
        indices: List[int],
        # 当前损失值
        loss: torch.Tensor,
        # 阈值,用于判断
        threshold: float,
        # 文本嵌入张量
        text_embeddings: torch.Tensor,
        # 学习率
        step_size: float,
        # 当前迭代步数
        t: int,
        # 最大细化步骤,默认 20
        max_refinement_steps: int = 20,
    ):
        """
        执行论文中引入的迭代潜在优化。我们根据损失目标持续更新潜在代码,直到所有令牌达到给定的阈值。
        """
        # 初始化迭代计数器
        iteration = 0
        # 计算目标损失值,确保不小于 0
        target_loss = max(0, 1.0 - threshold)
        # 当当前损失大于目标损失时,持续迭代
        while loss > target_loss:
            # 迭代计数加一
            iteration += 1

            # 克隆潜在变量并准备计算梯度
            latents = latents.clone().detach().requires_grad_(True)
            # 使用 UNet 模型处理潜在变量,生成样本
            self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
            # 清零 UNet 模型的梯度
            self.unet.zero_grad()

            # 获取每个主题令牌的最大激活值
            max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
                indices=indices,
            )

            # 计算当前的损失值
            loss = self._compute_loss(max_attention_per_index)

            # 如果损失不为零,更新潜在变量
            if loss != 0:
                latents = self._update_latent(latents, loss, step_size)

            # 记录当前迭代和损失信息
            logger.info(f"\t Try {iteration}. loss: {loss}")

            # 如果达到最大迭代步数,记录并退出循环
            if iteration >= max_refinement_steps:
                logger.info(f"\t Exceeded max number of iterations ({max_refinement_steps})! ")
                break

        # 再次运行但不计算梯度,也不更新潜在变量,仅计算新损失
        latents = latents.clone().detach().requires_grad_(True)
        _ = self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
        self.unet.zero_grad()

        # 获取每个主题令牌的最大激活值
        max_attention_per_index = self._aggregate_and_get_max_attention_per_token(
            indices=indices,
        )
        # 计算当前损失值
        loss = self._compute_loss(max_attention_per_index)
        # 记录最终损失信息
        logger.info(f"\t Finished with loss of: {loss}")
        # 返回损失、潜在变量和最大激活值索引
        return loss, latents, max_attention_per_index

    def register_attention_control(self):
        # 初始化注意力处理器字典
        attn_procs = {}
        # 交叉注意力计数
        cross_att_count = 0
        # 遍历 UNet 中的注意力处理器
        for name in self.unet.attn_processors.keys():
            # 根据名称确定位置
            if name.startswith("mid_block"):
                place_in_unet = "mid"
            elif name.startswith("up_blocks"):
                place_in_unet = "up"
            elif name.startswith("down_blocks"):
                place_in_unet = "down"
            else:
                continue

            # 交叉注意力计数加一
            cross_att_count += 1
            # 创建注意力处理器并添加到字典
            attn_procs[name] = AttendExciteAttnProcessor(attnstore=self.attention_store, place_in_unet=place_in_unet)

        # 设置 UNet 的注意力处理器
        self.unet.set_attn_processor(attn_procs)
        # 更新注意力层的数量
        self.attention_store.num_att_layers = cross_att_count

    def get_indices(self, prompt: str) -> Dict[str, int]:
        """用于列出要更改的令牌的索引的实用函数"""
        # 将提示转换为输入 ID
        ids = self.tokenizer(prompt).input_ids
        # 创建令牌到索引的映射字典
        indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))}
        # 返回索引字典
        return indices

    @torch.no_grad()
    # 替换示例文档字符串的装饰器
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用的方法,接受多个参数以生成图像
        def __call__(
            # 提示信息,可以是字符串或字符串列表
            self,
            prompt: Union[str, List[str]],
            # 令牌索引,可以是整数列表或列表的列表
            token_indices: Union[List[int], List[List[int]]],
            # 可选的图像高度,默认为 None
            height: Optional[int] = None,
            # 可选的图像宽度,默认为 None
            width: Optional[int] = None,
            # 生成的推理步骤数,默认为 50
            num_inference_steps: int = 50,
            # 指导比例,默认为 7.5
            guidance_scale: float = 7.5,
            # 可选的负提示,可以是字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认为 1
            num_images_per_prompt: int = 1,
            # ETA 值,默认为 0.0
            eta: float = 0.0,
            # 可选的随机数生成器,可以是单个生成器或生成器列表
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 可选的潜在变量张量,默认为 None
            latents: Optional[torch.Tensor] = None,
            # 可选的提示嵌入张量,默认为 None
            prompt_embeds: Optional[torch.Tensor] = None,
            # 可选的负提示嵌入张量,默认为 None
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 输出类型,默认为 "pil"
            output_type: Optional[str] = "pil",
            # 是否返回字典形式的结果,默认为 True
            return_dict: bool = True,
            # 可选的回调函数,用于处理生成过程中的状态
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回调函数调用的步骤间隔,默认为 1
            callback_steps: int = 1,
            # 可选的交叉注意力参数
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 最大迭代次数,默认为 25
            max_iter_to_alter: int = 25,
            # 阈值字典,定义不同步长下的阈值
            thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
            # 缩放因子,默认为 20
            scale_factor: int = 20,
            # 可选的注意力分辨率元组,默认为 (16, 16)
            attn_res: Optional[Tuple[int]] = (16, 16),
            # 可选的跳过剪辑次数,默认为 None
            clip_skip: Optional[int] = None,
# 定义一个继承自 PyTorch 模块的高斯平滑类
class GaussianSmoothing(torch.nn.Module):
    """
    参数:
    对 1D、2D 或 3D 张量应用高斯平滑。每个通道分别使用深度卷积进行过滤。
        channels (int, sequence): 输入张量的通道数。输出将具有相同数量的通道。
        kernel_size (int, sequence): 高斯核的大小。 sigma (float, sequence): 高斯核的标准差。
        dim (int, optional): 数据的维度数量。默认值为 2(空间)。
    """

    # channels=1, kernel_size=kernel_size, sigma=sigma, dim=2
    # 初始化方法,设置高斯平滑的参数
    def __init__(
        self,
        channels: int = 1,  # 输入通道数,默认为1
        kernel_size: int = 3,  # 高斯核的大小,默认为3
        sigma: float = 0.5,  # 高斯核的标准差,默认为0.5
        dim: int = 2,  # 数据维度,默认为2
    ):
        super().__init__()  # 调用父类的初始化方法

        # 如果 kernel_size 是一个整数,则将其转换为对应维度的列表
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size] * dim
        # 如果 sigma 是一个浮点数,则将其转换为对应维度的列表
        if isinstance(sigma, float):
            sigma = [sigma] * dim

        # 高斯核是每个维度高斯函数的乘积
        kernel = 1  # 初始化高斯核为1
        # 创建高斯核的网格,生成每个维度的网格
        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
        # 遍历每个维度的大小、标准差和网格
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2  # 计算高斯分布的均值
            # 更新高斯核,计算当前维度的高斯值并与核相乘
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))

        # 确保高斯核的值之和等于1
        kernel = kernel / torch.sum(kernel)

        # 将高斯核重塑为深度卷积权重的形状
        kernel = kernel.view(1, 1, *kernel.size())  # 重塑为卷积所需的格式
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))  # 复制以适应输入通道数

        # 注册高斯核作为模块的缓冲区
        self.register_buffer("weight", kernel)
        self.groups = channels  # 设置分组卷积的通道数

        # 根据维度选择相应的卷积操作
        if dim == 1:
            self.conv = F.conv1d  # 1D 卷积
        elif dim == 2:
            self.conv = F.conv2d  # 2D 卷积
        elif dim == 3:
            self.conv = F.conv3d  # 3D 卷积
        else:
            # 如果维度不支持,则抛出运行时错误
            raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))

    # 前向传播方法,应用高斯滤波
    def forward(self, input):
        """
        参数:
        对输入应用高斯滤波。
            input (torch.Tensor): 需要应用高斯滤波的输入。
        返回:
            filtered (torch.Tensor): 滤波后的输出。
        """
        # 使用选择的卷积方法对输入进行卷积,返回滤波结果
        return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups)

.\diffusers\pipelines\stable_diffusion_attend_and_excite\__init__.py

# 导入类型检查相关的常量
from typing import TYPE_CHECKING

# 从上层目录的 utils 模块导入所需的工具函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 延迟导入的标志
    OptionalDependencyNotAvailable,  # 可选依赖不可用的异常
    _LazyModule,  # 延迟加载模块的类
    get_objects_from_module,  # 从模块获取对象的函数
    is_torch_available,  # 检查 PyTorch 是否可用的函数
    is_transformers_available,  # 检查 Transformers 是否可用的函数
)

# 初始化一个空字典用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}

# 尝试检查可选依赖是否可用
try:
    if not (is_transformers_available() and is_torch_available()):  # 检查 Transformers 和 PyTorch 是否可用
        raise OptionalDependencyNotAvailable()  # 如果不可用,抛出异常
except OptionalDependencyNotAvailable:  # 捕获可选依赖不可用的异常
    from ...utils import dummy_torch_and_transformers_objects  # 导入虚拟对象以防可选依赖不可用 # noqa F403

    # 更新虚拟对象字典,添加虚拟对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:  # 如果没有抛出异常
    # 更新导入结构字典,添加 StableDiffusionAttendAndExcitePipeline
    _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]

# 检查是否在类型检查中或需要慢速导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        if not (is_transformers_available() and is_torch_available()):  # 再次检查可选依赖是否可用
            raise OptionalDependencyNotAvailable()  # 如果不可用,抛出异常

    except OptionalDependencyNotAvailable:  # 捕获可选依赖不可用的异常
        from ...utils.dummy_torch_and_transformers_objects import *  # 导入虚拟对象以防可选依赖不可用
    else:  # 如果没有抛出异常
        from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline  # 导入实际的管道类

else:  # 如果不是类型检查或者不需要慢速导入
    import sys  # 导入 sys 模块

    # 将当前模块替换为延迟加载模块
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 当前文件的全局变量
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规格
    )

    # 遍历虚拟对象字典,设置模块中的虚拟对象
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 为当前模块添加虚拟对象

.\diffusers\pipelines\stable_diffusion_diffedit\pipeline_stable_diffusion_diffedit.py

# 版权声明,表明代码的作者及版权信息
# Copyright 2024 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版进行授权(“许可证”); 
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获得许可证的副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意, 
# 否则根据许可证分发的软件是在“按现状”基础上分发的, 
# 不提供任何形式的明示或暗示的担保或条件。
# 有关许可证所管辖的权限和限制的具体信息,请参阅许可证。
import inspect  # 导入 inspect 模块,用于获取有关对象的信息
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器,用于简化类定义
from typing import Any, Callable, Dict, List, Optional, Union  # 导入类型提示相关的类型

import numpy as np  # 导入 numpy 库,用于数值计算
import PIL.Image  # 导入 PIL.Image,用于图像处理
import torch  # 导入 PyTorch 库,用于深度学习
from packaging import version  # 导入 version 模块,用于版本控制
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer  # 导入 CLIP 相关模型和处理器

from ...configuration_utils import FrozenDict  # 从上层模块导入 FrozenDict,用于配置管理
from ...image_processor import VaeImageProcessor  # 从上层模块导入 VaeImageProcessor,用于图像处理
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin  # 导入混合类用于加载器
from ...models import AutoencoderKL, UNet2DConditionModel  # 导入模型类
from ...models.lora import adjust_lora_scale_text_encoder  # 导入调整 LORA 规模的函数
from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers  # 导入调度器
from ...utils import (  # 导入各种工具函数和常量
    PIL_INTERPOLATION,  # PIL 图像插值常量
    USE_PEFT_BACKEND,  # 使用 PEFT 后端的常量
    BaseOutput,  # 基础输出类
    deprecate,  # 用于标记已弃用的函数
    logging,  # 日志记录模块
    replace_example_docstring,  # 替换示例文档字符串的函数
    scale_lora_layers,  # 缩放 LORA 层的函数
    unscale_lora_layers,  # 取消缩放 LORA 层的函数
)
from ...utils.torch_utils import randn_tensor  # 从工具模块导入随机张量生成函数
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道和混合类
from ..stable_diffusion import StableDiffusionPipelineOutput  # 导入稳定扩散管道输出类
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker  # 导入安全检查器

logger = logging.get_logger(__name__)  # 创建一个名为当前模块的日志记录器,便于调试和信息记录

@dataclass  # 使用 dataclass 装饰器定义一个数据类
class DiffEditInversionPipelineOutput(BaseOutput):  # 继承基础输出类
    """
    Stable Diffusion 管道的输出类。

    参数:
        latents (`torch.Tensor`)
            反转的潜变量张量
        images (`List[PIL.Image.Image]` 或 `np.ndarray`)
            一个 PIL 图像的列表,长度为 `num_timesteps * batch_size` 或形状为 `(num_timesteps,
            batch_size, height, width, num_channels)` 的 numpy 数组。PIL 图像或 numpy 数组表示
            扩散管道的去噪图像。
    """

    latents: torch.Tensor  # 定义潜变量属性,类型为张量
    images: Union[List[PIL.Image.Image], np.ndarray]  # 定义图像属性,类型为图像列表或 numpy 数组
# 示例文档字符串,包含代码示例
EXAMPLE_DOC_STRING = """

        ```py
        >>> import PIL  # 导入PIL库用于图像处理
        >>> import requests  # 导入requests库用于HTTP请求
        >>> import torch  # 导入PyTorch库用于深度学习
        >>> from io import BytesIO  # 从io模块导入BytesIO类用于字节流处理

        >>> from diffusers import StableDiffusionDiffEditPipeline  # 从diffusers库导入StableDiffusionDiffEditPipeline类


        >>> def download_image(url):  # 定义下载图像的函数,接受URL作为参数
        ...     response = requests.get(url)  # 使用requests库获取指定URL的响应
        ...     return PIL.Image.open(BytesIO(response.content)).convert("RGB")  # 将响应内容转为字节流,打开为图像并转换为RGB模式


        >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"  # 图像的URL

        >>> init_image = download_image(img_url).resize((768, 768))  # 下载图像并调整大小为768x768

        >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained(  # 从预训练模型加载StableDiffusionDiffEditPipeline
        ...     "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16  # 指定模型名称和数据类型为float16
        ... )

        >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)  # 设置调度器为DDIM调度器
        >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)  # 设置逆调度器为DDIM逆调度器
        >>> pipeline.enable_model_cpu_offload()  # 启用模型的CPU卸载以节省内存

        >>> mask_prompt = "A bowl of fruits"  # 定义遮罩提示词
        >>> prompt = "A bowl of pears"  # 定义生成提示词

        >>> mask_image = pipeline.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt)  # 生成遮罩图像
        >>> image_latents = pipeline.invert(image=init_image, prompt=mask_prompt).latents  # 对初始图像进行反向处理,获取潜在图像
        >>> image = pipeline(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0]  # 生成最终图像
        ```py
"""

# 反转示例文档字符串,包含代码示例
EXAMPLE_INVERT_DOC_STRING = """
        ```py
        >>> import PIL  # 导入PIL库用于图像处理
        >>> import requests  # 导入requests库用于HTTP请求
        >>> import torch  # 导入PyTorch库用于深度学习
        >>> from io import BytesIO  # 从io模块导入BytesIO类用于字节流处理

        >>> from diffusers import StableDiffusionDiffEditPipeline  # 从diffusers库导入StableDiffusionDiffEditPipeline类


        >>> def download_image(url):  # 定义下载图像的函数,接受URL作为参数
        ...     response = requests.get(url)  # 使用requests库获取指定URL的响应
        ...     return PIL.Image.open(BytesIO(response.content)).convert("RGB")  # 将响应内容转为字节流,打开为图像并转换为RGB模式


        >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"  # 图像的URL

        >>> init_image = download_image(img_url).resize((768, 768))  # 下载图像并调整大小为768x768

        >>> pipeline = StableDiffusionDiffEditPipeline.from_pretrained(  # 从预训练模型加载StableDiffusionDiffEditPipeline
        ...     "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16  # 指定模型名称和数据类型为float16
        ... )

        >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)  # 设置调度器为DDIM调度器
        >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)  # 设置逆调度器为DDIM逆调度器
        >>> pipeline.enable_model_cpu_offload()  # 启用模型的CPU卸载以节省内存

        >>> prompt = "A bowl of fruits"  # 定义生成提示词

        >>> inverted_latents = pipeline.invert(image=init_image, prompt=prompt).latents  # 对初始图像进行反向处理,获取潜在图像
        ```py
"""


def auto_corr_loss(hidden_states, generator=None):  # 定义自相关损失函数,接受隐藏状态和生成器作为参数
    reg_loss = 0.0  # 初始化正则化损失为0.0
    # 遍历隐藏状态的第一个维度
        for i in range(hidden_states.shape[0]):
            # 遍历隐藏状态的第二个维度
            for j in range(hidden_states.shape[1]):
                # 提取当前隐藏状态的一个子块
                noise = hidden_states[i : i + 1, j : j + 1, :, :]
                # 进入循环以处理噪声
                while True:
                    # 随机选择一个滚动的数量
                    roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
                    # 计算第一个方向的正则化损失并累加
                    reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
                    # 计算第二个方向的正则化损失并累加
                    reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
    
                    # 如果噪声的宽度小于等于8,退出循环
                    if noise.shape[2] <= 8:
                        break
                    # 对噪声进行平均池化处理
                    noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2)
        # 返回计算的正则化损失
        return reg_loss
# 计算隐藏状态的 Kullback-Leibler 散度
def kl_divergence(hidden_states):
    # 计算隐藏状态的方差,加上均值的平方,再减去1,再减去方差加上一个小常数的对数
    return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7)


# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess 复制而来
def preprocess(image):
    # 定义弃用信息,提示用户使用新的预处理方法
    deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
    # 调用弃用函数,输出警告信息
    deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
    # 如果输入是 PyTorch 张量,则直接返回
    if isinstance(image, torch.Tensor):
        return image
    # 如果输入是 PIL 图像,将其放入列表中
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    # 如果列表中的第一个元素是 PIL 图像
    if isinstance(image[0], PIL.Image.Image):
        # 获取图像的宽和高
        w, h = image[0].size
        # 将宽和高调整为8的整数倍
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8

        # 将每个图像调整为新大小,并转换为 NumPy 数组
        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        # 沿第0维连接所有图像数组
        image = np.concatenate(image, axis=0)
        # 将数据类型转换为 float32,并归一化到[0, 1]
        image = np.array(image).astype(np.float32) / 255.0
        # 调整数组维度顺序
        image = image.transpose(0, 3, 1, 2)
        # 将像素值缩放到[-1, 1]范围
        image = 2.0 * image - 1.0
        # 将 NumPy 数组转换为 PyTorch 张量
        image = torch.from_numpy(image)
    # 如果列表中的第一个元素是 PyTorch 张量
    elif isinstance(image[0], torch.Tensor):
        # 沿第0维连接所有张量
        image = torch.cat(image, dim=0)
    # 返回处理后的图像
    return image


def preprocess_mask(mask, batch_size: int = 1):
    # 如果输入的 mask 不是 PyTorch 张量
    if not isinstance(mask, torch.Tensor):
        # 处理 mask
        # 如果是 PIL 图像或 NumPy 数组,将其放入列表中
        if isinstance(mask, (PIL.Image.Image, np.ndarray)):
            mask = [mask]

        # 如果 mask 是列表
        if isinstance(mask, list):
            # 如果列表中的第一个元素是 PIL 图像
            if isinstance(mask[0], PIL.Image.Image):
                # 将每个图像转换为灰度并归一化到[0, 1]
                mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask]
            # 如果列表中的第一个元素是 NumPy 数组
            if isinstance(mask[0], np.ndarray):
                # 根据维度堆叠或连接数组
                mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0)
                # 将 NumPy 数组转换为 PyTorch 张量
                mask = torch.from_numpy(mask)
            # 如果列表中的第一个元素是 PyTorch 张量
            elif isinstance(mask[0], torch.Tensor):
                # 堆叠或连接张量
                mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0)

    # 如果 mask 是二维,添加批次和通道维度
    if mask.ndim == 2:
        mask = mask.unsqueeze(0).unsqueeze(0)

    # 如果 mask 是三维
    if mask.ndim == 3:
        # 如果是单一批次的 mask,且没有通道维度或单一 mask 但有通道维度
        if mask.shape[0] == 1:
            mask = mask.unsqueeze(0)

        # 对于没有通道维度的批次 mask
        else:
            mask = mask.unsqueeze(1)

    # 检查 mask 的形状
    if batch_size > 1:
        # 如果 mask 只有一个元素,复制以匹配 batch_size
        if mask.shape[0] == 1:
            mask = torch.cat([mask] * batch_size)
        # 如果 mask 的形状与 batch_size 不一致,则引发错误
        elif mask.shape[0] > 1 and mask.shape[0] != batch_size:
            raise ValueError(
                f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} "
                f"inferred by prompt inputs"
            )

    # 检查 mask 是否具有单通道
    if mask.shape[1] != 1:
        raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels")

    # 检查 mask 的值是否在 [0, 1] 之间
    # 检查掩码的最小值是否小于0,或最大值是否大于1
    if mask.min() < 0 or mask.max() > 1:
        # 如果条件满足,抛出值错误异常,提示掩码图像应在[0, 1]范围内
        raise ValueError("`mask_image` should be in [0, 1] range")

    # 二值化掩码,低于0.5的值设为0
    mask[mask < 0.5] = 0
    # 大于等于0.5的值设为1
    mask[mask >= 0.5] = 1

    # 返回处理后的掩码
    return mask
# 定义一个稳定扩散图像编辑管道类,继承多个混入类
class StableDiffusionDiffEditPipeline(
    # 继承扩散管道、稳定扩散和文本逆转加载的功能
    DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
    r"""
    <Tip warning={true}>
    # 提示用户该特性是实验性的
    This is an experimental feature!
    </Tip>

    # 使用稳定扩散和DiffEdit进行文本引导的图像修补的管道。
    Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit.

    # 该模型继承自DiffusionPipeline。检查超类文档以获取所有管道实现的通用方法(下载、保存、在特定设备上运行等)。
    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    # 该管道还继承以下加载和保存方法:
        # 加载文本逆转嵌入的方法
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        # 加载和保存LoRA权重的方法
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights

    # 构造函数参数说明:
        vae ([`AutoencoderKL`]):
            # 变分自编码器模型,用于将图像编码和解码为潜在表示。
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            # 冻结的文本编码器,使用特定的CLIP模型。
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            # 用于将文本进行分词的CLIP分词器。
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            # 用于去噪编码后图像潜在表示的UNet模型。
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            # 与UNet结合使用的调度器,用于去噪图像潜在表示。
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        inverse_scheduler ([`DDIMInverseScheduler`]):
            # 与UNet结合使用的调度器,用于填补输入潜在表示的未掩蔽部分。
            A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents.
        safety_checker ([`StableDiffusionSafetyChecker`]):
            # 用于评估生成图像是否可能被视为冒犯或有害的分类模块。
            Classification module that estimates whether generated images could be considered offensive or harmful.
            # 参考模型卡以获取关于模型潜在危害的更多细节。
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            # 用于提取生成图像特征的CLIP图像处理器;作为输入传递给安全检查器。
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    # 定义模型的CPU卸载顺序
    model_cpu_offload_seq = "text_encoder->unet->vae"
    # 定义可选组件
    _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
    # 定义排除在CPU卸载之外的组件
    _exclude_from_cpu_offload = ["safety_checker"]

    # 构造函数
    def __init__(
        # 定义构造函数的参数,包括各种模型
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        inverse_scheduler: DDIMInverseScheduler,
        requires_safety_checker: bool = True,
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 复制而来
    def _encode_prompt(
            self,
            prompt,  # 输入的提示文本
            device,  # 指定的设备(如 CPU 或 GPU)
            num_images_per_prompt,  # 每个提示生成的图像数量
            do_classifier_free_guidance,  # 是否使用无分类器的引导
            negative_prompt=None,  # 可选的负面提示文本
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入
            lora_scale: Optional[float] = None,  # 可选的 LORA 缩放因子
            **kwargs,  # 其他任意关键字参数
    ):
            # 生成弃用消息,提示用户使用新的 encode_prompt() 函数
            deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
            # 调用弃用函数警告
            deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
    
            # 调用 encode_prompt() 以获取提示嵌入的元组
            prompt_embeds_tuple = self.encode_prompt(
                prompt=prompt,  # 传递提示文本
                device=device,  # 传递设备
                num_images_per_prompt=num_images_per_prompt,  # 传递每个提示的图像数量
                do_classifier_free_guidance=do_classifier_free_guidance,  # 传递无分类器引导参数
                negative_prompt=negative_prompt,  # 传递负面提示文本
                prompt_embeds=prompt_embeds,  # 传递提示嵌入
                negative_prompt_embeds=negative_prompt_embeds,  # 传递负面提示嵌入
                lora_scale=lora_scale,  # 传递 LORA 缩放因子
                **kwargs,  # 传递其他参数
            )
    
            # 将提示嵌入元组中的两个部分连接为一个张量
            prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
    
            # 返回合并后的提示嵌入
            return prompt_embeds
    
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt 复制而来
    def encode_prompt(
            self,
            prompt,  # 输入的提示文本
            device,  # 指定的设备(如 CPU 或 GPU)
            num_images_per_prompt,  # 每个提示生成的图像数量
            do_classifier_free_guidance,  # 是否使用无分类器的引导
            negative_prompt=None,  # 可选的负面提示文本
            prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入
            negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入
            lora_scale: Optional[float] = None,  # 可选的 LORA 缩放因子
            clip_skip: Optional[int] = None,  # 可选的剪辑跳过参数
    ):
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 复制而来
        def run_safety_checker(self, image, device, dtype):  # 定义运行安全检查器的函数
            # 如果没有安全检查器,则将 NSFW 概念标记为 None
            if self.safety_checker is None:
                has_nsfw_concept = None
            else:
                # 如果输入图像是张量,进行后处理为 PIL 格式
                if torch.is_tensor(image):
                    feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
                else:
                    # 如果不是张量,则将 NumPy 数组转换为 PIL 格式
                    feature_extractor_input = self.image_processor.numpy_to_pil(image)
                # 获取安全检查器的输入,将其转换为张量并移到指定设备
                safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
                # 运行安全检查器,检查图像的 NSFW 概念
                image, has_nsfw_concept = self.safety_checker(
                    images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
                )
            # 返回处理后的图像和 NSFW 概念标记
            return image, has_nsfw_concept
    
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制而来
    # 定义一个方法来准备额外的参数,用于调度器的步骤
    def prepare_extra_step_kwargs(self, generator, eta):
        # 准备调度器步骤的额外关键字参数,因为并非所有调度器的签名相同
        # eta (η) 仅在 DDIMScheduler 中使用,对于其他调度器将被忽略
        # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
        # 其值应在 [0, 1] 之间

        # 检查调度器的步骤函数是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 初始化一个空字典以存储额外的步骤参数
        extra_step_kwargs = {}
        # 如果调度器接受 eta 参数,则将其添加到字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 检查调度器的步骤函数是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果调度器接受 generator 参数,则将其添加到字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回包含额外参数的字典
        return extra_step_kwargs

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline 复制的方法,用于解码潜在变量
    def decode_latents(self, latents):
        # 定义弃用消息,说明 decode_latents 方法将于 1.0.0 版本中移除
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        # 调用弃用函数,发出警告
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)

        # 使用 VAE 配置的缩放因子来调整潜在变量
        latents = 1 / self.vae.config.scaling_factor * latents
        # 解码潜在变量,返回图像数据
        image = self.vae.decode(latents, return_dict=False)[0]
        # 将图像值从 [-1, 1] 范围转换到 [0, 1] 范围,并限制其范围
        image = (image / 2 + 0.5).clamp(0, 1)
        # 始终转换为 float32,因为这不会造成显著开销且与 bfloat16 兼容
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        # 返回解码后的图像
        return image

    # 定义一个方法检查输入参数
    def check_inputs(
        self,
        prompt,  # 输入提示
        strength,  # 强度参数
        callback_steps,  # 回调步骤数
        negative_prompt=None,  # 可选的负面提示
        prompt_embeds=None,  # 可选的提示嵌入
        negative_prompt_embeds=None,  # 可选的负面提示嵌入
    ):
        # 检查 strength 参数是否为 None 或者不在 [0, 1] 范围内
        if (strength is None) or (strength is not None and (strength < 0 or strength > 1)):
            # 如果不符合条件,抛出 ValueError 异常,并给出详细错误信息
            raise ValueError(
                f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}."
            )

        # 检查 callback_steps 参数是否为 None 或者不符合条件(非正整数)
        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            # 如果不符合条件,抛出 ValueError 异常,并给出详细错误信息
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

        # 检查同时传入 prompt 和 prompt_embeds 是否为 None
        if prompt is not None and prompt_embeds is not None:
            # 如果同时传入,抛出 ValueError 异常
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        # 检查 prompt 和 prompt_embeds 是否均为 None
        elif prompt is None and prompt_embeds is None:
            # 如果均为 None,抛出 ValueError 异常
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        # 检查 prompt 类型是否为 str 或 list
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            # 如果不符合条件,抛出 ValueError 异常
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 检查同时传入 negative_prompt 和 negative_prompt_embeds 是否为 None
        if negative_prompt is not None and negative_prompt_embeds is not None:
            # 如果同时传入,抛出 ValueError 异常
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        # 检查 prompt_embeds 和 negative_prompt_embeds 是否同时不为 None
        if prompt_embeds is not None and negative_prompt_embeds is not None:
            # 检查它们的形状是否相同
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                # 如果形状不同,抛出 ValueError 异常,并给出详细错误信息
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

    # 定义 check_source_inputs 方法,检查源输入的有效性
    def check_source_inputs(
        self,
        # 定义 source_prompt 参数,默认为 None
        source_prompt=None,
        # 定义 source_negative_prompt 参数,默认为 None
        source_negative_prompt=None,
        # 定义 source_prompt_embeds 参数,默认为 None
        source_prompt_embeds=None,
        # 定义 source_negative_prompt_embeds 参数,默认为 None
        source_negative_prompt_embeds=None,
    ):
        # 检查 source_prompt 和 source_prompt_embeds 是否同时提供
        if source_prompt is not None and source_prompt_embeds is not None:
            # 抛出错误,提示不能同时传递这两个参数
            raise ValueError(
                f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}."
                "  Please make sure to only forward one of the two."
            )
        # 检查是否同时未提供 source_prompt 和 source_prompt_embeds
        elif source_prompt is None and source_prompt_embeds is None:
            # 抛出错误,提示至少要提供一个参数
            raise ValueError(
                "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined."
            )
        # 检查 source_prompt 是否不是字符串或列表
        elif source_prompt is not None and (
            not isinstance(source_prompt, str) and not isinstance(source_prompt, list)
        ):
            # 抛出错误,提示 source_prompt 类型错误
            raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}")

        # 检查 source_negative_prompt 和 source_negative_prompt_embeds 是否同时提供
        if source_negative_prompt is not None and source_negative_prompt_embeds is not None:
            # 抛出错误,提示不能同时传递这两个参数
            raise ValueError(
                f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:"
                f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        # 检查 source_prompt_embeds 和 source_negative_prompt_embeds 是否同时提供且形状不同
        if source_prompt_embeds is not None and source_negative_prompt_embeds is not None:
            if source_prompt_embeds.shape != source_negative_prompt_embeds.shape:
                # 抛出错误,提示两个参数的形状不匹配
                raise ValueError(
                    "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed"
                    f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !="
                    f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}."
                )

    def get_timesteps(self, num_inference_steps, strength, device):
        # 计算初始时间步长,使用 num_inference_steps 和 strength
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        # 计算开始时间步长,确保不小于零
        t_start = max(num_inference_steps - init_timestep, 0)
        # 获取调度器的时间步长切片
        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]

        # 返回时间步长和剩余的推理步骤
        return timesteps, num_inference_steps - t_start

    def get_inverse_timesteps(self, num_inference_steps, strength, device):
        # 计算初始时间步长,使用 num_inference_steps 和 strength
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        # 计算开始时间步长,确保不小于零
        t_start = max(num_inference_steps - init_timestep, 0)

        # 安全检查以防止 t_start 溢出,避免空切片
        if t_start == 0:
            return self.inverse_scheduler.timesteps, num_inference_steps
        # 获取逆调度器的时间步长切片
        timesteps = self.inverse_scheduler.timesteps[:-t_start]

        # 返回时间步长和剩余的推理步骤
        return timesteps, num_inference_steps - t_start

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 中复制的
    # 准备潜在变量,返回调整后的张量
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 定义潜在变量的形状,根据输入参数计算
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 检查生成器是否为列表且长度与批量大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 若不匹配,抛出值错误
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )
    
        # 如果潜在变量为 None,则生成随机张量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果潜在变量不为 None,则将其移动到指定设备
            latents = latents.to(device)
    
        # 将初始噪声按调度器所需的标准差进行缩放
        latents = latents * self.scheduler.init_noise_sigma
        # 返回处理后的潜在变量
        return latents
    # 准备图像的潜在表示,接受图像、批次大小、数据类型、设备和可选生成器
    def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
        # 检查输入的 image 是否为有效类型(torch.Tensor、PIL.Image.Image 或 list)
        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
            raise ValueError(
                # 如果类型不匹配,则抛出错误并提示实际类型
                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
            )

        # 将图像数据转移到指定的设备上,并转换为指定的数据类型
        image = image.to(device=device, dtype=dtype)

        # 如果图像的通道数为4,直接使用该图像作为潜在表示
        if image.shape[1] == 4:
            latents = image

        else:
            # 如果生成器是列表且其长度与批次大小不匹配,抛出错误
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    # 抛出错误,提示生成器列表的长度与请求的批次大小不匹配
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )

            # 如果生成器是列表,则逐个处理图像
            if isinstance(generator, list):
                # 对于每个图像,编码并从对应生成器中采样潜在表示
                latents = [
                    self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
                ]
                # 将潜在表示沿第0维度拼接成一个张量
                latents = torch.cat(latents, dim=0)
            else:
                # 如果生成器不是列表,直接编码图像并采样潜在表示
                latents = self.vae.encode(image).latent_dist.sample(generator)

            # 根据配置的缩放因子调整潜在表示
            latents = self.vae.config.scaling_factor * latents

        # 检查生成的潜在表示与请求的批次大小是否匹配
        if batch_size != latents.shape[0]:
            # 如果请求的批次大小可以整除当前潜在表示的大小
            if batch_size % latents.shape[0] == 0:
                # 扩展潜在表示以匹配批次大小
                deprecation_message = (
                    # 构造警告消息,提示用户图像数量与文本提示数量不匹配
                    f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
                    " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
                    " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
                    " your script to pass as many initial images as text prompts to suppress this warning."
                )
                # 发出过时警告
                deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
                # 计算每个图像需要复制的次数
                additional_latents_per_image = batch_size // latents.shape[0]
                # 复制潜在表示以满足批次大小
                latents = torch.cat([latents] * additional_latents_per_image, dim=0)
            else:
                # 如果无法按请求大小复制潜在表示,抛出错误
                raise ValueError(
                    f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
                )
        else:
            # 如果匹配,则将潜在表示转换为张量形式
            latents = torch.cat([latents], dim=0)

        # 返回处理后的潜在表示
        return latents
    # 根据模型输出、样本和时间步长获取 epsilon 值
        def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
            # 获取预测类型配置
            pred_type = self.inverse_scheduler.config.prediction_type
            # 获取时间步长对应的 alpha 乘积值
            alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
    
            # 计算 beta 乘积值
            beta_prod_t = 1 - alpha_prod_t
    
            # 根据预测类型返回不同的计算结果
            if pred_type == "epsilon":
                return model_output
            elif pred_type == "sample":
                # 根据样本和模型输出计算并返回生成样本
                return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
            elif pred_type == "v_prediction":
                # 根据 alpha 和 beta 乘积值返回加权模型输出和样本
                return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
            else:
                # 抛出错误,指明无效的预测类型
                raise ValueError(
                    f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
                )
    
        # 不计算梯度
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        # 生成掩码的函数定义
        def generate_mask(
            # 输入图像,可以是张量或PIL图像
            image: Union[torch.Tensor, PIL.Image.Image] = None,
            # 目标提示,单个字符串或字符串列表
            target_prompt: Optional[Union[str, List[str]]] = None,
            # 目标负提示,单个字符串或字符串列表
            target_negative_prompt: Optional[Union[str, List[str]]] = None,
            # 目标提示的嵌入表示,张量形式
            target_prompt_embeds: Optional[torch.Tensor] = None,
            # 目标负提示的嵌入表示,张量形式
            target_negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 源提示,单个字符串或字符串列表
            source_prompt: Optional[Union[str, List[str]]] = None,
            # 源负提示,单个字符串或字符串列表
            source_negative_prompt: Optional[Union[str, List[str]]] = None,
            # 源提示的嵌入表示,张量形式
            source_prompt_embeds: Optional[torch.Tensor] = None,
            # 源负提示的嵌入表示,张量形式
            source_negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 每个掩码生成的映射数量
            num_maps_per_mask: Optional[int] = 10,
            # 掩码编码强度
            mask_encode_strength: Optional[float] = 0.5,
            # 掩码阈值比例
            mask_thresholding_ratio: Optional[float] = 3.0,
            # 推理步骤数
            num_inference_steps: int = 50,
            # 引导比例
            guidance_scale: float = 7.5,
            # 随机数生成器
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 输出类型,默认是numpy格式
            output_type: Optional[str] = "np",
            # 跨注意力参数字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 不计算梯度
        @torch.no_grad()
        # 替换示例文档字符串
        @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
        # 反转操作的函数定义
        def invert(
            # 提示内容,单个字符串或字符串列表
            prompt: Optional[Union[str, List[str]]] = None,
            # 输入图像,可以是张量或PIL图像
            image: Union[torch.Tensor, PIL.Image.Image] = None,
            # 推理步骤数
            num_inference_steps: int = 50,
            # 反向处理强度
            inpaint_strength: float = 0.8,
            # 引导比例
            guidance_scale: float = 7.5,
            # 负提示,单个字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 随机数生成器
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 提示嵌入表示,张量形式
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示嵌入表示,张量形式
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 是否解码潜变量
            decode_latents: bool = False,
            # 输出类型,默认是PIL格式
            output_type: Optional[str] = "pil",
            # 是否返回字典格式
            return_dict: bool = True,
            # 回调函数,用于每个步骤的反馈
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回调步骤间隔
            callback_steps: Optional[int] = 1,
            # 跨注意力参数字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 自动相关的惩罚系数
            lambda_auto_corr: float = 20.0,
            # KL散度的惩罚系数
            lambda_kl: float = 20.0,
            # 正则化步骤数
            num_reg_steps: int = 0,
            # 自动相关的滚动次数
            num_auto_corr_rolls: int = 5,
        # 不计算梯度
    # 使用装饰器替换示例文档字符串
        @replace_example_docstring(EXAMPLE_DOC_STRING)
        # 定义可调用方法,接受多个参数以生成图像
        def __call__(
            # 提示文本,可以是字符串或字符串列表
            self,
            prompt: Optional[Union[str, List[str]]] = None,
            # 待处理的掩码图像,可以是张量或PIL图像
            mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
            # 图像潜变量,可以是张量或PIL图像
            image_latents: Union[torch.Tensor, PIL.Image.Image] = None,
            # 图像修补强度,默认值为0.8
            inpaint_strength: Optional[float] = 0.8,
            # 推理步骤数量,默认值为50
            num_inference_steps: int = 50,
            # 指导缩放因子,默认值为7.5
            guidance_scale: float = 7.5,
            # 负提示文本,可以是字符串或字符串列表
            negative_prompt: Optional[Union[str, List[str]]] = None,
            # 每个提示生成的图像数量,默认值为1
            num_images_per_prompt: Optional[int] = 1,
            # 噪声系数,默认值为0.0
            eta: float = 0.0,
            # 随机数生成器,可以是单个生成器或生成器列表
            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
            # 潜变量,可以是张量
            latents: Optional[torch.Tensor] = None,
            # 提示嵌入,张量类型
            prompt_embeds: Optional[torch.Tensor] = None,
            # 负提示嵌入,张量类型
            negative_prompt_embeds: Optional[torch.Tensor] = None,
            # 输出类型,默认为“pil”
            output_type: Optional[str] = "pil",
            # 是否返回字典格式的结果,默认为True
            return_dict: bool = True,
            # 回调函数,接受步骤和张量
            callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
            # 回调调用的步骤间隔,默认为1
            callback_steps: int = 1,
            # 跨注意力的关键字参数,可选字典
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 跳过的剪辑数量,默认为None
            clip_skip: int = None,
posted @ 2024-10-22 12:33  绝不原创的飞龙  阅读(64)  评论(0编辑  收藏  举报