diffusers-源码解析-二十三-

diffusers 源码解析(二十三)

.\diffusers\pipelines\controlnet\pipeline_controlnet_sd_xl_img2img.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件
# 在“按原样”基础上分发,不提供任何形式的保证或条件,
# 无论是明示或暗示的。
# 请参阅许可证以了解管理权限的具体语言和
# 限制条款。


import inspect  # 导入 inspect 模块,用于获取对象的摘要信息
from typing import Any, Callable, Dict, List, Optional, Tuple, Union  # 导入类型注解模块

import numpy as np  # 导入 numpy,用于数组和矩阵计算
import PIL.Image  # 导入 PIL.Image,用于处理图像
import torch  # 导入 PyTorch,用于深度学习
import torch.nn.functional as F  # 导入 PyTorch 的函数式 API
from transformers import (  # 从 transformers 导入模型和处理器
    CLIPImageProcessor,  # 导入 CLIP 图像处理器
    CLIPTextModel,  # 导入 CLIP 文本模型
    CLIPTextModelWithProjection,  # 导入带投影的 CLIP 文本模型
    CLIPTokenizer,  # 导入 CLIP 分词器
    CLIPVisionModelWithProjection,  # 导入带投影的 CLIP 视觉模型
)

from diffusers.utils.import_utils import is_invisible_watermark_available  # 导入检查是否可用的隐形水印功能

from ...callbacks import MultiPipelineCallbacks, PipelineCallback  # 导入多管道回调和管道回调类
from ...image_processor import PipelineImageInput, VaeImageProcessor  # 导入图像处理相关类
from ...loaders import (  # 导入加载器相关类
    FromSingleFileMixin,  # 从单文件加载的混合类
    IPAdapterMixin,  # 图像处理适配器混合类
    StableDiffusionXLLoraLoaderMixin,  # StableDiffusionXL Lora 加载混合类
    TextualInversionLoaderMixin,  # 文本反转加载混合类
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel  # 导入不同模型
from ...models.attention_processor import (  # 导入注意力处理器
    AttnProcessor2_0,  # 注意力处理器版本 2.0
    XFormersAttnProcessor,  # XFormers 注意力处理器
)
from ...models.lora import adjust_lora_scale_text_encoder  # 导入调整 Lora 标度文本编码器的函数
from ...schedulers import KarrasDiffusionSchedulers  # 导入 Karras 扩散调度器
from ...utils import (  # 导入常用工具
    USE_PEFT_BACKEND,  # 指示是否使用 PEFT 后端的常量
    deprecate,  # 导入弃用装饰器
    logging,  # 导入日志记录模块
    replace_example_docstring,  # 导入替换示例文档字符串的工具
    scale_lora_layers,  # 导入缩放 Lora 层的工具
    unscale_lora_layers,  # 导入反缩放 Lora 层的工具
)
from ...utils.torch_utils import is_compiled_module, randn_tensor  # 导入与 PyTorch 相关的工具
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin  # 导入扩散管道和稳定扩散混合类
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput  # 导入稳定扩散 XL 管道输出类


if is_invisible_watermark_available():  # 如果隐形水印功能可用
    from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker  # 导入稳定扩散 XL 水印类

from .multicontrolnet import MultiControlNetModel  # 导入多控制网模型


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,禁止 pylint 检查


EXAMPLE_DOC_STRING = """  # 示例文档字符串的空模板
"""


# 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 复制的函数
def retrieve_latents(  # 定义函数以检索潜在变量
    encoder_output: torch.Tensor,  # 输入为编码器输出的张量
    generator: Optional[torch.Generator] = None,  # 可选的随机数生成器
    sample_mode: str = "sample"  # 采样模式,默认为“sample”
):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":  # 如果编码器输出有潜在分布并且模式为采样
        return encoder_output.latent_dist.sample(generator)  # 从潜在分布中采样并返回
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":  # 如果编码器输出有潜在分布并且模式为“argmax”
        return encoder_output.latent_dist.mode()  # 返回潜在分布的众数
    elif hasattr(encoder_output, "latents"):  # 如果编码器输出有潜在变量
        return encoder_output.latents  # 直接返回潜在变量
    else:  # 如果以上条件都不满足
        raise AttributeError("Could not access latents of provided encoder_output")  # 抛出属性错误,说明无法访问潜在变量


class StableDiffusionXLControlNetImg2ImgPipeline(  # 定义 StableDiffusionXL 控制网络图像到图像的管道类
    DiffusionPipeline,  # 继承自扩散管道
    # 继承稳定扩散模型的混合类
        StableDiffusionMixin,
        # 继承文本反转加载器的混合类
        TextualInversionLoaderMixin,
        # 继承稳定扩散 XL Lora 加载器的混合类
        StableDiffusionXLLoraLoaderMixin,
        # 继承单文件加载器的混合类
        FromSingleFileMixin,
        # 继承 IP 适配器的混合类
        IPAdapterMixin,
# 文档字符串,描述使用 ControlNet 指导的图像生成管道
    r"""
    Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    """

    # 定义模型在 CPU 上卸载的顺序
    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
    # 定义可选组件的列表,用于管道的初始化
    _optional_components = [
        "tokenizer",  # 词汇表,用于文本编码
        "tokenizer_2",  # 第二个词汇表,用于文本编码
        "text_encoder",  # 文本编码器,用于生成文本嵌入
        "text_encoder_2",  # 第二个文本编码器,可能有不同的功能
        "feature_extractor",  # 特征提取器,用于图像特征的提取
        "image_encoder",  # 图像编码器,将图像转换为嵌入
    ]
    # 定义回调张量输入的列表,用于处理管道中的输入
    _callback_tensor_inputs = [
        "latents",  # 潜在变量,用于生成模型的输入
        "prompt_embeds",  # 正向提示的嵌入表示
        "negative_prompt_embeds",  # 负向提示的嵌入表示
        "add_text_embeds",  # 额外文本嵌入,用于补充输入
        "add_time_ids",  # 附加的时间标识符,用于时间相关的处理
        "negative_pooled_prompt_embeds",  # 负向池化提示的嵌入表示
        "add_neg_time_ids",  # 附加的负向时间标识符
    ]

    # 构造函数,初始化管道所需的组件
    def __init__(
        self,  # 构造函数的第一个参数,指向类的实例
        vae: AutoencoderKL,  # 变分自编码器,用于图像的重建
        text_encoder: CLIPTextModel,  # 文本编码器,使用 CLIP 模型
        text_encoder_2: CLIPTextModelWithProjection,  # 第二个文本编码器,带投影功能的 CLIP 模型
        tokenizer: CLIPTokenizer,  # 第一个 CLIP 词汇表
        tokenizer_2: CLIPTokenizer,  # 第二个 CLIP 词汇表
        unet: UNet2DConditionModel,  # U-Net 模型,用于生成图像
        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],  # 控制网络模型,用于引导生成
        scheduler: KarrasDiffusionSchedulers,  # 调度器,控制扩散过程
        requires_aesthetics_score: bool = False,  # 是否需要美学评分,默认为 False
        force_zeros_for_empty_prompt: bool = True,  # 对于空提示强制使用零值,默认为 True
        add_watermarker: Optional[bool] = None,  # 是否添加水印,默认为 None
        feature_extractor: CLIPImageProcessor = None,  # 特征提取器,默认为 None
        image_encoder: CLIPVisionModelWithProjection = None,  # 图像编码器,默认为 None
    ):
        # 调用父类的构造函数进行初始化
        super().__init__()

        # 检查 controlnet 是否为列表或元组,如果是则将其封装为 MultiControlNetModel 对象
        if isinstance(controlnet, (list, tuple)):
            controlnet = MultiControlNetModel(controlnet)

        # 注册多个模块,包括 VAE、文本编码器、tokenizer、UNet 等
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            controlnet=controlnet,
            scheduler=scheduler,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        # 计算 VAE 的缩放因子,通常用于图像尺寸调整
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        # 创建 VAE 图像处理器,设置缩放因子并开启 RGB 转换
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
        # 创建控制图像处理器,设置缩放因子,开启 RGB 转换,但不进行标准化
        self.control_image_processor = VaeImageProcessor(
            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
        )
        # 根据输入参数或默认值确定是否添加水印
        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

        # 如果需要水印,则初始化水印对象
        if add_watermarker:
            self.watermark = StableDiffusionXLWatermarker()
        else:
            # 否则将水印设置为 None
            self.watermark = None

        # 注册配置,强制空提示使用零值
        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
        # 注册配置,标记是否需要美学评分
        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)

    # 从 StableDiffusionXLPipeline 复制的 encode_prompt 方法
    def encode_prompt(
        self,
        # 定义 prompt 字符串及其相关参数
        prompt: str,
        prompt_2: Optional[str] = None,
        device: Optional[torch.device] = None,
        num_images_per_prompt: int = 1,
        do_classifier_free_guidance: bool = True,
        negative_prompt: Optional[str] = None,
        negative_prompt_2: Optional[str] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
    # 从 StableDiffusionPipeline 复制的 encode_image 方法
    # 定义一个方法来编码图像,参数包括图像、设备、每个提示的图像数量和可选的隐藏状态输出
        def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
            # 获取图像编码器参数的数据类型
            dtype = next(self.image_encoder.parameters()).dtype
    
            # 检查输入的图像是否为张量类型
            if not isinstance(image, torch.Tensor):
                # 如果不是,将其转换为张量,并提取像素值
                image = self.feature_extractor(image, return_tensors="pt").pixel_values
    
            # 将图像移动到指定设备并转换为相应的数据类型
            image = image.to(device=device, dtype=dtype)
            # 检查是否需要输出隐藏状态
            if output_hidden_states:
                # 获取图像编码器的隐藏状态,选择倒数第二个隐藏层
                image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
                # 将隐藏状态按每个提示的图像数量重复
                image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
                # 获取无条件图像编码的隐藏状态,使用全零张量作为输入
                uncond_image_enc_hidden_states = self.image_encoder(
                    torch.zeros_like(image), output_hidden_states=True
                ).hidden_states[-2]
                # 将无条件隐藏状态按每个提示的图像数量重复
                uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                    num_images_per_prompt, dim=0
                )
                # 返回图像编码的隐藏状态和无条件图像编码的隐藏状态
                return image_enc_hidden_states, uncond_image_enc_hidden_states
            else:
                # 获取图像编码的嵌入表示
                image_embeds = self.image_encoder(image).image_embeds
                # 将嵌入表示按每个提示的图像数量重复
                image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
                # 创建与图像嵌入同样形状的全零张量作为无条件嵌入
                uncond_image_embeds = torch.zeros_like(image_embeds)
    
                # 返回图像嵌入和无条件图像嵌入
                return image_embeds, uncond_image_embeds
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds 复制的方法
        def prepare_ip_adapter_image_embeds(
            # 定义方法的参数,包括 IP 适配器图像、图像嵌入、设备、每个提示的图像数量和分类器自由引导的标志
            self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        # 初始化一个空列表,用于存储图像嵌入
        image_embeds = []
        # 如果启用了无分类器自由引导,则初始化负图像嵌入列表
        if do_classifier_free_guidance:
            negative_image_embeds = []
        # 如果输入适配器图像嵌入为 None
        if ip_adapter_image_embeds is None:
            # 检查输入适配器图像是否为列表类型,如果不是,则转换为列表
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            # 检查输入适配器图像的长度是否与 IP 适配器数量相等
            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                # 如果不相等,抛出值错误
                raise ValueError(
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            # 遍历输入适配器图像和相应的图像投影层
            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                # 确定是否输出隐藏状态,依据图像投影层的类型
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                # 编码单个图像,获取嵌入和负嵌入
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )

                # 将图像嵌入添加到列表中,增加一个维度
                image_embeds.append(single_image_embeds[None, :])
                # 如果启用了无分类器自由引导,则将负图像嵌入添加到列表中
                if do_classifier_free_guidance:
                    negative_image_embeds.append(single_negative_image_embeds[None, :])
        else:
            # 如果输入适配器图像嵌入已存在
            for single_image_embeds in ip_adapter_image_embeds:
                # 如果启用了无分类器自由引导,将嵌入分成负嵌入和正嵌入
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    # 添加负图像嵌入到列表中
                    negative_image_embeds.append(single_negative_image_embeds)
                # 添加正图像嵌入到列表中
                image_embeds.append(single_image_embeds)

        # 初始化一个空列表,用于存储处理后的输入适配器图像嵌入
        ip_adapter_image_embeds = []
        # 遍历图像嵌入,执行重复操作以匹配每个提示的图像数量
        for i, single_image_embeds in enumerate(image_embeds):
            # 将单个图像嵌入沿着维度 0 重复指定次数
            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
            # 如果启用了无分类器自由引导,处理负嵌入
            if do_classifier_free_guidance:
                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                # 将负嵌入与正嵌入合并
                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

            # 将嵌入移动到指定的设备
            single_image_embeds = single_image_embeds.to(device=device)
            # 将处理后的嵌入添加到列表中
            ip_adapter_image_embeds.append(single_image_embeds)

        # 返回处理后的输入适配器图像嵌入列表
        return ip_adapter_image_embeds

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制的内容
    # 准备额外的参数用于调度器步骤,因为并非所有调度器具有相同的参数签名
        def prepare_extra_step_kwargs(self, generator, eta):
            # eta (η) 仅在 DDIMScheduler 中使用,其他调度器会忽略此参数
            # eta 对应于 DDIM 论文中的 η: https://arxiv.org/abs/2010.02502
            # 该值应在 [0, 1] 之间
    
            # 检查调度器的步骤方法是否接受 eta 参数
            accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 创建一个空字典用于存放额外参数
            extra_step_kwargs = {}
            # 如果调度器接受 eta,则将其添加到额外参数中
            if accepts_eta:
                extra_step_kwargs["eta"] = eta
    
            # 检查调度器的步骤方法是否接受 generator 参数
            accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
            # 如果调度器接受 generator,则将其添加到额外参数中
            if accepts_generator:
                extra_step_kwargs["generator"] = generator
            # 返回包含额外参数的字典
            return extra_step_kwargs
    
        # 检查输入参数的有效性
        def check_inputs(
            self,
            prompt,
            prompt_2,
            image,
            strength,
            num_inference_steps,
            callback_steps,
            negative_prompt=None,
            negative_prompt_2=None,
            prompt_embeds=None,
            negative_prompt_embeds=None,
            pooled_prompt_embeds=None,
            negative_pooled_prompt_embeds=None,
            ip_adapter_image=None,
            ip_adapter_image_embeds=None,
            controlnet_conditioning_scale=1.0,
            control_guidance_start=0.0,
            control_guidance_end=1.0,
            callback_on_step_end_tensor_inputs=None,
        # 从 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image 复制的参数
    # 检查输入图像的类型和形状,确保与提示的批量大小一致
    def check_image(self, image, prompt, prompt_embeds):
        # 判断输入是否为 PIL 图像
        image_is_pil = isinstance(image, PIL.Image.Image)
        # 判断输入是否为 PyTorch 张量
        image_is_tensor = isinstance(image, torch.Tensor)
        # 判断输入是否为 NumPy 数组
        image_is_np = isinstance(image, np.ndarray)
        # 判断输入是否为 PIL 图像列表
        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
        # 判断输入是否为 PyTorch 张量列表
        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
        # 判断输入是否为 NumPy 数组列表
        image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
    
        # 如果输入不符合任何类型,抛出类型错误
        if (
            not image_is_pil
            and not image_is_tensor
            and not image_is_np
            and not image_is_pil_list
            and not image_is_tensor_list
            and not image_is_np_list
        ):
            raise TypeError(
                f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
            )
    
        # 如果输入为 PIL 图像,设置批量大小为 1
        if image_is_pil:
            image_batch_size = 1
        else:
            # 否则,根据输入的长度确定批量大小
            image_batch_size = len(image)
    
        # 如果提示不为 None 且为字符串,设置提示批量大小为 1
        if prompt is not None and isinstance(prompt, str):
            prompt_batch_size = 1
        # 如果提示为列表,根据列表长度设置批量大小
        elif prompt is not None and isinstance(prompt, list):
            prompt_batch_size = len(prompt)
        # 如果提示嵌入不为 None,使用其第一维的大小作为批量大小
        elif prompt_embeds is not None:
            prompt_batch_size = prompt_embeds.shape[0]
    
        # 如果图像批量大小不为 1,且与提示批量大小不一致,抛出值错误
        if image_batch_size != 1 and image_batch_size != prompt_batch_size:
            raise ValueError(
                f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
            )
    
        # 从 diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl 导入的 prepare_image 方法
        def prepare_control_image(
            self,
            image,
            width,
            height,
            batch_size,
            num_images_per_prompt,
            device,
            dtype,
            do_classifier_free_guidance=False,
            guess_mode=False,
        ):
            # 预处理输入图像并转换为指定的数据类型
            image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
            # 获取图像批量大小
            image_batch_size = image.shape[0]
    
            # 如果图像批量大小为 1,重复次数设置为 batch_size
            if image_batch_size == 1:
                repeat_by = batch_size
            else:
                # 如果图像批量大小与提示批量大小相同,设置重复次数为每个提示的图像数量
                repeat_by = num_images_per_prompt
    
            # 重复图像以匹配所需的批量大小
            image = image.repeat_interleave(repeat_by, dim=0)
    
            # 将图像转移到指定设备和数据类型
            image = image.to(device=device, dtype=dtype)
    
            # 如果启用分类器自由引导并且不在猜测模式下,复制图像以增加维度
            if do_classifier_free_guidance and not guess_mode:
                image = torch.cat([image] * 2)
    
            # 返回处理后的图像
            return image
    
        # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img 导入的 get_timesteps 方法
    # 获取时间步的函数,接收推理步骤数、强度和设备参数
        def get_timesteps(self, num_inference_steps, strength, device):
            # 计算原始时间步,使用 init_timestep,确保不超过推理步骤数
            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    
            # 计算开始时间步,确保不小于零
            t_start = max(num_inference_steps - init_timestep, 0)
            # 从调度器获取时间步,截取从 t_start 开始的所有时间步
            timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
            # 如果调度器具有设置开始索引的方法,则调用该方法
            if hasattr(self.scheduler, "set_begin_index"):
                self.scheduler.set_begin_index(t_start * self.scheduler.order)
    
            # 返回时间步和剩余的推理步骤数
            return timesteps, num_inference_steps - t_start
    
        # 从 StableDiffusionXLImg2ImgPipeline 复制的准备潜在变量的函数
        def prepare_latents(
            self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
        # 从 StableDiffusionXLImg2ImgPipeline 复制的获取附加时间 ID 的函数
        def _get_add_time_ids(
            self,
            original_size,
            crops_coords_top_left,
            target_size,
            aesthetic_score,
            negative_aesthetic_score,
            negative_original_size,
            negative_crops_coords_top_left,
            negative_target_size,
            dtype,
            text_encoder_projection_dim=None,
    ):
        # 检查配置是否需要美学评分
        if self.config.requires_aesthetics_score:
            # 创建包含原始大小、裁剪坐标及美学评分的列表
            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
            # 创建包含负样本原始大小、裁剪坐标及负美学评分的列表
            add_neg_time_ids = list(
                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
            )
        else:
            # 创建包含原始大小、裁剪坐标和目标大小的列表
            add_time_ids = list(original_size + crops_coords_top_left + target_size)
            # 创建包含负样本原始大小、裁剪坐标及负目标大小的列表
            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)

        # 计算通过添加时间嵌入维度和文本编码器投影维度得到的通过嵌入维度
        passed_add_embed_dim = (
            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
        )
        # 获取模型期望的添加嵌入维度
        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

        # 检查期望的嵌入维度是否大于传递的嵌入维度,并符合特定条件
        if (
            expected_add_embed_dim > passed_add_embed_dim
            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
        ):
            # 抛出值错误,说明创建的嵌入维度不符合预期
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
            )
        # 检查期望的嵌入维度是否小于传递的嵌入维度,并符合特定条件
        elif (
            expected_add_embed_dim < passed_add_embed_dim
            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
        ):
            # 抛出值错误,说明创建的嵌入维度不符合预期
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
            )
        # 检查期望的嵌入维度是否与传递的嵌入维度不相等
        elif expected_add_embed_dim != passed_add_embed_dim:
            # 抛出值错误,说明模型配置不正确
            raise ValueError(
                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
            )

        # 将添加的时间 ID 转换为张量,并指定数据类型
        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
        # 将添加的负时间 ID 转换为张量,并指定数据类型
        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)

        # 返回添加的时间 ID 和添加的负时间 ID
        return add_time_ids, add_neg_time_ids

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae 复制而来
    # 定义一个方法,用于将 VAE 模型的参数类型提升
    def upcast_vae(self):
        # 获取当前 VAE 模型的数据类型
        dtype = self.vae.dtype
        # 将 VAE 模型转换为 float32 数据类型
        self.vae.to(dtype=torch.float32)
        # 检查 VAE 解码器中第一个注意力处理器的类型,以确定是否使用了特定版本的处理器
        use_torch_2_0_or_xformers = isinstance(
            self.vae.decoder.mid_block.attentions[0].processor,
            (
                AttnProcessor2_0,
                XFormersAttnProcessor,
            ),
        )
        # 如果使用了 xformers 或 torch_2_0,注意力块不需要为 float32 类型,从而节省大量内存
        if use_torch_2_0_or_xformers:
            # 将后量化卷积层转换为原始数据类型
            self.vae.post_quant_conv.to(dtype)
            # 将解码器输入卷积层转换为原始数据类型
            self.vae.decoder.conv_in.to(dtype)
            # 将解码器中间块转换为原始数据类型
            self.vae.decoder.mid_block.to(dtype)

    # 定义一个属性,返回当前的引导缩放比例
    @property
    def guidance_scale(self):
        # 返回内部存储的引导缩放比例
        return self._guidance_scale

    # 定义一个属性,返回当前的剪辑跳过值
    @property
    def clip_skip(self):
        # 返回内部存储的剪辑跳过值
        return self._clip_skip

    # 定义一个属性,用于判断是否进行无分类器引导,依据是引导缩放比例是否大于 1
    # 此属性的定义参考了 Imagen 论文中的方程 (2)
    # 当 `guidance_scale = 1` 时,相当于不进行无分类器引导
    @property
    def do_classifier_free_guidance(self):
        # 如果引导缩放比例大于 1,返回 True,否则返回 False
        return self._guidance_scale > 1

    # 定义一个属性,返回当前的交叉注意力参数
    @property
    def cross_attention_kwargs(self):
        # 返回内部存储的交叉注意力参数
        return self._cross_attention_kwargs

    # 定义一个属性,返回当前的时间步数
    @property
    def num_timesteps(self):
        # 返回内部存储的时间步数
        return self._num_timesteps

    # 装饰器,表示在执行下面的方法时不计算梯度
    @torch.no_grad()
    # 装饰器,用于替换示例文档字符串
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义一个可调用的类方法,接受多个参数用于处理图像生成
    def __call__(
        # 主提示字符串或字符串列表,默认为 None
        self,
        prompt: Union[str, List[str]] = None,
        # 第二个提示字符串或字符串列表,默认为 None
        prompt_2: Optional[Union[str, List[str]]] = None,
        # 输入图像,用于图像生成的基础,默认为 None
        image: PipelineImageInput = None,
        # 控制图像,用于影响生成的图像,默认为 None
        control_image: PipelineImageInput = None,
        # 输出图像的高度,默认为 None
        height: Optional[int] = None,
        # 输出图像的宽度,默认为 None
        width: Optional[int] = None,
        # 图像生成的强度,默认为 0.8
        strength: float = 0.8,
        # 进行推理的步数,默认为 50
        num_inference_steps: int = 50,
        # 引导尺度,控制图像生成的引导程度,默认为 5.0
        guidance_scale: float = 5.0,
        # 负面提示字符串或字符串列表,默认为 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 第二个负面提示字符串或字符串列表,默认为 None
        negative_prompt_2: Optional[Union[str, List[str]]] = None,
        # 每个提示生成的图像数量,默认为 1
        num_images_per_prompt: Optional[int] = 1,
        # 采样的 eta 值,默认为 0.0
        eta: float = 0.0,
        # 随机数生成器,可选,默认为 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潜在变量,默认为 None
        latents: Optional[torch.Tensor] = None,
        # 提示的嵌入向量,默认为 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 负面提示的嵌入向量,默认为 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 聚合的提示嵌入向量,默认为 None
        pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 负面聚合提示嵌入向量,默认为 None
        negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
        # 输入适配器图像,默认为 None
        ip_adapter_image: Optional[PipelineImageInput] = None,
        # 输入适配器图像的嵌入向量,默认为 None
        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
        # 输出类型,默认为 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典,默认为 True
        return_dict: bool = True,
        # 交叉注意力参数,默认为 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 控制网络的条件缩放,默认为 0.8
        controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
        # 猜测模式,默认为 False
        guess_mode: bool = False,
        # 控制引导的开始位置,默认为 0.0
        control_guidance_start: Union[float, List[float]] = 0.0,
        # 控制引导的结束位置,默认为 1.0
        control_guidance_end: Union[float, List[float]] = 1.0,
        # 原始图像的尺寸,默认为 None
        original_size: Tuple[int, int] = None,
        # 裁剪坐标的左上角,默认为 (0, 0)
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 目标尺寸,默认为 None
        target_size: Tuple[int, int] = None,
        # 负面原始图像的尺寸,默认为 None
        negative_original_size: Optional[Tuple[int, int]] = None,
        # 负面裁剪坐标的左上角,默认为 (0, 0)
        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 负目标尺寸,默认为 None
        negative_target_size: Optional[Tuple[int, int]] = None,
        # 审美分数,默认为 6.0
        aesthetic_score: float = 6.0,
        # 负面审美分数,默认为 2.5
        negative_aesthetic_score: float = 2.5,
        # 跳过的剪辑层数,默认为 None
        clip_skip: Optional[int] = None,
        # 步骤结束时的回调函数,可选,默认为 None
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        # 结束步骤时的张量输入回调,默认为 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 其他额外参数,默认为空
        **kwargs,

.\diffusers\pipelines\controlnet\pipeline_flax_controlnet.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证,版本 2.0(“许可证”)授权;
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有规定,
# 否则根据许可证分发的软件是“按原样”提供的,
# 不提供任何形式的担保或条件,无论是明示或暗示。
# 有关许可证下的特定语言的权限和限制,请参见许可证。

import warnings  # 导入警告模块,用于处理警告信息
from functools import partial  # 从 functools 导入 partial,用于部分函数应用
from typing import Dict, List, Optional, Union  # 导入类型提示,方便函数参数和返回值的类型注释

import jax  # 导入 JAX,用于高性能数值计算
import jax.numpy as jnp  # 导入 JAX 的 NumPy 接口,提供数组操作功能
import numpy as np  # 导入 NumPy,提供数值计算功能
from flax.core.frozen_dict import FrozenDict  # 从 flax 导入 FrozenDict,用于不可变字典
from flax.jax_utils import unreplicate  # 从 flax 导入 unreplicate,用于在 JAX 中处理设备数据
from flax.training.common_utils import shard  # 从 flax 导入 shard,用于数据并行
from PIL import Image  # 从 PIL 导入 Image,用于图像处理
from transformers import CLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel  # 导入 CLIP 相关模块,处理图像和文本

from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel  # 导入模型定义
from ...schedulers import (  # 导入调度器,用于训练过程中的控制
    FlaxDDIMScheduler,
    FlaxDPMSolverMultistepScheduler,
    FlaxLMSDiscreteScheduler,
    FlaxPNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring  # 导入工具函数和常量
from ..pipeline_flax_utils import FlaxDiffusionPipeline  # 导入扩散管道
from ..stable_diffusion import FlaxStableDiffusionPipelineOutput  # 导入稳定扩散管道输出
from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker  # 导入安全检查器

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,方便调试和信息输出

# 设置为 True 以使用 Python 循环而不是 jax.fori_loop,以便于调试
DEBUG = False  # 调试模式标志,默认为关闭状态

EXAMPLE_DOC_STRING = """  # 示例文档字符串,可能用于文档生成或示例展示


```  # 示例结束标志
    Examples:
        ```py
        >>> import jax  # 导入 JAX 库,用于高性能数值计算
        >>> import numpy as np  # 导入 NumPy 库,支持数组操作
        >>> import jax.numpy as jnp  # 导入 JAX 的 NumPy,支持自动微分和GPU加速
        >>> from flax.jax_utils import replicate  # 从 Flax 导入 replicate 函数,用于参数复制
        >>> from flax.training.common_utils import shard  # 从 Flax 导入 shard 函数,用于数据分片
        >>> from diffusers.utils import load_image, make_image_grid  # 从 diffusers 导入图像加载和网格生成工具
        >>> from PIL import Image  # 导入 PIL 库,用于图像处理
        >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel  # 导入用于稳定扩散模型和控制网的类

        >>> def create_key(seed=0):  # 定义函数创建随机数生成器的密钥
        ...     return jax.random.PRNGKey(seed)  # 返回一个以 seed 为种子的 PRNG 密钥

        >>> rng = create_key(0)  # 创建随机数生成器的密钥,种子为 0

        >>> # get canny image  # 获取 Canny 边缘检测图像
        >>> canny_image = load_image(  # 使用 load_image 函数加载图像
        ...     "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"  # 指定图像的 URL
        ... )

        >>> prompts = "best quality, extremely detailed"  # 定义用于生成图像的正向提示
        >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality"  # 定义生成图像时要避免的负向提示

        >>> # load control net and stable diffusion v1-5  # 加载控制网络和稳定扩散模型 v1-5
        >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(  # 从预训练模型加载控制网络及其参数
        ...     "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32  # 指定模型名称、来源及数据类型
        ... )
        >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(  # 从预训练模型加载稳定扩散管道及其参数
        ...     "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32  # 指定模型名称、控制网、版本和数据类型
        ... )
        >>> params["controlnet"] = controlnet_params  # 将控制网参数存入管道参数中

        >>> num_samples = jax.device_count()  # 获取当前设备的数量,设置样本数量
        >>> rng = jax.random.split(rng, jax.device_count())  # 将随机数生成器的密钥根据设备数量进行分割

        >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)  # 准备正向提示的输入,针对每个样本复制
        >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)  # 准备负向提示的输入,针对每个样本复制
        >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)  # 准备处理后的图像输入,针对每个样本复制

        >>> p_params = replicate(params)  # 复制参数以便在多个设备上使用
        >>> prompt_ids = shard(prompt_ids)  # 将正向提示的输入数据进行分片
        >>> negative_prompt_ids = shard(negative_prompt_ids)  # 将负向提示的输入数据进行分片
        >>> processed_image = shard(processed_image)  # 将处理后的图像输入数据进行分片

        >>> output = pipe(  # 调用管道生成输出
        ...     prompt_ids=prompt_ids,  # 传入正向提示 ID
        ...     image=processed_image,  # 传入处理后的图像
        ...     params=p_params,  # 传入复制的参数
        ...     prng_seed=rng,  # 传入随机数生成器的密钥
        ...     num_inference_steps=50,  # 设置推理的步骤数
        ...     neg_prompt_ids=negative_prompt_ids,  # 传入负向提示 ID
        ...     jit=True,  # 启用 JIT 编译
        ... ).images  # 获取生成的图像

        >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))  # 将输出图像转换为 PIL 格式
        >>> output_images = make_image_grid(output_images, num_samples // 4, 4)  # 将图像生成网格格式,指定每行显示的图像数量
        >>> output_images.save("generated_image.png")  # 保存生成的图像为 PNG 文件
        ``` 
# 定义一个类,基于 Flax 实现 Stable Diffusion 的控制网文本到图像生成管道
class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
    r"""
    基于 Flax 的管道,用于使用 Stable Diffusion 和 ControlNet 指导进行文本到图像生成。

    此模型继承自 [`FlaxDiffusionPipeline`]。有关所有管道实现的通用方法(下载、保存、在特定设备上运行等),请查看超类文档。

    参数:
        vae ([`FlaxAutoencoderKL`]):
            用于将图像编码和解码为潜在表示的变分自编码器(VAE)模型。
        text_encoder ([`~transformers.FlaxCLIPTextModel`]):
            冻结的文本编码器([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14))。
        tokenizer ([`~transformers.CLIPTokenizer`]):
            用于对文本进行分词的 `CLIPTokenizer`。
        unet ([`FlaxUNet2DConditionModel`]):
            一个 `FlaxUNet2DConditionModel`,用于去噪编码后的图像潜在表示。
        controlnet ([`FlaxControlNetModel`]):
            在去噪过程中为 `unet` 提供额外的条件信息。
        scheduler ([`SchedulerMixin`]):
            用于与 `unet` 结合使用的调度器,以去噪编码的图像潜在表示。可以是
            [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`] 或
            [`FlaxDPMSolverMultistepScheduler`] 中的一个。
        safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
            分类模块,评估生成的图像是否可能被视为冒犯或有害。
            有关模型潜在危害的更多细节,请参阅 [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5)。
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            一个 `CLIPImageProcessor`,用于提取生成图像的特征;用于 `safety_checker` 的输入。
    """

    # 初始化方法,定义所需参数及其类型
    def __init__(
        # 变分自编码器(VAE)模型,用于图像编码和解码
        vae: FlaxAutoencoderKL,
        # 冻结的文本编码器模型
        text_encoder: FlaxCLIPTextModel,
        # 文本分词器
        tokenizer: CLIPTokenizer,
        # 去噪模型
        unet: FlaxUNet2DConditionModel,
        # 控制网模型
        controlnet: FlaxControlNetModel,
        # 图像去噪的调度器
        scheduler: Union[
            FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
        ],
        # 安全检查模块
        safety_checker: FlaxStableDiffusionSafetyChecker,
        # 特征提取器
        feature_extractor: CLIPImageProcessor,
        # 数据类型,默认为 32 位浮点数
        dtype: jnp.dtype = jnp.float32,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置数据类型属性
        self.dtype = dtype

        # 检查安全检查器是否为 None
        if safety_checker is None:
            # 记录警告,告知用户已禁用安全检查器
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        # 注册各个模块,方便后续使用
        self.register_modules(
            vae=vae,  # 变分自编码器
            text_encoder=text_encoder,  # 文本编码器
            tokenizer=tokenizer,  # 分词器
            unet=unet,  # UNet 模型
            controlnet=controlnet,  # 控制网络
            scheduler=scheduler,  # 调度器
            safety_checker=safety_checker,  # 安全检查器
            feature_extractor=feature_extractor,  # 特征提取器
        )
        # 计算 VAE 的缩放因子
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)

    def prepare_text_inputs(self, prompt: Union[str, List[str]]):
        # 检查 prompt 类型是否为字符串或列表
        if not isinstance(prompt, (str, list)):
            # 如果类型不符,抛出值错误
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        # 使用分词器处理输入文本
        text_input = self.tokenizer(
            prompt,  # 输入的提示文本
            padding="max_length",  # 填充到最大长度
            max_length=self.tokenizer.model_max_length,  # 设置最大长度为分词器的最大模型长度
            truncation=True,  # 如果超过最大长度,则截断
            return_tensors="np",  # 返回 NumPy 格式的张量
        )

        # 返回处理后的输入 ID
        return text_input.input_ids

    def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
        # 检查图像类型是否为 PIL.Image 或列表
        if not isinstance(image, (Image.Image, list)):
            # 如果类型不符,抛出值错误
            raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")

        # 如果输入是单个图像,将其转换为列表
        if isinstance(image, Image.Image):
            image = [image]

        # 对所有图像进行预处理,并合并为一个数组
        processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])

        # 返回处理后的图像数组
        return processed_images

    def _get_has_nsfw_concepts(self, features, params):
        # 使用安全检查器检查是否存在不适当内容概念
        has_nsfw_concepts = self.safety_checker(features, params)
        # 返回检查结果
        return has_nsfw_concepts
    # 定义一个安全检查的私有方法,接收图像、模型参数和是否使用 JIT 编译的标志
    def _run_safety_checker(self, images, safety_model_params, jit=False):
        # 当 jit 为 True 时,safety_model_params 应该已经被复制
        # 将输入的图像数组转换为 PIL 图像格式
        pil_images = [Image.fromarray(image) for image in images]
        # 使用特征提取器处理 PIL 图像,返回其像素值
        features = self.feature_extractor(pil_images, return_tensors="np").pixel_values

        # 如果启用 JIT 编译
        if jit:
            # 对特征进行分片处理
            features = shard(features)
            # 检查特征中是否存在 NSFW 概念
            has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
            # 取消特征的分片
            has_nsfw_concepts = unshard(has_nsfw_concepts)
            # 取消模型参数的复制
            safety_model_params = unreplicate(safety_model_params)
        else:
            # 否则,直接获取 NSFW 概念的检查结果
            has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)

        # 初始化一个标志,指示图像是否已经被复制
        images_was_copied = False
        # 遍历每个 NSFW 概念的检查结果
        for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
            # 如果检测到 NSFW 概念
            if has_nsfw_concept:
                # 如果还没有复制图像
                if not images_was_copied:
                    # 标记为已复制,并进行图像复制
                    images_was_copied = True
                    images = images.copy()

                # 将对应的图像替换为全黑图像
                images[idx] = np.zeros(images[idx].shape, dtype=np.uint8)  # black image

            # 如果存在任何 NSFW 概念
            if any(has_nsfw_concepts):
                # 发出警告,提示可能检测到不适宜内容
                warnings.warn(
                    "Potential NSFW content was detected in one or more images. A black image will be returned"
                    " instead. Try again with a different prompt and/or seed."
                )

        # 返回处理后的图像和 NSFW 概念的检查结果
        return images, has_nsfw_concepts

    # 定义一个生成图像的私有方法,接收多个参数以控制生成过程
    def _generate(
        self,
        prompt_ids: jnp.ndarray,  # 输入的提示 ID 数组
        image: jnp.ndarray,  # 输入的图像数据
        params: Union[Dict, FrozenDict],  # 模型参数,可能是字典或不可变字典
        prng_seed: jax.Array,  # 随机种子,用于随机数生成
        num_inference_steps: int,  # 推理步骤的数量
        guidance_scale: float,  # 指导比例,用于控制生成质量
        latents: Optional[jnp.ndarray] = None,  # 潜在变量,默认值为 None
        neg_prompt_ids: Optional[jnp.ndarray] = None,  # 负提示 ID,默认值为 None
        controlnet_conditioning_scale: float = 1.0,  # 控制网络的条件缩放比例
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    # 定义可调用的方法,接收多个参数以控制生成过程
    def __call__(
        self,
        prompt_ids: jnp.ndarray,  # 输入的提示 ID 数组
        image: jnp.ndarray,  # 输入的图像数据
        params: Union[Dict, FrozenDict],  # 模型参数,可能是字典或不可变字典
        prng_seed: jax.Array,  # 随机种子,用于随机数生成
        num_inference_steps: int = 50,  # 默认推理步骤的数量为 50
        guidance_scale: Union[float, jnp.ndarray] = 7.5,  # 默认指导比例为 7.5
        latents: jnp.ndarray = None,  # 潜在变量,默认值为 None
        neg_prompt_ids: jnp.ndarray = None,  # 负提示 ID,默认值为 None
        controlnet_conditioning_scale: Union[float, jnp.ndarray] = 1.0,  # 默认控制网络的条件缩放比例为 1.0
        return_dict: bool = True,  # 默认返回字典格式
        jit: bool = False,  # 默认不启用 JIT 编译
# 静态参数为 pipe 和 num_inference_steps,任何更改都会触发重新编译。
# 非静态参数是(分片)输入张量,这些张量在它们的第一维上被映射(因此为 `0`)。
@partial(
    jax.pmap,  # 使用 JAX 的 pmap 并行映射功能
    in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0),  # 指定输入张量的轴
    static_broadcasted_argnums=(0, 5),  # 指定静态广播参数的索引
)
def _p_generate(  # 定义生成函数
    pipe,  # 生成管道对象
    prompt_ids,  # 提示 ID
    image,  # 输入图像
    params,  # 生成参数
    prng_seed,  # 随机数生成种子
    num_inference_steps,  # 推理步骤数
    guidance_scale,  # 指导尺度
    latents,  # 潜在变量
    neg_prompt_ids,  # 负提示 ID
    controlnet_conditioning_scale,  # 控制网条件尺度
):
    return pipe._generate(  # 调用生成管道的生成方法
        prompt_ids,  # 提示 ID
        image,  # 输入图像
        params,  # 生成参数
        prng_seed,  # 随机数生成种子
        num_inference_steps,  # 推理步骤数
        guidance_scale,  # 指导尺度
        latents,  # 潜在变量
        neg_prompt_ids,  # 负提示 ID
        controlnet_conditioning_scale,  # 控制网条件尺度
    )


@partial(jax.pmap, static_broadcasted_argnums=(0,))  # 使用 JAX 的 pmap,并指定静态广播参数
def _p_get_has_nsfw_concepts(pipe, features, params):  # 定义检查是否有 NSFW 概念的函数
    return pipe._get_has_nsfw_concepts(features, params)  # 调用管道的相关方法


def unshard(x: jnp.ndarray):  # 定义反分片函数,接受一个张量
    # einops.rearrange(x, 'd b ... -> (d b) ...')  # 注释掉的排列操作
    num_devices, batch_size = x.shape[:2]  # 获取设备数量和批量大小
    rest = x.shape[2:]  # 获取其余维度
    return x.reshape(num_devices * batch_size, *rest)  # 重新调整形状以合并设备和批量维度


def preprocess(image, dtype):  # 定义图像预处理函数
    image = image.convert("RGB")  # 将图像转换为 RGB 模式
    w, h = image.size  # 获取图像的宽和高
    w, h = (x - x % 64 for x in (w, h))  # 将宽高调整为64的整数倍
    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])  # 调整图像大小,使用 Lanczos 插值法
    image = jnp.array(image).astype(dtype) / 255.0  # 转换为 NumPy 数组并归一化到 [0, 1]
    image = image[None].transpose(0, 3, 1, 2)  # 添加新维度并调整通道顺序
    return image  # 返回处理后的图像

.\diffusers\pipelines\controlnet\__init__.py

# 导入类型检查工具
from typing import TYPE_CHECKING

# 从 utils 模块导入必要的工具和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入慢导入标志
    OptionalDependencyNotAvailable,  # 导入可选依赖不可用异常
    _LazyModule,  # 导入延迟模块工具
    get_objects_from_module,  # 导入从模块获取对象的函数
    is_flax_available,  # 导入检查 Flax 可用性的函数
    is_torch_available,  # 导入检查 Torch 可用性的函数
    is_transformers_available,  # 导入检查 Transformers 可用性的函数
)

# 初始化一个空字典用于存储虚拟对象
_dummy_objects = {}
# 初始化一个空字典用于存储导入结构
_import_structure = {}

try:
    # 检查 Transformers 和 Torch 是否可用
    if not (is_transformers_available() and is_torch_available()):
        # 如果不可用,抛出异常
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 导入虚拟对象
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新虚拟对象字典
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
    # 如果依赖可用,更新导入结构字典
    _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
    _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
    _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
    _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
    _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
    _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
    _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
    _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]

try:
    # 检查 Transformers 和 Flax 是否可用
    if not (is_transformers_available() and is_flax_available()):
        # 如果不可用,抛出异常
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 导入虚拟 Flax 和 Transformers 对象
    from ...utils import dummy_flax_and_transformers_objects  # noqa F403

    # 更新虚拟对象字典
    _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
    # 如果依赖可用,更新导入结构字典
    _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]

# 如果类型检查或慢导入标志被设置
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    try:
        # 检查 Transformers 和 Torch 是否可用
        if not (is_transformers_available() and is_torch_available()):
            # 如果不可用,抛出异常
            raise OptionalDependencyNotAvailable()

    # 捕获可选依赖不可用的异常
    except OptionalDependencyNotAvailable:
        # 导入虚拟的 Torch 和 Transformers 对象
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 如果依赖可用,导入相应模块
        from .multicontrolnet import MultiControlNetModel
        from .pipeline_controlnet import StableDiffusionControlNetPipeline
        from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
        from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
        from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
        from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
        from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
        from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline

    try:
        # 检查 Transformers 和 Flax 是否可用
        if not (is_transformers_available() and is_flax_available()):
            # 如果不可用,抛出异常
            raise OptionalDependencyNotAvailable()
    # 捕获可选依赖项不可用的异常
        except OptionalDependencyNotAvailable:
            # 从 dummy 模块导入所有内容,忽略 F403 警告
            from ...utils.dummy_flax_and_transformers_objects import *  # noqa F403
        else:
            # 从 pipeline_flax_controlnet 模块导入 FlaxStableDiffusionControlNetPipeline
            from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
# 如果之前的条件不满足,执行以下代码
else:
    # 导入 sys 模块,用于访问和操作 Python 解释器的运行时环境
    import sys

    # 将当前模块替换为一个延迟加载的模块对象
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 当前文件的路径
        _import_structure,  # 模块的导入结构
        module_spec=__spec__,  # 模块的规格对象
    )
    # 遍历假对象字典,将每个对象属性设置到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)  # 设置模块属性

.\diffusers\pipelines\controlnet_hunyuandit\pipeline_hunyuandit_controlnet.py

# 版权声明,指明文件的版权归 HunyuanDiT 和 HuggingFace 团队所有
# 本文件在 Apache 2.0 许可证下授权使用
# 除非遵循许可证,否则不能使用此文件
# 许可证的副本可以在以下网址获取
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律规定或书面协议另有约定,否则软件在"按现状"基础上提供,不附带任何明示或暗示的保证
# 查看许可证以了解特定语言的权限和限制

# 导入用于获取函数信息的 inspect 模块
import inspect
# 导入类型提示所需的类型
from typing import Callable, Dict, List, Optional, Tuple, Union

# 导入 numpy 库
import numpy as np
# 导入 PyTorch 库
import torch
# 从 transformers 库导入相关模型和分词器
from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel

# 从 diffusers 库导入 StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

# 导入多管道回调类
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
# 导入图像处理类
from ...image_processor import PipelineImageInput, VaeImageProcessor
# 导入自动编码器和模型
from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel
# 导入 2D 旋转位置嵌入函数
from ...models.embeddings import get_2d_rotary_pos_embed
# 导入稳定扩散安全检查器
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# 导入扩散调度器
from ...schedulers import DDPMScheduler
# 导入实用工具函数
from ...utils import (
    is_torch_xla_available,  # 检查是否可用 XLA
    logging,  # 导入日志记录模块
    replace_example_docstring,  # 替换示例文档字符串的工具
)
# 导入 PyTorch 相关的随机张量函数
from ...utils.torch_utils import randn_tensor
# 导入扩散管道工具类
from ..pipeline_utils import DiffusionPipeline

# 检查是否可用 XLA,并根据结果导入相应模块
if is_torch_xla_available():
    import torch_xla.core.xla_model as xm  # 导入 XLA 核心模型

    XLA_AVAILABLE = True  # 设置 XLA 可用标志为 True
else:
    XLA_AVAILABLE = False  # 设置 XLA 可用标志为 False

# 创建一个日志记录器实例,记录当前模块的日志
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 示例文档字符串,用于说明使用方法
EXAMPLE_DOC_STRING = """
# 示例代码展示如何使用 HunyuanDiT 进行图像生成
    Examples:
        ```py
        # 从 diffusers 库导入所需的模型和管道
        from diffusers import HunyuanDiT2DControlNetModel, HunyuanDiTControlNetPipeline
        # 导入 PyTorch 库
        import torch

        # 从预训练模型加载 HunyuanDiT2DControlNetModel,并指定数据类型为 float16
        controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
            "Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny", torch_dtype=torch.float16
        )

        # 从预训练模型加载 HunyuanDiTControlNetPipeline,传入 controlnet 和数据类型
        pipe = HunyuanDiTControlNetPipeline.from_pretrained(
            "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
        )
        # 将管道移动到 CUDA 设备以加速处理
        pipe.to("cuda")

        # 从 diffusers.utils 导入加载图像的工具
        from diffusers.utils import load_image

        # 从指定 URL 加载条件图像
        cond_image = load_image(
            "https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Canny/resolve/main/canny.jpg?download=true"
        )

        ## HunyuanDiT 支持英语和中文提示,因此也可以使用英文提示
        # 定义图像生成的提示内容,描述夜晚的场景
        prompt = "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围"
        # prompt="At night, an ancient Chinese-style lion statue stands in front of the hotel, its eyes gleaming as if guarding the building. The background is the hotel entrance at night, with a close-up, eye-level, and centered composition. This photo presents a realistic photographic style, embodies Chinese sculpture culture, and reveals a mysterious atmosphere."
        # 使用提示、图像尺寸、条件图像和推理步骤生成图像,并获取生成的第一张图像
        image = pipe(
            prompt,
            height=1024,
            width=1024,
            control_image=cond_image,
            num_inference_steps=50,
        ).images[0]
        ```  

"""

文档字符串,通常用于描述模块或类的功能

"""

定义一个标准宽高比的 NumPy 数组

STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)

定义一个标准尺寸的列表,每个比例对应不同的宽高组合

STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]

根据标准尺寸计算每个形状的面积,并将结果存储在 NumPy 数组中

STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]

定义一个支持的尺寸列表,包含不同的宽高组合

SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 768), # 16:9
(768, 1280), # 9:16
]

定义一个函数,用于将目标宽高映射到标准形状

def map_to_standard_shapes(target_width, target_height):
# 计算目标宽高比
target_ratio = target_width / target_height
# 找到与目标宽高比最接近的标准宽高比的索引
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
# 找到与目标面积最接近的标准形状的索引
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
# 获取对应的标准宽和高
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
# 返回标准宽和高
return width, height

定义一个函数,用于计算源图像的缩放裁剪区域以适应目标大小

def get_resize_crop_region_for_grid(src, tgt_size):
# 获取目标尺寸的高度和宽度
th = tw = tgt_size
# 获取源图像的高度和宽度
h, w = src

# 计算源图像的宽高比
r = h / w

# 根据宽高比决定缩放方式
# 如果高度大于宽度
if r > 1:
    # 将目标高度作为缩放高度
    resize_height = th
    # 根据高度缩放计算对应的宽度
    resize_width = int(round(th / h * w))
else:
    # 否则,将目标宽度作为缩放宽度
    resize_width = tw
    # 根据宽度缩放计算对应的高度
    resize_height = int(round(tw / w * h))

# 计算裁剪区域的顶部和左边位置
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))

# 返回裁剪区域的起始和结束坐标
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)

从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 复制的函数

def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
根据 guidance_rescalenoise_cfg 进行重新缩放。基于论文Common Diffusion Noise Schedules and
Sample Steps are Flawed
中的发现。见第3.4节
"""
# 计算噪声预测文本的标准差
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
# 计算噪声配置的标准差
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# 重新缩放来自引导的结果(修复过度曝光问题)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# 按照引导缩放因子与原始引导结果进行混合,以避免生成“单调”的图像
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
# 返回重新缩放后的噪声配置
return noise_cfg

定义 HunyuanDiT 控制网络管道类,继承自 DiffusionPipeline

class HunyuanDiTControlNetPipeline(DiffusionPipeline):
r"""
使用 HunyuanDiT 进行英语/中文到图像生成的管道。

该模型继承自 [`DiffusionPipeline`]. 请查看超类文档以获取库为所有管道实现的通用方法
(例如下载或保存,在特定设备上运行等)。

HunyuanDiT 使用两个文本编码器:[mT5](https://huggingface.co/google/mt5-base) 和 [双语 CLIP](自行微调)
"""
# 参数说明
Args:
    vae ([`AutoencoderKL`]):  # 变分自编码器模型,用于将图像编码和解码为潜在表示,这里使用'sdxl-vae-fp16-fix'
        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. We use
        `sdxl-vae-fp16-fix`.
    text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):  # 冻结的文本编码器,使用CLIP模型
        Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). 
        HunyuanDiT uses a fine-tuned [bilingual CLIP].
    tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):  # 文本标记化器,可以是BertTokenizer或CLIPTokenizer
        A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
    transformer ([`HunyuanDiT2DModel`]):  # HunyuanDiT模型,由腾讯Hunyuan设计
        The HunyuanDiT model designed by Tencent Hunyuan.
    text_encoder_2 (`T5EncoderModel`):  # mT5嵌入模型,特别是't5-v1_1-xxl'
        The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
    tokenizer_2 (`MT5Tokenizer`):  # mT5嵌入模型的标记化器
        The tokenizer for the mT5 embedder.
    scheduler ([`DDPMScheduler`]):  # 调度器,用于与HunyuanDiT结合,去噪编码的图像潜在表示
        A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
    controlnet ([`HunyuanDiT2DControlNetModel`] or `List[HunyuanDiT2DControlNetModel]` or [`HunyuanDiT2DControlNetModel`]):  # 提供额外的条件信息以辅助去噪过程
        Provides additional conditioning to the `unet` during the denoising process. If you set multiple
        ControlNets as a list, the outputs from each ControlNet are added together to create one combined
        additional conditioning.
"""

# 定义模型在CPU上卸载的顺序
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
# 可选组件列表,可能会在初始化中使用
_optional_components = [
    "safety_checker",  # 安全检查器
    "feature_extractor",  # 特征提取器
    "text_encoder_2",  # 第二个文本编码器
    "tokenizer_2",  # 第二个标记化器
    "text_encoder",  # 第一个文本编码器
    "tokenizer",  # 第一个标记化器
]
# 从CPU卸载中排除的组件
_exclude_from_cpu_offload = ["safety_checker"]  # 不允许卸载安全检查器
# 回调张量输入的列表,用于传递给模型
_callback_tensor_inputs = [
    "latents",  # 潜在变量
    "prompt_embeds",  # 提示的嵌入表示
    "negative_prompt_embeds",  # 负提示的嵌入表示
    "prompt_embeds_2",  # 第二个提示的嵌入表示
    "negative_prompt_embeds_2",  # 第二个负提示的嵌入表示
]

# 初始化方法定义,接收多个参数以构造模型
def __init__(
    self,
    vae: AutoencoderKL,  # 变分自编码器模型
    text_encoder: BertModel,  # 文本编码器
    tokenizer: BertTokenizer,  # 文本标记化器
    transformer: HunyuanDiT2DModel,  # HunyuanDiT模型
    scheduler: DDPMScheduler,  # 调度器
    safety_checker: StableDiffusionSafetyChecker,  # 安全检查器
    feature_extractor: CLIPImageProcessor,  # 特征提取器
    controlnet: Union[  # 控制网络,可以是单个或多个模型
        HunyuanDiT2DControlNetModel,
        List[HunyuanDiT2DControlNetModel],
        Tuple[HunyuanDiT2DControlNetModel],
        HunyuanDiT2DMultiControlNetModel,
    ],
    text_encoder_2=T5EncoderModel,  # 第二个文本编码器,默认使用T5模型
    tokenizer_2=MT5Tokenizer,  # 第二个标记化器,默认使用MT5标记化器
    requires_safety_checker: bool = True,  # 是否需要安全检查器,默认是True
# 初始化父类
):
    super().__init__()

    # 注册多个模块,提供必要的组件以供使用
    self.register_modules(
        vae=vae,  # 注册变分自编码器
        text_encoder=text_encoder,  # 注册文本编码器
        tokenizer=tokenizer,  # 注册分词器
        tokenizer_2=tokenizer_2,  # 注册第二个分词器
        transformer=transformer,  # 注册变换器
        scheduler=scheduler,  # 注册调度器
        safety_checker=safety_checker,  # 注册安全检查器
        feature_extractor=feature_extractor,  # 注册特征提取器
        text_encoder_2=text_encoder_2,  # 注册第二个文本编码器
        controlnet=controlnet,  # 注册控制网络
    )

    # 检查安全检查器是否为 None 并且需要使用安全检查器
    if safety_checker is None and requires_safety_checker:
        # 记录警告信息,提醒用户禁用安全检查器的后果
        logger.warning(
            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
        )

    # 检查安全检查器不为 None 且特征提取器为 None
    if safety_checker is not None and feature_extractor is None:
        # 抛出错误,提醒用户必须定义特征提取器
        raise ValueError(
            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
        )

    # 计算 VAE 的缩放因子,如果存在 VAE 配置则使用其通道数量,否则默认为 8
    self.vae_scale_factor = (
        2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
    )
    # 初始化图像处理器,传入 VAE 缩放因子
    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
    # 注册到配置中,指明是否需要安全检查器
    self.register_to_config(requires_safety_checker=requires_safety_checker)
    # 设置默认样本大小,根据变换器配置或默认为 128
    self.default_sample_size = (
        self.transformer.config.sample_size
        if hasattr(self, "transformer") and self.transformer is not None
        else 128
    )

# 从其他模块复制的方法,用于编码提示
def encode_prompt(
    self,
    prompt: str,  # 输入的提示文本
    device: torch.device = None,  # 设备参数,指定在哪个设备上处理
    dtype: torch.dtype = None,  # 数据类型参数,指定张量的数据类型
    num_images_per_prompt: int = 1,  # 每个提示生成的图像数量
    do_classifier_free_guidance: bool = True,  # 是否执行无分类器的引导
    negative_prompt: Optional[str] = None,  # 可选的负面提示文本
    prompt_embeds: Optional[torch.Tensor] = None,  # 可选的提示嵌入张量
    negative_prompt_embeds: Optional[torch.Tensor] = None,  # 可选的负面提示嵌入张量
    prompt_attention_mask: Optional[torch.Tensor] = None,  # 可选的提示注意力掩码
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,  # 可选的负面提示注意力掩码
    max_sequence_length: Optional[int] = None,  # 可选的最大序列长度
    text_encoder_index: int = 0,  # 文本编码器索引,默认值为 0
# 从其他模块复制的方法,用于运行安全检查器
# 定义运行安全检查器的方法,接收图像、设备和数据类型作为参数
def run_safety_checker(self, image, device, dtype):
    # 如果安全检查器未定义,设置无敏感内容标志为 None
    if self.safety_checker is None:
        has_nsfw_concept = None
    else:
        # 如果输入图像是张量,则后处理为 PIL 格式
        if torch.is_tensor(image):
            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
        else:
            # 如果输入图像不是张量,则将其转换为 PIL 格式
            feature_extractor_input = self.image_processor.numpy_to_pil(image)
        # 使用特征提取器处理输入图像并将其转换为指定设备上的张量
        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
        # 调用安全检查器,返回处理后的图像和无敏感内容标志
        image, has_nsfw_concept = self.safety_checker(
            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
        )
    # 返回处理后的图像和无敏感内容标志
    return image, has_nsfw_concept

    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 复制
    def prepare_extra_step_kwargs(self, generator, eta):
        # 为调度器步骤准备额外参数,因为不是所有调度器的参数签名相同
        # eta(η)仅在 DDIMScheduler 中使用,其他调度器将忽略
        # eta 在 DDIM 论文中的对应关系:https://arxiv.org/abs/2010.02502
        # 其值应在 [0, 1] 之间

        # 检查调度器的步骤是否接受 eta 参数
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        # 如果接受 eta,则将其添加到额外参数字典中
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # 检查调度器的步骤是否接受 generator 参数
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        # 如果接受 generator,则将其添加到额外参数字典中
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        # 返回额外参数字典
        return extra_step_kwargs

    # 从 diffusers.pipelines.hunyuandit.pipeline_hunyuandit.HunyuanDiTPipeline.check_inputs 复制
    def check_inputs(
        self,
        prompt,
        height,
        width,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        prompt_attention_mask=None,
        negative_prompt_attention_mask=None,
        prompt_embeds_2=None,
        negative_prompt_embeds_2=None,
        prompt_attention_mask_2=None,
        negative_prompt_attention_mask_2=None,
        callback_on_step_end_tensor_inputs=None,
    # 从 diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 复制
# 准备潜在变量的函数,接收多个参数以配置潜在变量的形状和生成方式
    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        # 根据批大小、通道数、高度和宽度计算潜在变量的形状
        shape = (
            batch_size,
            num_channels_latents,
            int(height) // self.vae_scale_factor,
            int(width) // self.vae_scale_factor,
        )
        # 检查生成器列表的长度是否与批大小匹配
        if isinstance(generator, list) and len(generator) != batch_size:
            # 如果不匹配,抛出一个值错误
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        # 如果潜在变量为 None,则生成随机潜在变量
        if latents is None:
            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
        else:
            # 如果已提供潜在变量,将其移动到指定设备
            latents = latents.to(device)

        # 按调度器要求的标准差缩放初始噪声
        latents = latents * self.scheduler.init_noise_sigma
        # 返回处理后的潜在变量
        return latents

    # 准备图像的函数,从外部调用
    def prepare_image(
        self,
        image,
        width,
        height,
        batch_size,
        num_images_per_prompt,
        device,
        dtype,
        do_classifier_free_guidance=False,
        guess_mode=False,
    ):
        # 检查图像是否为张量,如果是则不处理
        if isinstance(image, torch.Tensor):
            pass
        else:
            # 否则对图像进行预处理,调整为指定的高度和宽度
            image = self.image_processor.preprocess(image, height=height, width=width)

        # 获取图像的批大小
        image_batch_size = image.shape[0]

        # 如果图像批大小为1,则重复次数为批大小
        if image_batch_size == 1:
            repeat_by = batch_size
        else:
            # 否则图像批大小与提示批大小相同
            repeat_by = num_images_per_prompt

        # 沿着维度0重复图像
        image = image.repeat_interleave(repeat_by, dim=0)

        # 将图像移动到指定设备,并转换为指定数据类型
        image = image.to(device=device, dtype=dtype)

        # 如果启用了无分类器自由引导,并且未启用猜测模式,则将图像复制两次
        if do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)

        # 返回处理后的图像
        return image

    # 获取指导比例的属性
    @property
    def guidance_scale(self):
        # 返回当前的指导比例
        return self._guidance_scale

    # 获取指导重标定的属性
    @property
    def guidance_rescale(self):
        # 返回当前的指导重标定值
        return self._guidance_rescale

    # 此属性定义了类似于论文中指导权重的定义
    @property
    def do_classifier_free_guidance(self):
        # 如果指导比例大于1,则启用无分类器自由引导
        return self._guidance_scale > 1

    # 获取时间步数的属性
    @property
    def num_timesteps(self):
        # 返回当前的时间步数
        return self._num_timesteps

    # 获取中断状态的属性
    @property
    def interrupt(self):
        # 返回当前中断状态
        return self._interrupt

    # 在不计算梯度的情况下运行,替换示例文档字符串
    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
# 定义一个可调用的类方法
    def __call__(
        # 提示内容,可以是字符串或字符串列表
        self,
        prompt: Union[str, List[str]] = None,
        # 输出图像的高度
        height: Optional[int] = None,
        # 输出图像的宽度
        width: Optional[int] = None,
        # 推理步骤的数量,默认为 50
        num_inference_steps: Optional[int] = 50,
        # 引导比例,默认为 5.0
        guidance_scale: Optional[float] = 5.0,
        # 控制图像输入,默认为 None
        control_image: PipelineImageInput = None,
        # 控制网条件比例,可以是单一值或值列表,默认为 1.0
        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
        # 负提示内容,可以是字符串或字符串列表,默认为 None
        negative_prompt: Optional[Union[str, List[str]]] = None,
        # 每个提示生成的图像数量,默认为 1
        num_images_per_prompt: Optional[int] = 1,
        # 用于生成的随机性,默认为 0.0
        eta: Optional[float] = 0.0,
        # 随机数生成器,可以是单个或列表,默认为 None
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        # 潜在变量,默认为 None
        latents: Optional[torch.Tensor] = None,
        # 提示的嵌入,默认为 None
        prompt_embeds: Optional[torch.Tensor] = None,
        # 第二组提示的嵌入,默认为 None
        prompt_embeds_2: Optional[torch.Tensor] = None,
        # 负提示的嵌入,默认为 None
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        # 第二组负提示的嵌入,默认为 None
        negative_prompt_embeds_2: Optional[torch.Tensor] = None,
        # 提示的注意力掩码,默认为 None
        prompt_attention_mask: Optional[torch.Tensor] = None,
        # 第二组提示的注意力掩码,默认为 None
        prompt_attention_mask_2: Optional[torch.Tensor] = None,
        # 负提示的注意力掩码,默认为 None
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        # 第二组负提示的注意力掩码,默认为 None
        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
        # 输出类型,默认为 "pil"
        output_type: Optional[str] = "pil",
        # 是否返回字典格式,默认为 True
        return_dict: bool = True,
        # 在步骤结束时的回调函数
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        # 回调时的张量输入列表,默认为 ["latents"]
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        # 引导重标定,默认为 0.0
        guidance_rescale: float = 0.0,
        # 原始图像大小,默认为 (1024, 1024)
        original_size: Optional[Tuple[int, int]] = (1024, 1024),
        # 目标图像大小,默认为 None
        target_size: Optional[Tuple[int, int]] = None,
        # 裁剪坐标,默认为 (0, 0)
        crops_coords_top_left: Tuple[int, int] = (0, 0),
        # 是否使用分辨率分箱,默认为 True
        use_resolution_binning: bool = True,

# `.\diffusers\pipelines\controlnet_hunyuandit\__init__.py`

```py
# 从 typing 模块导入 TYPE_CHECKING,用于静态类型检查
from typing import TYPE_CHECKING

# 从父模块的 utils 导入多个工具函数和常量
from ...utils import (
    DIFFUSERS_SLOW_IMPORT,  # 导入慢导入的标志
    OptionalDependencyNotAvailable,  # 导入可选依赖不可用的异常
    _LazyModule,  # 导入延迟加载模块的类
    get_objects_from_module,  # 导入从模块获取对象的函数
    is_torch_available,  # 导入检查 PyTorch 是否可用的函数
    is_transformers_available,  # 导入检查 Transformers 是否可用的函数
)

# 创建一个空字典,用于存储假对象
_dummy_objects = {}
# 创建一个空字典,用于存储导入结构
_import_structure = {}

# 尝试检查是否可用的依赖
try:
    # 如果 Transformers 和 Torch 不可用,抛出异常
    if not (is_transformers_available() and is_torch_available()):
        raise OptionalDependencyNotAvailable()
# 捕获可选依赖不可用的异常
except OptionalDependencyNotAvailable:
    # 从 utils 导入假对象(dummy objects),避免直接依赖
    from ...utils import dummy_torch_and_transformers_objects  # noqa F403

    # 更新 _dummy_objects 字典,包含假对象
    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
# 如果依赖可用,更新导入结构
else:
    # 将 HunyuanDiTControlNetPipeline 加入导入结构
    _import_structure["pipeline_hunyuandit_controlnet"] = ["HunyuanDiTControlNetPipeline"]

# 检查类型是否在检查模式或是否需要慢导入
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    # 尝试检查是否可用的依赖
    try:
        # 如果 Transformers 和 Torch 不可用,抛出异常
        if not (is_transformers_available() and is_torch_available()):
            raise OptionalDependencyNotAvailable()

    # 捕获可选依赖不可用的异常
    except OptionalDependencyNotAvailable:
        # 从 utils 导入所有假对象,避免直接依赖
        from ...utils.dummy_torch_and_transformers_objects import *
    else:
        # 导入真实的 HunyuanDiTControlNetPipeline 类
        from .pipeline_hunyuandit_controlnet import HunyuanDiTControlNetPipeline

# 如果不在类型检查或不需要慢导入
else:
    # 导入 sys 模块
    import sys

    # 将当前模块替换为延迟加载模块,传递必要的信息
    sys.modules[__name__] = _LazyModule(
        __name__,  # 模块名称
        globals()["__file__"],  # 模块文件路径
        _import_structure,  # 导入结构
        module_spec=__spec__,  # 模块规格
    )

    # 遍历假对象并设置到当前模块
    for name, value in _dummy_objects.items():
        setattr(sys.modules[__name__], name, value)
posted @ 2024-10-22 12:35  绝不原创的飞龙  阅读(10)  评论(0编辑  收藏  举报