diffusers-源码解析-七-

diffusers 源码解析(七)

.\diffusers\models\autoencoders\vq_model.py

# 版权声明,指明版权归 HuggingFace 团队所有
# 
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 你不得在不遵守许可证的情况下使用此文件。
# 可以在以下网址获取许可证副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,否则软件在“按原样”基础上分发,
# 不提供任何形式的担保或条件,无论是明示或暗示的。
# 有关许可证的具体条款和条件,请参阅许可证。
# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入可选类型、元组和联合类型
from typing import Optional, Tuple, Union

# 导入 PyTorch 库
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn

# 从配置工具中导入 ConfigMixin 和注册配置的函数
from ...configuration_utils import ConfigMixin, register_to_config
# 从工具模块导入 BaseOutput 类
from ...utils import BaseOutput
# 从加速工具中导入应用前向钩子的函数
from ...utils.accelerate_utils import apply_forward_hook
# 从自动编码器模块导入解码器、解码器输出、编码器和向量量化器
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
# 从建模工具中导入 ModelMixin 类
from ..modeling_utils import ModelMixin


# 定义一个数据类,用于表示 VQModel 编码方法的输出
@dataclass
class VQEncoderOutput(BaseOutput):
    """
    VQModel 编码方法的输出。

    参数:
        latents (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`):
            模型最后一层的编码输出样本。
    """

    # 定义一个属性 latents,类型为 torch.Tensor
    latents: torch.Tensor


# 定义 VQModel 类,继承自 ModelMixin 和 ConfigMixin
class VQModel(ModelMixin, ConfigMixin):
    r"""
    用于解码潜在表示的 VQ-VAE 模型。

    该模型继承自 [`ModelMixin`]。请查看超类文档,以了解其为所有模型实现的通用方法
    (例如下载或保存)。
    # 函数参数说明部分
    Parameters:
        # 输入图像的通道数,默认为3
        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
        # 输出图像的通道数,默认为3
        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
        # 下采样块类型的元组,默认为包含一个类型的元组
        down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            Tuple of downsample block types.
        # 上采样块类型的元组,默认为包含一个类型的元组
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            Tuple of upsample block types.
        # 块输出通道数的元组,默认为包含一个值的元组
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
            Tuple of block output channels.
        # 每个块的层数,默认为1
        layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
        # 激活函数类型,默认为"silu"
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        # 潜在空间的通道数,默认为3
        latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
        # 输入样本的大小,默认为32
        sample_size (`int`, *optional*, defaults to `32`): Sample input size.
        # VQ-VAE中的代码本向量数量,默认为256
        num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
        # 归一化层的组数,默认为32
        norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
        # VQ-VAE中代码本向量的隐藏维度,可选
        vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
        # 缩放因子,默认为0.18215,主要用于训练时的标准化
        scaling_factor (`float`, *optional*, defaults to `0.18215`):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
        # 归一化层的类型,默认为"group",可选为"group"或"spatial"
        norm_type (`str`, *optional*, defaults to `"group"`):
            Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
    """

    # 注册配置的构造函数
    @register_to_config
    def __init__(
        # 输入通道参数,默认为3
        self,
        in_channels: int = 3,
        # 输出通道参数,默认为3
        out_channels: int = 3,
        # 下采样块类型,默认为("DownEncoderBlock2D",)
        down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
        # 上采样块类型,默认为("UpDecoderBlock2D",)
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        # 块输出通道参数,默认为(64,)
        block_out_channels: Tuple[int, ...] = (64,),
        # 每块层数参数,默认为1
        layers_per_block: int = 1,
        # 激活函数参数,默认为"silu"
        act_fn: str = "silu",
        # 潜在通道数参数,默认为3
        latent_channels: int = 3,
        # 样本大小参数,默认为32
        sample_size: int = 32,
        # VQ-VAE代码本向量数量,默认为256
        num_vq_embeddings: int = 256,
        # 归一化层组数参数,默认为32
        norm_num_groups: int = 32,
        # VQ-VAE代码本向量隐藏维度,默认为None
        vq_embed_dim: Optional[int] = None,
        # 缩放因子参数,默认为0.18215
        scaling_factor: float = 0.18215,
        # 归一化层类型参数,默认为"group"
        norm_type: str = "group",  # group, spatial
        # 是否在中间块添加注意力,默认为True
        mid_block_add_attention=True,
        # 是否从代码本查找,默认为False
        lookup_from_codebook=False,
        # 是否强制上溯,默认为False
        force_upcast=False,
    # 初始化方法,调用父类构造函数
        ):
            super().__init__()
    
            # 将初始化参数传递给编码器
            self.encoder = Encoder(
                in_channels=in_channels,  # 输入通道数
                out_channels=latent_channels,  # 潜在通道数
                down_block_types=down_block_types,  # 下采样块类型
                block_out_channels=block_out_channels,  # 块输出通道数
                layers_per_block=layers_per_block,  # 每个块的层数
                act_fn=act_fn,  # 激活函数
                norm_num_groups=norm_num_groups,  # 归一化的组数
                double_z=False,  # 是否使用双潜变量
                mid_block_add_attention=mid_block_add_attention,  # 中间块是否添加注意力机制
            )
    
            # 如果未提供,使用潜在通道数作为嵌入维度
            vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
    
            # 创建量化卷积层,将潜在通道数映射到嵌入维度
            self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
            # 初始化向量量化器
            self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
            # 创建后量化卷积层,将嵌入维度映射回潜在通道数
            self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
    
            # 将初始化参数传递给解码器
            self.decoder = Decoder(
                in_channels=latent_channels,  # 潜在通道数
                out_channels=out_channels,  # 输出通道数
                up_block_types=up_block_types,  # 上采样块类型
                block_out_channels=block_out_channels,  # 块输出通道数
                layers_per_block=layers_per_block,  # 每个块的层数
                act_fn=act_fn,  # 激活函数
                norm_num_groups=norm_num_groups,  # 归一化的组数
                norm_type=norm_type,  # 归一化类型
                mid_block_add_attention=mid_block_add_attention,  # 中间块是否添加注意力机制
            )
    
        # 应用前向钩子,定义编码方法
        @apply_forward_hook
        def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
            # 将输入 x 传递给编码器以获取潜在表示
            h = self.encoder(x)
            # 通过量化卷积层处理潜在表示
            h = self.quant_conv(h)
    
            # 如果不需要返回字典,返回潜在表示
            if not return_dict:
                return (h,)
    
            # 返回包含潜在表示的自定义输出对象
            return VQEncoderOutput(latents=h)
    
        # 应用前向钩子,定义解码方法
        @apply_forward_hook
        def decode(
            self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
        ) -> Union[DecoderOutput, torch.Tensor]:
            # 如果不强制不量化,则通过量化层处理潜在表示
            if not force_not_quantize:
                quant, commit_loss, _ = self.quantize(h)
            # 如果从代码本中查找,则获取代码本条目
            elif self.config.lookup_from_codebook:
                quant = self.quantize.get_codebook_entry(h, shape)
                # 初始化承诺损失为零
                commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
            else:
                # 否则直接使用输入
                quant = h
                # 初始化承诺损失为零
                commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
            # 通过后量化卷积层处理量化结果
            quant2 = self.post_quant_conv(quant)
            # 将量化结果传递给解码器以获取输出
            dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
    
            # 如果不需要返回字典,返回解码结果和承诺损失
            if not return_dict:
                return dec, commit_loss
    
            # 返回自定义输出对象,包括解码结果和承诺损失
            return DecoderOutput(sample=dec, commit_loss=commit_loss)
    
        # 定义前向传播方法
        def forward(
            self, sample: torch.Tensor, return_dict: bool = True
    ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
        r"""  # 文档字符串,描述该方法的功能和参数
        The [`VQModel`] forward method.  # 指明这是 VQModel 类的前向传播方法

        Args:  # 参数说明部分
            sample (`torch.Tensor`): Input sample.  # 输入样本,类型为 torch.Tensor
            return_dict (`bool`, *optional*, defaults to `True`):  # 可选参数,指示是否返回字典
                Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple.  # 说明返回值的类型

        Returns:  # 返回值说明部分
            [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`:  # 返回值可以是 VQEncoderOutput 对象或元组
                If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a  # 如果 return_dict 为 True,则返回 VQEncoderOutput
                plain `tuple` is returned.  # 否则返回一个普通的元组
        """

        h = self.encode(sample).latents  # 调用 encode 方法对输入样本进行编码,并获取其潜在表示
        dec = self.decode(h)  # 调用 decode 方法对潜在表示进行解码,获取解码结果

        if not return_dict:  # 如果 return_dict 为 False
            return dec.sample, dec.commit_loss  # 返回解码结果的样本和承诺损失
        return dec  # 否则返回解码结果对象

.\diffusers\models\autoencoders\__init__.py

# 从当前包中导入 AsymmetricAutoencoderKL 类
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
# 从当前包中导入 AutoencoderKL 类
from .autoencoder_kl import AutoencoderKL
# 从当前包中导入 AutoencoderKLCogVideoX 类
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
# 从当前包中导入 AutoencoderKLTemporalDecoder 类
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
# 从当前包中导入 AutoencoderOobleck 类
from .autoencoder_oobleck import AutoencoderOobleck
# 从当前包中导入 AutoencoderTiny 类
from .autoencoder_tiny import AutoencoderTiny
# 从当前包中导入 ConsistencyDecoderVAE 类
from .consistency_decoder_vae import ConsistencyDecoderVAE
# 从当前包中导入 VQModel 类
from .vq_model import VQModel

.\diffusers\models\controlnet.py

# 版权声明,标明版权所有者及其权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 2.0 许可证进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在符合许可证的情况下使用
# you may not use this file except in compliance with the License.
# 可以在以下地址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则软件按“原样”分发,不提供任何形式的担保或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言所规定的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass  # 从dataclasses模块导入dataclass装饰器
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示

import torch  # 导入torch库,用于张量操作
from torch import nn  # 从torch库导入神经网络模块
from torch.nn import functional as F  # 从torch.nn导入功能性操作模块

# 导入配置相关的混合类和注册功能
from ..configuration_utils import ConfigMixin, register_to_config
# 导入原始模型的加载混合类
from ..loaders.single_file_model import FromOriginalModelMixin
# 导入基础输出类和日志记录功能
from ..utils import BaseOutput, logging
# 导入注意力处理器相关的组件
from .attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入增加的KV注意力处理器
    CROSS_ATTENTION_PROCESSORS,  # 导入交叉注意力处理器
    AttentionProcessor,  # 导入注意力处理器基类
    AttnAddedKVProcessor,  # 导入增加KV的注意力处理器
    AttnProcessor,  # 导入基本的注意力处理器
)
# 导入嵌入相关的组件
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
# 导入模型相关的混合类
from .modeling_utils import ModelMixin
# 导入UNet的二维块相关组件
from .unets.unet_2d_blocks import (
    CrossAttnDownBlock2D,  # 导入二维交叉注意力下采样块
    DownBlock2D,  # 导入二维下采样块
    UNetMidBlock2D,  # 导入UNet的中间块
    UNetMidBlock2DCrossAttn,  # 导入具有交叉注意力的UNet中间块
    get_down_block,  # 导入获取下采样块的函数
)
# 导入UNet的条件模型
from .unets.unet_2d_condition import UNet2DConditionModel

# 获取logger实例,用于记录日志信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

@dataclass  # 使用dataclass装饰器定义一个数据类
class ControlNetOutput(BaseOutput):
    """
    ControlNetModel的输出。

    参数:
        down_block_res_samples (`tuple[torch.Tensor]`):
            不同分辨率下的下采样激活元组,每个张量形状为`(batch_size, channel * resolution, height // resolution, width // resolution)`。
            输出可用于对原始UNet的下采样激活进行条件化。
        mid_down_block_re_sample (`torch.Tensor`):
            中间块(最低采样分辨率)的激活。每个张量形状为
            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`。
            输出可用于对原始UNet的中间块激活进行条件化。
    """

    down_block_res_samples: Tuple[torch.Tensor]  # 下采样块的激活张量元组
    mid_block_res_sample: torch.Tensor  # 中间块的激活张量


class ControlNetConditioningEmbedding(nn.Module):
    """
    引用 https://arxiv.org/abs/2302.05543: “Stable Diffusion使用类似于VQ-GAN的预处理方法
    将整个512 × 512图像数据集转换为较小的64 × 64“潜在图像”,以实现稳定训练。
    这要求ControlNets将基于图像的条件转换为64 × 64特征空间,以匹配卷积大小。
    我们使用一个包含四个卷积层的小型网络E(·),卷积核为4 × 4,步幅为2 × 2。
    # 文档字符串,描述这个模块的功能,提到使用 ReLU 激活函数,通道数为 16, 32, 64, 128,采用高斯权重初始化,并与整个模型共同训练,以将图像空间条件编码为特征图
    """

    # 初始化函数,用于定义该类的基本属性
    def __init__(
        # 条件嵌入通道数
        conditioning_embedding_channels: int,
        # 条件通道数,默认为 3(即 RGB 图像)
        conditioning_channels: int = 3,
        # 输出通道的元组,定义卷积层的通道数
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 定义输入卷积层,接收条件通道并输出第一个块的通道数
        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        # 创建一个空的模块列表,用于存储后续的卷积块
        self.blocks = nn.ModuleList([])

        # 遍历 block_out_channels 列表,构建多个卷积块
        for i in range(len(block_out_channels) - 1):
            # 当前块的输入通道数
            channel_in = block_out_channels[i]
            # 下一个块的输出通道数
            channel_out = block_out_channels[i + 1]
            # 添加一个卷积层,保持输入通道数
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            # 添加另一个卷积层,改变输出通道数,同时步幅为 2,进行下采样
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        # 定义输出卷积层,将最后一个块的通道数映射到条件嵌入通道数,并使用零初始化
        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    # 前向传播函数,定义输入如何通过网络传递
    def forward(self, conditioning):
        # 通过输入卷积层处理条件输入,得到嵌入
        embedding = self.conv_in(conditioning)
        # 应用 SiLU 激活函数
        embedding = F.silu(embedding)

        # 遍历所有定义的卷积块,逐层处理嵌入
        for block in self.blocks:
            # 通过当前卷积块处理嵌入
            embedding = block(embedding)
            # 再次应用 SiLU 激活函数
            embedding = F.silu(embedding)

        # 通过输出卷积层处理嵌入
        embedding = self.conv_out(embedding)

        # 返回最终的嵌入结果
        return embedding
# 定义一个 ControlNet 模型类,继承自 ModelMixin, ConfigMixin, 和 FromOriginalModelMixin
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    """
    A ControlNet model.
    """  # 文档字符串,描述该类是一个 ControlNet 模型

    _supports_gradient_checkpointing = True  # 设置支持梯度检查点的标志为真

    @register_to_config  # 注册到配置中
    def __init__(  # 初始化方法,构造 ControlNetModel 的实例
        self,  # 指向实例本身的引用
        in_channels: int = 4,  # 输入通道数,默认为 4
        conditioning_channels: int = 3,  # 条件通道数,默认为 3
        flip_sin_to_cos: bool = True,  # 是否将正弦转换为余弦,默认为真
        freq_shift: int = 0,  # 频率偏移量,默认为 0
        down_block_types: Tuple[str, ...] = (  # 下采样块类型的元组
            "CrossAttnDownBlock2D",  # 第一个下采样块类型
            "CrossAttnDownBlock2D",  # 第二个下采样块类型
            "CrossAttnDownBlock2D",  # 第三个下采样块类型
            "DownBlock2D",  # 第四个下采样块类型
        ),
        mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",  # 中间块类型,默认为 UNet 中间块
        only_cross_attention: Union[bool, Tuple[bool]] = False,  # 是否仅使用交叉注意力,默认为假
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),  # 每个块的输出通道数
        layers_per_block: int = 2,  # 每个块的层数,默认为 2
        downsample_padding: int = 1,  # 下采样的填充,默认为 1
        mid_block_scale_factor: float = 1,  # 中间块的缩放因子,默认为 1
        act_fn: str = "silu",  # 激活函数类型,默认为 "silu"
        norm_num_groups: Optional[int] = 32,  # 规范化的组数,默认为 32
        norm_eps: float = 1e-5,  # 规范化的 epsilon 值,默认为 1e-5
        cross_attention_dim: int = 1280,  # 交叉注意力的维度,默认为 1280
        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,  # 每个块的变换层数,默认为 1
        encoder_hid_dim: Optional[int] = None,  # 编码器隐藏维度,可选
        encoder_hid_dim_type: Optional[str] = None,  # 编码器隐藏维度类型,可选
        attention_head_dim: Union[int, Tuple[int, ...]] = 8,  # 注意力头的维度,默认为 8
        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,  # 注意力头数量,可选
        use_linear_projection: bool = False,  # 是否使用线性投影,默认为假
        class_embed_type: Optional[str] = None,  # 类嵌入类型,可选
        addition_embed_type: Optional[str] = None,  # 附加嵌入类型,可选
        addition_time_embed_dim: Optional[int] = None,  # 附加时间嵌入维度,可选
        num_class_embeds: Optional[int] = None,  # 类嵌入数量,可选
        upcast_attention: bool = False,  # 是否上调注意力,默认为假
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移,默认为 "default"
        projection_class_embeddings_input_dim: Optional[int] = None,  # 投影类嵌入输入维度,可选
        controlnet_conditioning_channel_order: str = "rgb",  # ControlNet 条件通道顺序,默认为 "rgb"
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),  # 条件嵌入输出通道数
        global_pool_conditions: bool = False,  # 是否使用全局池化条件,默认为假
        addition_embed_type_num_heads: int = 64,  # 附加嵌入类型的头数量,默认为 64
    @classmethod
    def from_unet(  # 从 UNet 创建 ControlNetModel 的类方法
        cls,  # 指向类本身的引用
        unet: UNet2DConditionModel,  # 传入的 UNet2DConditionModel 实例
        controlnet_conditioning_channel_order: str = "rgb",  # ControlNet 条件通道顺序,默认为 "rgb"
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),  # 条件嵌入输出通道数
        load_weights_from_unet: bool = True,  # 是否从 UNet 加载权重,默认为真
        conditioning_channels: int = 3,  # 条件通道数,默认为 3
    @property  # 定义一个属性装饰器
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors  # 注释,说明该属性是从另一个模块复制过来的
    # 定义一个返回模型中所有注意力处理器的字典的方法
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r""" 
        返回:
            `dict` 的注意力处理器:包含模型中所有注意力处理器的字典,按其权重名称索引。
        """
        # 初始化一个空字典用于存储处理器
        processors = {}
    
        # 定义一个递归函数来添加注意力处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模块具有获取处理器的方法,将其添加到处理器字典
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用以处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            # 返回更新后的处理器字典
            return processors
    
        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数添加处理器
            fn_recursive_add_processors(name, module, processors)
    
        # 返回包含所有处理器的字典
        return processors
    
    # 从 UNet2DConditionModel 的 set_attn_processor 方法复制而来
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或处理器类的字典,将被设置为**所有** `Attention` 层的处理器。
                
                如果 `processor` 是一个字典,则键需要定义相应交叉注意力处理器的路径。当设置可训练的注意力处理器时,强烈推荐这样做。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 检查传入的处理器字典长度是否与注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
    
        # 定义一个递归函数来设置注意力处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块具有设置处理器的方法,进行处理器设置
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用以处理子模块
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数设置处理器
            fn_recursive_attn_processor(name, module, processor)
    
    # 从 UNet2DConditionModel 的 set_default_attn_processor 方法复制而来
    # 设置默认的注意力处理器,禁用自定义注意力处理器
    def set_default_attn_processor(self):
        # 文档字符串,说明该函数的作用
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        # 检查所有注意力处理器是否为已添加的 KV 注意力处理器
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建添加 KV 的注意力处理器
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否为交叉注意力处理器
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建常规的注意力处理器
            processor = AttnProcessor()
        else:
            # 引发错误,说明注意力处理器类型不支持
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )
    
        # 设置选择的注意力处理器
        self.set_attn_processor(processor)
    
    # 从 diffusers.models.unets.unet_2d_condition 中复制的函数
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 如果模块是特定类型,则设置其梯度检查点标志
        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
            module.gradient_checkpointing = value
    
    # 前向传播方法,处理输入样本和相关参数
    def forward(
        self,
        sample: torch.Tensor,  # 输入样本张量
        timestep: Union[torch.Tensor, float, int],  # 当前时间步
        encoder_hidden_states: torch.Tensor,  # 编码器的隐藏状态
        controlnet_cond: torch.Tensor,  # ControlNet 的条件输入
        conditioning_scale: float = 1.0,  # 条件缩放因子
        class_labels: Optional[torch.Tensor] = None,  # 可选的类别标签
        timestep_cond: Optional[torch.Tensor] = None,  # 可选的时间步条件
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,  # 可选的额外条件参数
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的交叉注意力参数
        guess_mode: bool = False,  # 是否启用猜测模式
        return_dict: bool = True,  # 是否以字典形式返回结果
# 将给定的 PyTorch 模块的所有参数初始化为零
def zero_module(module):
    # 遍历模块的所有参数
    for p in module.parameters():
        # 将当前参数 p 的值初始化为零
        nn.init.zeros_(p)
    # 返回已修改的模块
    return module

.\diffusers\models\controlnet_flax.py

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入类型提示
from typing import Optional, Tuple, Union

# 导入 Flax 库
import flax
import flax.linen as nn
# 导入 JAX 库
import jax
import jax.numpy as jnp
# 从 Flax 导入冻结字典
from flax.core.frozen_dict import FrozenDict

# 导入配置相关的工具
from ..configuration_utils import ConfigMixin, flax_register_to_config
# 导入基础输出类
from ..utils import BaseOutput
# 导入时间步嵌入和时间步类
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
# 导入 Flax 模型混合类
from .modeling_flax_utils import FlaxModelMixin
# 导入 UNet 2D 块
from .unets.unet_2d_blocks_flax import (
    FlaxCrossAttnDownBlock2D,
    FlaxDownBlock2D,
    FlaxUNetMidBlock2DCrossAttn,
)

# 定义 FlaxControlNetOutput 数据类,继承自 BaseOutput
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
    """
    The output of [`FlaxControlNetModel`].
    该类表示 FlaxControlNetModel 的输出。

    Args:
        down_block_res_samples (`jnp.ndarray`): 下层块的结果样本
        mid_block_res_sample (`jnp.ndarray`): 中间块的结果样本
    """

    # 定义下层块结果样本的类型
    down_block_res_samples: jnp.ndarray
    # 定义中间块结果样本的类型
    mid_block_res_sample: jnp.ndarray


# 定义 FlaxControlNetConditioningEmbedding 模块
class FlaxControlNetConditioningEmbedding(nn.Module):
    # 定义输入的条件嵌入通道数
    conditioning_embedding_channels: int
    # 定义每个块的输出通道数
    block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
    # 定义数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32

    # 设置模块的组件
    def setup(self) -> None:
        # 创建输入卷积层,输出通道数为第一个块的输出通道数
        self.conv_in = nn.Conv(
            self.block_out_channels[0],
            kernel_size=(3, 3),  # 卷积核大小
            padding=((1, 1), (1, 1)),  # 填充方式
            dtype=self.dtype,  # 数据类型
        )

        # 初始化块列表
        blocks = []
        # 遍历每对相邻的块输出通道数
        for i in range(len(self.block_out_channels) - 1):
            # 获取当前输入通道数
            channel_in = self.block_out_channels[i]
            # 获取下一个输出通道数
            channel_out = self.block_out_channels[i + 1]
            # 创建第一个卷积层
            conv1 = nn.Conv(
                channel_in,  # 输入通道数
                kernel_size=(3, 3),  # 卷积核大小
                padding=((1, 1), (1, 1)),  # 填充方式
                dtype=self.dtype,  # 数据类型
            )
            # 将卷积层添加到块列表
            blocks.append(conv1)
            # 创建第二个卷积层,带有步幅
            conv2 = nn.Conv(
                channel_out,  # 输出通道数
                kernel_size=(3, 3),  # 卷积核大小
                strides=(2, 2),  # 步幅
                padding=((1, 1), (1, 1)),  # 填充方式
                dtype=self.dtype,  # 数据类型
            )
            # 将卷积层添加到块列表
            blocks.append(conv2)
        # 将所有块存储为类的属性
        self.blocks = blocks

        # 创建输出卷积层,输出通道数为条件嵌入通道数
        self.conv_out = nn.Conv(
            self.conditioning_embedding_channels,  # 输出通道数
            kernel_size=(3, 3),  # 卷积核大小
            padding=((1, 1), (1, 1)),  # 填充方式
            kernel_init=nn.initializers.zeros_init(),  # 权重初始化为零
            bias_init=nn.initializers.zeros_init(),  # 偏置初始化为零
            dtype=self.dtype,  # 数据类型
        )
    # 定义调用方法,接收条件输入并返回处理后的嵌入
        def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
            # 通过输入卷积层处理条件输入,生成嵌入
            embedding = self.conv_in(conditioning)
            # 应用 SiLU 激活函数到嵌入
            embedding = nn.silu(embedding)
    
            # 遍历所有块进行嵌入的逐层处理
            for block in self.blocks:
                # 通过当前块处理嵌入
                embedding = block(embedding)
                # 再次应用 SiLU 激活函数到嵌入
                embedding = nn.silu(embedding)
    
            # 通过输出卷积层处理嵌入,得到最终结果
            embedding = self.conv_out(embedding)
    
            # 返回最终处理后的嵌入
            return embedding
# 注册类到 Flax 配置管理
@flax_register_to_config
# 定义 FlaxControlNetModel 类,继承自 nn.Module, FlaxModelMixin 和 ConfigMixin
class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
    r"""
    一个 ControlNet 模型。

    该模型继承自 [`FlaxModelMixin`]。请查看超类文档以了解它为所有模型实现的通用方法
    (例如下载或保存)。

    此模型也是 Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    的子类。可以将其用作常规的 Flax Linen 模块,并参考 Flax 文档以了解与其
    一般用法和行为相关的所有事项。

    支持 JAX 的固有特性,例如:

    - [即时编译 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [自动微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [并行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    参数:
        sample_size (`int`, *可选*):
            输入样本的大小。
        in_channels (`int`, *可选*, 默认为 4):
            输入样本中的通道数。
        down_block_types (`Tuple[str]`, *可选*, 默认为 `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
            使用的下采样块的元组。
        block_out_channels (`Tuple[int]`, *可选*, 默认为 `(320, 640, 1280, 1280)`):
            每个块的输出通道元组。
        layers_per_block (`int`, *可选*, 默认为 2):
            每个块的层数。
        attention_head_dim (`int` 或 `Tuple[int]`, *可选*, 默认为 8):
            注意力头的维度。
        num_attention_heads (`int` 或 `Tuple[int]`, *可选*):
            注意力头的数量。
        cross_attention_dim (`int`, *可选*, 默认为 768):
            跨注意力特征的维度。
        dropout (`float`, *可选*, 默认为 0):
            下采样、上采样和瓶颈块的 dropout 概率。
        flip_sin_to_cos (`bool`, *可选*, 默认为 `True`):
            是否在时间嵌入中将 sin 转换为 cos。
        freq_shift (`int`, *可选*, 默认为 0): 应用于时间嵌入的频率偏移。
        controlnet_conditioning_channel_order (`str`, *可选*, 默认为 `rgb`):
            条件图像的通道顺序。如果是 `bgr`,将转换为 `rgb`。
        conditioning_embedding_out_channels (`tuple`, *可选*, 默认为 `(16, 32, 96, 256)`):
            `conditioning_embedding` 层中每个块的输出通道元组。
    """

    # 设置输入样本的默认大小
    sample_size: int = 32
    # 设置输入样本的默认通道数
    in_channels: int = 4
    # 定义下采样块的类型元组,包括三次 CrossAttnDownBlock2D 和一次 DownBlock2D
    down_block_types: Tuple[str, ...] = (
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    )
    # 定义是否仅使用交叉注意力,默认为 False
    only_cross_attention: Union[bool, Tuple[bool, ...]] = False
    # 定义每个块的输出通道数的元组
    block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
    # 定义每个块的层数,默认为 2
    layers_per_block: int = 2
    # 定义注意力头的维度,默认为 8
    attention_head_dim: Union[int, Tuple[int, ...]] = 8
    # 可选的注意力头数量,默认为 None
    num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
    # 定义交叉注意力的维度,默认为 1280
    cross_attention_dim: int = 1280
    # 定义 dropout 概率,默认为 0.0
    dropout: float = 0.0
    # 定义是否使用线性投影,默认为 False
    use_linear_projection: bool = False
    # 定义数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定义是否翻转正弦和余弦的布尔值,默认为 True
    flip_sin_to_cos: bool = True
    # 定义频率偏移,默认为 0
    freq_shift: int = 0
    # 定义 ControlNet 条件通道的顺序,默认为 "rgb"
    controlnet_conditioning_channel_order: str = "rgb"
    # 定义条件嵌入输出通道数的元组
    conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)

    # 初始化权重的方法,接收一个随机数生成器
    def init_weights(self, rng: jax.Array) -> FrozenDict:
        # 初始化输入张量的形状
        sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
        # 创建一个全零的样本张量
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)
        # 创建一个全为 1 的时间步张量
        timesteps = jnp.ones((1,), dtype=jnp.int32)
        # 创建一个全零的编码器隐藏状态张量
        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
        # 创建 ControlNet 条件的形状
        controlnet_cond_shape = (1, 3, self.sample_size * 8, self.sample_size * 8)
        # 创建一个全零的 ControlNet 条件张量
        controlnet_cond = jnp.zeros(controlnet_cond_shape, dtype=jnp.float32)

        # 将 rng 分成两个部分,一个用于参数,一个用于 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        # 创建一个包含参数和 dropout 随机数生成器的字典
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 调用初始化方法,返回包含参数的字典
        return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]

    # 定义可调用方法,接收样本、时间步、编码器隐藏状态等参数
    def __call__(
        self,
        sample: jnp.ndarray,
        timesteps: Union[jnp.ndarray, float, int],
        encoder_hidden_states: jnp.ndarray,
        controlnet_cond: jnp.ndarray,
        # 定义条件缩放因子,默认为 1.0
        conditioning_scale: float = 1.0,
        # 定义是否返回字典,默认为 True
        return_dict: bool = True,
        # 定义是否处于训练模式,默认为 False
        train: bool = False,

.\diffusers\models\controlnet_hunyuan.py

# 版权信息,声明代码的版权所有者和年份
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版许可使用本文件
# Licensed under the Apache License, Version 2.0 (the "License");
# 本文件的使用需遵循许可证的规定
# you may not use this file except in compliance with the License.
# 可通过以下网址获取许可证的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是按“现状”提供的
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# 不提供任何明示或暗示的保证或条件
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证中关于权限和限制的具体条款
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Dict, Optional, Union  # 导入用于类型提示的类型

import torch  # 导入 PyTorch 库
from torch import nn  # 从 PyTorch 导入神经网络模块

from ..configuration_utils import ConfigMixin, register_to_config  # 从配置工具导入配置混合和注册功能
from ..utils import logging  # 从工具模块导入日志功能
from .attention_processor import AttentionProcessor  # 导入注意力处理器
from .controlnet import BaseOutput, Tuple, zero_module  # 导入控制网络的基础输出和相关类型
from .embeddings import (  # 从嵌入模块导入多个类
    HunyuanCombinedTimestepTextSizeStyleEmbedding,
    PatchEmbed,
    PixArtAlphaTextProjection,
)
from .modeling_utils import ModelMixin  # 导入模型混合类
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock  # 导入 Hunyuan 二维变换器块

logger = logging.get_logger(__name__)  # 初始化日志记录器,使用当前模块名作为标识

@dataclass  # 将 HunyuanControlNetOutput 类标记为数据类
class HunyuanControlNetOutput(BaseOutput):  # 定义 HunyuanControlNetOutput 类,继承自 BaseOutput
    controlnet_block_samples: Tuple[torch.Tensor]  # 定义一个包含控制网络块样本的元组属性

class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):  # 定义 HunyuanDiT2DControlNetModel 类,继承自多个混合类
    @register_to_config  # 将该方法注册到配置系统
    def __init__(  # 定义构造函数
        self,
        conditioning_channels: int = 3,  # 初始化条件通道数,默认为 3
        num_attention_heads: int = 16,  # 初始化注意力头数,默认为 16
        attention_head_dim: int = 88,  # 初始化注意力头维度,默认为 88
        in_channels: Optional[int] = None,  # 可选的输入通道数,默认为 None
        patch_size: Optional[int] = None,  # 可选的补丁大小,默认为 None
        activation_fn: str = "gelu-approximate",  # 初始化激活函数,默认为“gelu-approximate”
        sample_size=32,  # 初始化样本大小,默认为 32
        hidden_size=1152,  # 初始化隐藏层大小,默认为 1152
        transformer_num_layers: int = 40,  # 初始化变换器层数,默认为 40
        mlp_ratio: float = 4.0,  # 初始化 MLP 比率,默认为 4.0
        cross_attention_dim: int = 1024,  # 初始化交叉注意力维度,默认为 1024
        cross_attention_dim_t5: int = 2048,  # 初始化 T5 交叉注意力维度,默认为 2048
        pooled_projection_dim: int = 1024,  # 初始化池化投影维度,默认为 1024
        text_len: int = 77,  # 初始化文本长度,默认为 77
        text_len_t5: int = 256,  # 初始化 T5 文本长度,默认为 256
        use_style_cond_and_image_meta_size: bool = True,  # 初始化样式条件和图像元大小的使用标志,默认为 True
    # 初始化父类
        ):
            super().__init__()
            # 设置注意力头数量
            self.num_heads = num_attention_heads
            # 计算内部维度
            self.inner_dim = num_attention_heads * attention_head_dim
    
            # 创建文本嵌入投影层
            self.text_embedder = PixArtAlphaTextProjection(
                in_features=cross_attention_dim_t5,  # 输入特征维度
                hidden_size=cross_attention_dim_t5 * 4,  # 隐藏层大小
                out_features=cross_attention_dim,  # 输出特征维度
                act_fn="silu_fp32",  # 激活函数
            )
    
            # 创建文本嵌入的可学习参数
            self.text_embedding_padding = nn.Parameter(
                torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)  # 初始化随机张量
            )
    
            # 创建位置嵌入层
            self.pos_embed = PatchEmbed(
                height=sample_size,  # 高度
                width=sample_size,  # 宽度
                in_channels=in_channels,  # 输入通道数
                embed_dim=hidden_size,  # 嵌入维度
                patch_size=patch_size,  # 每个块的大小
                pos_embed_type=None,  # 位置嵌入类型
            )
    
            # 创建时间额外嵌入层
            self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
                hidden_size,  # 隐藏层大小
                pooled_projection_dim=pooled_projection_dim,  # 池化投影维度
                seq_len=text_len_t5,  # 序列长度
                cross_attention_dim=cross_attention_dim_t5,  # 跨注意力维度
                use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,  # 是否使用样式条件和图像元大小
            )
    
            # 初始化控制网络模块列表
            self.controlnet_blocks = nn.ModuleList([])
    
            # 初始化 HunyuanDiT 模块列表
            self.blocks = nn.ModuleList(
                [
                    HunyuanDiTBlock(
                        dim=self.inner_dim,  # 模块维度
                        num_attention_heads=self.config.num_attention_heads,  # 注意力头数量
                        activation_fn=activation_fn,  # 激活函数
                        ff_inner_dim=int(self.inner_dim * mlp_ratio),  # 前馈层内部维度
                        cross_attention_dim=cross_attention_dim,  # 跨注意力维度
                        qk_norm=True,  # 是否使用 QK 归一化
                        skip=False,  # 是否跳过,首个模型的前半部分总为 False
                    )
                    for layer in range(transformer_num_layers // 2 - 1)  # 根据层数生成 HunyuanDiTBlock
                ]
            )
            # 初始化输入层
            self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))  
            # 根据模块数量创建控制网络块
            for _ in range(len(self.blocks)):
                controlnet_block = nn.Linear(hidden_size, hidden_size)  # 初始化控制网络块
                controlnet_block = zero_module(controlnet_block)  # 零初始化
                self.controlnet_blocks.append(controlnet_block)  # 添加到控制网络模块列表
    
        @property  # 声明为属性
    # 定义一个返回注意力处理器的函数,返回值为字典类型,字典包含模型中所有的注意力处理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 的注意力处理器: 一个字典,包含模型中所有注意力处理器,按其权重名称索引。
        """
        # 初始化一个空字典,用于存储处理器
        processors = {}

        # 定义递归函数,将注意力处理器添加到字典中
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否有 "get_processor" 方法
            if hasattr(module, "get_processor"):
                # 如果有,调用该方法并将处理器存入字典
                processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用函数,将子模块的处理器添加到字典中
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新后的处理器字典
            return processors

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 递归调用函数,添加所有子模块的处理器
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有处理器的字典
        return processors

    # 定义一个设置注意力处理器的函数
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。

        参数:
            processor (`dict` of `AttentionProcessor` 或 `AttentionProcessor`):
                实例化的处理器类或处理器类字典,将作为所有 `Attention` 层的处理器。如果 `processor` 是字典,键需要定义
                对应的交叉注意力处理器的路径。强烈建议在设置可训练的注意力处理器时使用。
        """
        # 获取当前注意力处理器字典的键数量
        count = len(self.attn_processors.keys())

        # 检查传入的处理器字典数量是否与当前注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器的数量 {len(processor)} 与注意力层的数量: {count} 不匹配。请确保传入 {count} 个处理器类。"
            )

        # 定义递归函数,为模块设置注意力处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否有 "set_processor" 方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中获取对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用函数,为子模块设置处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 递归调用函数,设置所有子模块的处理器
            fn_recursive_attn_processor(name, module, processor)

    # 定义一个类方法,从 transformer 创建对象
    @classmethod
    def from_transformer(
        cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
    # 开始方法定义
        ):
            # 获取变换器的配置
            config = transformer.config
            # 获取激活函数
            activation_fn = config.activation_fn
            # 获取注意力头的维度
            attention_head_dim = config.attention_head_dim
            # 获取交叉注意力的维度
            cross_attention_dim = config.cross_attention_dim
            # 获取 T5 模型的交叉注意力维度
            cross_attention_dim_t5 = config.cross_attention_dim_t5
            # 获取隐藏层的大小
            hidden_size = config.hidden_size
            # 获取输入通道的数量
            in_channels = config.in_channels
            # 获取多层感知器的比率
            mlp_ratio = config.mlp_ratio
            # 获取注意力头的数量
            num_attention_heads = config.num_attention_heads
            # 获取补丁的大小
            patch_size = config.patch_size
            # 获取样本的大小
            sample_size = config.sample_size
            # 获取文本的长度
            text_len = config.text_len
            # 获取 T5 模型的文本长度
            text_len_t5 = config.text_len_t5
    
            # 设置条件通道
            conditioning_channels = conditioning_channels
            # 设置变换器层数,如果未提供,则使用配置中的默认值
            transformer_num_layers = transformer_num_layers or config.transformer_num_layers
    
            # 实例化 ControlNet 对象,传入多个参数
            controlnet = cls(
                conditioning_channels=conditioning_channels,
                transformer_num_layers=transformer_num_layers,
                activation_fn=activation_fn,
                attention_head_dim=attention_head_dim,
                cross_attention_dim=cross_attention_dim,
                cross_attention_dim_t5=cross_attention_dim_t5,
                hidden_size=hidden_size,
                in_channels=in_channels,
                mlp_ratio=mlp_ratio,
                num_attention_heads=num_attention_heads,
                patch_size=patch_size,
                sample_size=sample_size,
                text_len=text_len,
                text_len_t5=text_len_t5,
            )
            # 如果需要从变换器加载权重
            if load_weights_from_transformer:
                # 加载状态字典,忽略缺失的键
                key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
                # 记录警告,显示缺失的键
                logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
            # 返回创建的 ControlNet 对象
            return controlnet
    
        # 定义前向传播方法
        def forward(
            # 隐藏状态输入
            hidden_states,
            # 时间步长
            timestep,
            # 控制网条件张量
            controlnet_cond: torch.Tensor,
            # 条件缩放因子,默认值为 1.0
            conditioning_scale: float = 1.0,
            # 编码器的隐藏状态,可选
            encoder_hidden_states=None,
            # 文本嵌入的掩码,可选
            text_embedding_mask=None,
            # T5 编码器的隐藏状态,可选
            encoder_hidden_states_t5=None,
            # T5 文本嵌入的掩码,可选
            text_embedding_mask_t5=None,
            # 图像元数据的大小,可选
            image_meta_size=None,
            # 风格参数,可选
            style=None,
            # 图像旋转嵌入,可选
            image_rotary_emb=None,
            # 是否返回字典格式的输出,默认值为 True
            return_dict=True,
# HunyuanDiT2DMultiControlNetModel 类,用于封装多个 HunyuanDiT2DControlNetModel 实例
class HunyuanDiT2DMultiControlNetModel(ModelMixin):
    r"""
    `HunyuanDiT2DMultiControlNetModel` 是用于 Multi-HunyuanDiT2DControlNetModel 的封装类

    该模块为多个 `HunyuanDiT2DControlNetModel` 实例提供封装。`forward()` API 设计上与 `HunyuanDiT2DControlNetModel` 兼容。

    参数:
        controlnets (`List[HunyuanDiT2DControlNetModel]`):
            在去噪过程中为 unet 提供额外的条件。必须将多个 `HunyuanDiT2DControlNetModel` 作为列表设置。
    """

    # 初始化方法,接收控制网络列表并调用父类构造函数
    def __init__(self, controlnets):
        super().__init__()  # 调用父类构造函数以初始化基类
        self.nets = nn.ModuleList(controlnets)  # 将控制网络列表封装为一个可训练的模块列表

    # 前向传播方法,处理输入并生成输出
    def forward(
        self,
        hidden_states,  # 输入的隐藏状态
        timestep,  # 当前时间步
        controlnet_cond: torch.Tensor,  # 控制网络的条件张量
        conditioning_scale: float = 1.0,  # 条件缩放因子,默认值为 1.0
        encoder_hidden_states=None,  # 可选的编码器隐藏状态
        text_embedding_mask=None,  # 文本嵌入的掩码
        encoder_hidden_states_t5=None,  # 可选的 T5 编码器隐藏状态
        text_embedding_mask_t5=None,  # T5 文本嵌入的掩码
        image_meta_size=None,  # 图像元数据大小
        style=None,  # 样式信息
        image_rotary_emb=None,  # 图像旋转嵌入
        return_dict=True,  # 是否以字典形式返回结果,默认为 True
    ):
        """
        [`HunyuanDiT2DControlNetModel`] 的前向传播方法。

        参数:
        hidden_states (`torch.Tensor`,形状为 `(batch size, dim, height, width)`):
            输入张量。
        timestep ( `torch.LongTensor`,*可选*):
            用于指示去噪步骤。
        controlnet_cond ( `torch.Tensor` ):
            ControlNet 的条件输入。
        conditioning_scale ( `float` ):
            指示条件的比例。
        encoder_hidden_states ( `torch.Tensor`,形状为 `(batch size, sequence len, embed dims)`,*可选*):
            交叉注意力层的条件嵌入。这是 `BertModel` 的输出。
        text_embedding_mask: torch.Tensor
            形状为 `(batch, key_tokens)` 的注意力掩码,应用于 `encoder_hidden_states`。这是 `BertModel` 的输出。
        encoder_hidden_states_t5 ( `torch.Tensor`,形状为 `(batch size, sequence len, embed dims)`,*可选*):
            交叉注意力层的条件嵌入。这是 T5 文本编码器的输出。
        text_embedding_mask_t5: torch.Tensor
            形状为 `(batch, key_tokens)` 的注意力掩码,应用于 `encoder_hidden_states`。这是 T5 文本编码器的输出。
        image_meta_size (torch.Tensor):
            条件嵌入,指示图像大小
        style: torch.Tensor:
            条件嵌入,指示样式
        image_rotary_emb (`torch.Tensor`):
            在注意力计算中应用于查询和键张量的图像旋转嵌入。
        return_dict: bool
            是否返回字典。
        """
        # 遍历 controlnet_cond、conditioning_scale 和自有网络的组合
        for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
            # 调用 controlnet 处理输入,返回区块样本
            block_samples = controlnet(
                hidden_states=hidden_states,  # 输入隐藏状态
                timestep=timestep,  # 输入时间步
                controlnet_cond=image,  # 输入图像条件
                conditioning_scale=scale,  # 输入条件比例
                encoder_hidden_states=encoder_hidden_states,  # 输入 BERT 编码的隐藏状态
                text_embedding_mask=text_embedding_mask,  # 输入 BERT 的注意力掩码
                encoder_hidden_states_t5=encoder_hidden_states_t5,  # 输入 T5 编码的隐藏状态
                text_embedding_mask_t5=text_embedding_mask_t5,  # 输入 T5 的注意力掩码
                image_meta_size=image_meta_size,  # 输入图像元数据大小
                style=style,  # 输入样式条件
                image_rotary_emb=image_rotary_emb,  # 输入图像旋转嵌入
                return_dict=return_dict,  # 指示是否返回字典
            )

            # 合并样本
            if i == 0:  # 如果是第一个样本
                control_block_samples = block_samples  # 初始化样本
            else:  # 如果不是第一个样本
                # 合并现有样本和新样本
                control_block_samples = [
                    control_block_sample + block_sample  # 对应位置样本相加
                    for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
                ]
                control_block_samples = (control_block_samples,)  # 转换为元组

        # 返回合并后的样本
        return control_block_samples

.\diffusers\models\controlnet_sd3.py

# 版权所有 2024 Stability AI, HuggingFace 团队和 InstantX 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)进行许可;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下位置获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是以“原样”基础分发的,
# 不提供任何形式的保证或条件,无论是明示或暗示的。
# 有关许可证的特定权限和限制,请参见许可证。

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 从 typing 模块导入类型提示的相关类型
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 PyTorch 库及其神经网络模块
import torch
import torch.nn as nn

# 导入配置和注册功能相关的模块
from ..configuration_utils import ConfigMixin, register_to_config
# 导入模型加载的混合接口
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
# 导入联合变换器块的定义
from ..models.attention import JointTransformerBlock
# 导入注意力处理相关的模块
from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
# 导入变换器 2D 模型输出的定义
from ..models.modeling_outputs import Transformer2DModelOutput
# 导入模型的通用功能混合接口
from ..models.modeling_utils import ModelMixin
# 导入工具函数和常量
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
# 导入控制网络相关的基础输出和零模块
from .controlnet import BaseOutput, zero_module
# 导入组合时间步文本投影嵌入和补丁嵌入的定义
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed

# 创建日志记录器实例,用于记录信息和调试
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义数据类 SD3ControlNetOutput,用于存储控制网络块的样本输出
@dataclass
class SD3ControlNetOutput(BaseOutput):
    # 控制网络块的样本,使用元组存储张量
    controlnet_block_samples: Tuple[torch.Tensor]

# 定义 SD3ControlNetModel 类,集成多种混合接口以实现模型功能
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
    # 支持梯度检查点,允许节省内存
    _supports_gradient_checkpointing = True

    @register_to_config
    # 初始化方法,设置模型的各种参数,提供默认值
    def __init__(
        self,
        sample_size: int = 128,  # 输入样本大小
        patch_size: int = 2,  # 补丁大小
        in_channels: int = 16,  # 输入通道数
        num_layers: int = 18,  # 模型层数
        attention_head_dim: int = 64,  # 注意力头的维度
        num_attention_heads: int = 18,  # 注意力头的数量
        joint_attention_dim: int = 4096,  # 联合注意力的维度
        caption_projection_dim: int = 1152,  # 标题投影的维度
        pooled_projection_dim: int = 2048,  # 池化投影的维度
        out_channels: int = 16,  # 输出通道数
        pos_embed_max_size: int = 96,  # 位置嵌入的最大尺寸
    ):
        # 初始化父类
        super().__init__()
        # 默认输出通道设置为输入通道
        default_out_channels = in_channels
        # 输出通道为指定值或默认值
        self.out_channels = out_channels if out_channels is not None else default_out_channels
        # 内部维度等于注意力头数量乘以每个头的维度
        self.inner_dim = num_attention_heads * attention_head_dim

        # 创建位置嵌入对象,用于处理图像补丁
        self.pos_embed = PatchEmbed(
            height=sample_size,  # 输入图像高度
            width=sample_size,   # 输入图像宽度
            patch_size=patch_size,  # 图像补丁大小
            in_channels=in_channels,  # 输入通道数量
            embed_dim=self.inner_dim,  # 嵌入维度
            pos_embed_max_size=pos_embed_max_size,  # 最大位置嵌入大小
        )
        # 创建时间和文本的联合嵌入
        self.time_text_embed = CombinedTimestepTextProjEmbeddings(
            embedding_dim=self.inner_dim,  # 嵌入维度
            pooled_projection_dim=pooled_projection_dim  # 聚合投影维度
        )
        # 定义上下文嵌入的线性层
        self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)

        # 注意力头维度加倍以适应混合
        # 需要在实际检查点中处理
        self.transformer_blocks = nn.ModuleList(
            [
                # 创建多个联合变换块
                JointTransformerBlock(
                    dim=self.inner_dim,  # 块的维度
                    num_attention_heads=num_attention_heads,  # 注意力头数量
                    attention_head_dim=self.config.attention_head_dim,  # 每个头的维度
                    context_pre_only=False,  # 是否仅上下文先行
                )
                for i in range(num_layers)  # 根据层数生成块
            ]
        )

        # 控制网络块
        self.controlnet_blocks = nn.ModuleList([])  # 初始化空的控制网络块列表
        for _ in range(len(self.transformer_blocks)):  # 根据变换块数量创建控制网络块
            controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)  # 创建线性层
            controlnet_block = zero_module(controlnet_block)  # 零化模块以初始化
            self.controlnet_blocks.append(controlnet_block)  # 添加到控制网络块列表
        # 创建位置嵌入输入对象
        pos_embed_input = PatchEmbed(
            height=sample_size,  # 输入图像高度
            width=sample_size,   # 输入图像宽度
            patch_size=patch_size,  # 图像补丁大小
            in_channels=in_channels,  # 输入通道数量
            embed_dim=self.inner_dim,  # 嵌入维度
            pos_embed_type=None,  # 不使用位置嵌入类型
        )
        # 零化位置嵌入输入
        self.pos_embed_input = zero_module(pos_embed_input)

        # 关闭梯度检查点
        self.gradient_checkpointing = False

    # 从 diffusers.models.unets.unet_3d_condition.UNet3DConditionModel 复制的启用前向分块方法
    # 定义一个方法,启用前馈层的分块处理
        def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
            """
            设置注意力处理器使用前馈分块。
            
            参数:
                chunk_size (`int`, *optional*):
                    前馈层的分块大小。如果未指定,将对每个维度为`dim`的张量单独运行前馈层。
                dim (`int`, *optional*, defaults to `0`):
                    应该进行前馈计算的维度。可以选择dim=0(批次)或dim=1(序列长度)。
            """
            # 检查dim是否在允许的范围内
            if dim not in [0, 1]:
                raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
    
            # 默认的分块大小为1
            chunk_size = chunk_size or 1
    
            # 定义一个递归函数,处理每个模块的前馈分块设置
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模块有设置前馈分块的方法,则调用它
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 递归处理子模块
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 对当前对象的所有子模块应用分块设置
            for module in self.children():
                fn_recursive_feed_forward(module, chunk_size, dim)
    
        @property
        # 从其他模型复制的属性,返回注意力处理器
        def attn_processors(self) -> Dict[str, AttentionProcessor]:
            r"""
            返回:
                `dict` 注意力处理器:包含模型中所有注意力处理器的字典,按权重名称索引。
            """
            # 定义一个空字典来存储处理器
            processors = {}
    
            # 定义递归函数,添加处理器到字典中
            def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
                # 如果模块有获取处理器的方法,则将其添加到字典中
                if hasattr(module, "get_processor"):
                    processors[f"{name}.processor"] = module.get_processor()
    
                # 递归处理子模块
                for sub_name, child in module.named_children():
                    fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
                return processors
    
            # 对当前对象的所有子模块添加处理器
            for name, module in self.named_children():
                fn_recursive_add_processors(name, module, processors)
    
            # 返回所有处理器的字典
            return processors
    
        # 从其他模型复制的设置方法
    # 定义设置注意力处理器的方法,接收一个注意力处理器或处理器字典
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。

        参数:
            processor (`dict` of `AttentionProcessor` 或 `AttentionProcessor`):
                实例化的处理器类或将作为处理器设置到**所有** `Attention` 层的处理器类字典。

                如果 `processor` 是字典,键需要定义对应的交叉注意力处理器的路径。当设置可训练的注意力处理器时,强烈建议使用字典。

        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())

        # 如果传入的是字典且字典的长度与当前处理器数量不匹配,则抛出错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )

        # 定义递归处理注意力处理器的函数
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果传入的不是字典,则直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中获取对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历子模块,递归调用自身
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 对每个子模块调用递归处理器设置函数
            fn_recursive_attn_processor(name, module, processor)

    # 从 diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections 复制的方法
    def fuse_qkv_projections(self):
        """
        启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)被融合。
        对于交叉注意力模块,键和值的投影矩阵被融合。

        <提示 警告={true}>

        此 API 是 🧪 实验性的。

        </提示>
        """
        # 初始化原始注意力处理器为 None
        self.original_attn_processors = None

        # 检查所有注意力处理器,确保没有添加的 KV 投影
        for _, attn_processor in self.attn_processors.items():
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` 不支持具有添加的 KV 投影的模型。")

        # 保存当前的注意力处理器以备后用
        self.original_attn_processors = self.attn_processors

        # 遍历模型中的所有模块
        for module in self.modules():
            # 如果模块是 Attention 类型
            if isinstance(module, Attention):
                # 融合投影矩阵
                module.fuse_projections(fuse=True)

        # 设置新的融合注意力处理器
        self.set_attn_processor(FusedJointAttnProcessor2_0())

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 复制的方法
    # 定义一个方法来禁用已启用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """如果启用了融合的 QKV 投影,则禁用它。

        <Tip warning={true}>

        此 API 是 🧪 实验性的。

        </Tip>

        """
        # 如果原始注意力处理器不为空,则恢复到原始设置
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    # 定义一个方法来设置梯度检查点
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块有梯度检查点属性,则设置其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    # 定义一个类方法从 Transformer 创建 ControlNet 实例
    @classmethod
    def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
        # 获取 Transformer 的配置
        config = transformer.config
        # 设置层数,如果未指定则使用配置中的层数
        config["num_layers"] = num_layers or config.num_layers
        # 创建 ControlNet 实例,传入配置参数
        controlnet = cls(**config)

        # 如果需要从 Transformer 加载权重
        if load_weights_from_transformer:
            # 加载位置嵌入的权重
            controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
            # 加载时间文本嵌入的权重
            controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
            # 加载上下文嵌入器的权重
            controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
            # 加载变换器块的权重,严格模式为 False
            controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)

            # 将位置嵌入输入初始化为零模块
            controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)

        # 返回创建的 ControlNet 实例
        return controlnet

    # 定义前向传播方法
    def forward(
        # 输入的隐藏状态张量
        hidden_states: torch.FloatTensor,
        # 控制网条件张量
        controlnet_cond: torch.Tensor,
        # 条件缩放因子,默认值为 1.0
        conditioning_scale: float = 1.0,
        # 编码器隐藏状态张量,默认为 None
        encoder_hidden_states: torch.FloatTensor = None,
        # 池化投影张量,默认为 None
        pooled_projections: torch.FloatTensor = None,
        # 时间步长张量,默认为 None
        timestep: torch.LongTensor = None,
        # 联合注意力参数,默认为 None
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 是否返回字典格式的输出,默认为 True
        return_dict: bool = True,
# SD3MultiControlNetModel 类,继承自 ModelMixin
class SD3MultiControlNetModel(ModelMixin):
    r"""
    `SD3ControlNetModel` 的包装类,用于 Multi-SD3ControlNet

    该模块是多个 `SD3ControlNetModel` 实例的包装。`forward()` API 设计与 `SD3ControlNetModel` 兼容。

    参数:
        controlnets (`List[SD3ControlNetModel]`):
            在去噪过程中为 unet 提供额外的条件。必须将多个 `SD3ControlNetModel` 作为列表设置。
    """

    # 初始化函数,接收控制网列表并调用父类构造
    def __init__(self, controlnets):
        super().__init__()  # 调用父类的初始化方法
        self.nets = nn.ModuleList(controlnets)  # 将控制网列表存储为模块列表

    # 前向传播函数,接收多个输入参数以处理数据
    def forward(
        self,
        hidden_states: torch.FloatTensor,  # 隐藏状态张量
        controlnet_cond: List[torch.tensor],  # 控制网条件列表
        conditioning_scale: List[float],  # 条件缩放因子列表
        pooled_projections: torch.FloatTensor,  # 池化的投影张量
        encoder_hidden_states: torch.FloatTensor = None,  # 可选编码器隐藏状态
        timestep: torch.LongTensor = None,  # 可选时间步长
        joint_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的联合注意力参数
        return_dict: bool = True,  # 返回格式,默认为字典
    ) -> Union[SD3ControlNetOutput, Tuple]:  # 返回类型可以是输出对象或元组
        # 遍历控制网条件、缩放因子和控制网
        for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
            # 调用控制网的前向传播以获取块样本
            block_samples = controlnet(
                hidden_states=hidden_states,  # 传递隐藏状态
                timestep=timestep,  # 传递时间步长
                encoder_hidden_states=encoder_hidden_states,  # 传递编码器隐藏状态
                pooled_projections=pooled_projections,  # 传递池化投影
                controlnet_cond=image,  # 传递控制网条件
                conditioning_scale=scale,  # 传递条件缩放因子
                joint_attention_kwargs=joint_attention_kwargs,  # 传递联合注意力参数
                return_dict=return_dict,  # 传递返回格式
            )

            # 合并样本
            if i == 0:  # 如果是第一个控制网
                control_block_samples = block_samples  # 直接使用块样本
            else:  # 如果不是第一个控制网
                # 将当前块样本与之前的样本逐元素相加
                control_block_samples = [
                    control_block_sample + block_sample
                    for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
                ]
                control_block_samples = (tuple(control_block_samples),)  # 将合并结果转为元组

        # 返回合并后的控制块样本
        return control_block_samples

.\diffusers\models\controlnet_sparsectrl.py

# 版权所有 2024 HuggingFace 团队。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件
# 根据许可证分发是按“原样”基础进行的,
# 不提供任何种类的保证或条件,无论是明示或暗示的。
# 请参阅许可证以获取与权限和
# 限制相关的特定语言。

from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示所需的类型

import torch  # 导入 PyTorch 库
from torch import nn  # 从 PyTorch 导入神经网络模块
from torch.nn import functional as F  # 导入 PyTorch 神经网络功能模块,通常用于定义激活函数等

from ..configuration_utils import ConfigMixin, register_to_config  # 从父级模块导入配置混合类和注册函数
from ..loaders import FromOriginalModelMixin  # 从父级模块导入原始模型混合类
from ..utils import BaseOutput, logging  # 从父级模块导入基础输出类和日志记录工具
from .attention_processor import (  # 从当前模块导入注意力处理相关类
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入新增键值注意力处理器
    CROSS_ATTENTION_PROCESSORS,  # 导入交叉注意力处理器
    AttentionProcessor,  # 导入注意力处理器基类
    AttnAddedKVProcessor,  # 导入新增键值的注意力处理器
    AttnProcessor,  # 导入普通注意力处理器
)
from .embeddings import TimestepEmbedding, Timesteps  # 从当前模块导入时间步嵌入和时间步类
from .modeling_utils import ModelMixin  # 从当前模块导入模型混合类
from .unets.unet_2d_blocks import UNetMidBlock2DCrossAttn  # 从 2D UNet 模块导入中间块交叉注意力类
from .unets.unet_2d_condition import UNet2DConditionModel  # 从 2D UNet 模块导入条件模型
from .unets.unet_motion_model import CrossAttnDownBlockMotion, DownBlockMotion  # 从运动模型模块导入相关类

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,便于后续日志记录使用

@dataclass  # 将类标记为数据类,自动生成初始化等方法
class SparseControlNetOutput(BaseOutput):  # 定义 SparseControlNetOutput 类,继承自 BaseOutput
    """
    [`SparseControlNetModel`] 的输出。

    参数:
        down_block_res_samples (`tuple[torch.Tensor]`):
            一个包含每个下采样块在不同分辨率下激活的元组。每个张量的形状应为
            `(batch_size, channel * resolution, height // resolution, width // resolution)`。输出可用于条件
            原始 UNet 的下采样激活。
        mid_down_block_re_sample (`torch.Tensor`):
            中间块(最低采样分辨率)的激活。每个张量的形状应为
            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`。
            输出可用于条件原始 UNet 的中间块激活。
    """

    down_block_res_samples: Tuple[torch.Tensor]  # 定义下采样块结果样本的属性
    mid_block_res_sample: torch.Tensor  # 定义中间块结果样本的属性


class SparseControlNetConditioningEmbedding(nn.Module):  # 定义 SparseControlNetConditioningEmbedding 类,继承自 nn.Module
    def __init__(  # 初始化方法
        self,
        conditioning_embedding_channels: int,  # 条件嵌入通道数
        conditioning_channels: int = 3,  # 条件通道数,默认为 3
        block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),  # 块输出通道数的元组,包含多个值
    ):
        # 初始化父类,调用父类的构造函数
        super().__init__()

        # 定义输入卷积层,接受条件通道数并输出块的第一个通道数,卷积核大小为3,填充为1
        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
        # 创建一个空的模块列表,用于存储后续的卷积块
        self.blocks = nn.ModuleList([])

        # 遍历块输出通道数列表,构建卷积块
        for i in range(len(block_out_channels) - 1):
            # 当前通道数
            channel_in = block_out_channels[i]
            # 下一层的通道数
            channel_out = block_out_channels[i + 1]
            # 添加一个卷积层,输入通道数和输出通道数均为当前通道数
            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
            # 添加一个卷积层,输入通道数为当前通道数,输出通道数为下一层通道数,步幅为2
            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        # 定义输出卷积层,接受最后一个块的输出通道数并输出条件嵌入通道数,卷积核大小为3,填充为1
        self.conv_out = zero_module(
            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
        )

    # 前向传播函数,接受一个张量作为输入,返回一个张量作为输出
    def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
        # 通过输入卷积层处理条件张量,得到嵌入
        embedding = self.conv_in(conditioning)
        # 应用激活函数 SiLU
        embedding = F.silu(embedding)

        # 遍历每个卷积块,依次处理嵌入
        for block in self.blocks:
            # 通过当前卷积块处理嵌入
            embedding = block(embedding)
            # 再次应用激活函数 SiLU
            embedding = F.silu(embedding)

        # 通过输出卷积层处理嵌入,得到最终输出
        embedding = self.conv_out(embedding)
        # 返回最终输出
        return embedding
# 定义一个稀疏控制网络模型类,继承自多个混合类以获得其功能
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    """
    根据 [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
    Models](https://arxiv.org/abs/2311.16933) 的描述,定义一个稀疏控制网络模型。
    """

    # 支持梯度检查点,允许在训练时节省内存
    _supports_gradient_checkpointing = True

    # 将初始化方法注册到配置中
    @register_to_config
    def __init__(
        # 输入通道数,默认为4
        in_channels: int = 4,
        # 条件通道数,默认为4
        conditioning_channels: int = 4,
        # 是否将正弦函数翻转为余弦函数,默认为True
        flip_sin_to_cos: bool = True,
        # 频率偏移量,默认为0
        freq_shift: int = 0,
        # 下采样块的类型,默认为三个交叉注意力块和一个下块
        down_block_types: Tuple[str, ...] = (
            "CrossAttnDownBlockMotion",
            "CrossAttnDownBlockMotion",
            "CrossAttnDownBlockMotion",
            "DownBlockMotion",
        ),
        # 是否仅使用交叉注意力,默认为False
        only_cross_attention: Union[bool, Tuple[bool]] = False,
        # 每个块的输出通道数,默认为指定的四个值
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
        # 每个块的层数,默认为2
        layers_per_block: int = 2,
        # 下采样时的填充大小,默认为1
        downsample_padding: int = 1,
        # 中间块的缩放因子,默认为1
        mid_block_scale_factor: float = 1,
        # 激活函数类型,默认为"silu"
        act_fn: str = "silu",
        # 归一化的组数,默认为32
        norm_num_groups: Optional[int] = 32,
        # 归一化的epsilon值,默认为1e-5
        norm_eps: float = 1e-5,
        # 交叉注意力的维度,默认为768
        cross_attention_dim: int = 768,
        # 每个块的变换器层数,默认为1
        transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
        # 每个中间块的变换器层数,默认为None
        transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
        # 每个块的时间变换器层数,默认为1
        temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
        # 注意力头的维度,默认为8
        attention_head_dim: Union[int, Tuple[int, ...]] = 8,
        # 注意力头的数量,默认为None
        num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
        # 是否使用线性投影,默认为False
        use_linear_projection: bool = False,
        # 是否提升注意力计算精度,默认为False
        upcast_attention: bool = False,
        # ResNet时间尺度偏移,默认为"default"
        resnet_time_scale_shift: str = "default",
        # 条件嵌入的输出通道数,默认为指定的四个值
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
        # 是否全局池条件,默认为False
        global_pool_conditions: bool = False,
        # 控制网络条件通道的顺序,默认为"rgb"
        controlnet_conditioning_channel_order: str = "rgb",
        # 最大的运动序列长度,默认为32
        motion_max_seq_length: int = 32,
        # 运动部分的注意力头数量,默认为8
        motion_num_attention_heads: int = 8,
        # 是否拼接条件掩码,默认为True
        concat_conditioning_mask: bool = True,
        # 是否使用简化的条件嵌入,默认为True
        use_simplified_condition_embedding: bool = True,
    # 定义一个类方法,用于从UNet模型创建稀疏控制网络模型
    @classmethod
    def from_unet(
        cls,
        # 输入的UNet模型
        unet: UNet2DConditionModel,
        # 控制网络条件通道的顺序,默认为"rgb"
        controlnet_conditioning_channel_order: str = "rgb",
        # 条件嵌入的输出通道数,默认为指定的四个值
        conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
        # 是否从UNet加载权重,默认为True
        load_weights_from_unet: bool = True,
        # 条件通道数,默认为3
        conditioning_channels: int = 3,
    # 实例化一个 [`SparseControlNetModel`],来源于 [`UNet2DConditionModel`]。
    ) -> "SparseControlNetModel":
        r"""
        实例化一个 [`SparseControlNetModel`],源自 [`UNet2DConditionModel`]。
        
        参数:
            unet (`UNet2DConditionModel`):
                需要复制到 [`SparseControlNetModel`] 的 UNet 模型权重。所有适用的配置选项也会被复制。
        """
        # 获取 UNet 配置中的 transformer_layers_per_block,默认值为 1
        transformer_layers_per_block = (
            unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
        )
        # 获取 UNet 配置中的 down_block_types
        down_block_types = unet.config.down_block_types
    
        # 遍历每种下采样块类型
        for i in range(len(down_block_types)):
            # 检查下采样块类型是否包含 "CrossAttn"
            if "CrossAttn" in down_block_types[i]:
                # 替换为 "CrossAttnDownBlockMotion"
                down_block_types[i] = "CrossAttnDownBlockMotion"
            # 检查下采样块类型是否包含 "Down"
            elif "Down" in down_block_types[i]:
                # 替换为 "DownBlockMotion"
                down_block_types[i] = "DownBlockMotion"
            # 如果类型无效,抛出异常
            else:
                raise ValueError("Invalid `block_type` encountered. Must be a cross-attention or down block")
    
        # 创建 SparseControlNetModel 实例
        controlnet = cls(
            in_channels=unet.config.in_channels,  # 输入通道数
            conditioning_channels=conditioning_channels,  # 条件通道数
            flip_sin_to_cos=unet.config.flip_sin_to_cos,  # 是否翻转正弦到余弦
            freq_shift=unet.config.freq_shift,  # 频率偏移
            down_block_types=unet.config.down_block_types,  # 下采样块类型
            only_cross_attention=unet.config.only_cross_attention,  # 仅使用交叉注意力
            block_out_channels=unet.config.block_out_channels,  # 块输出通道数
            layers_per_block=unet.config.layers_per_block,  # 每个块的层数
            downsample_padding=unet.config.downsample_padding,  # 下采样填充
            mid_block_scale_factor=unet.config.mid_block_scale_factor,  # 中间块缩放因子
            act_fn=unet.config.act_fn,  # 激活函数
            norm_num_groups=unet.config.norm_num_groups,  # 归一化组数
            norm_eps=unet.config.norm_eps,  # 归一化的 epsilon
            cross_attention_dim=unet.config.cross_attention_dim,  # 交叉注意力维度
            transformer_layers_per_block=transformer_layers_per_block,  # 每个块的 transformer 层数
            attention_head_dim=unet.config.attention_head_dim,  # 注意力头维度
            num_attention_heads=unet.config.num_attention_heads,  # 注意力头数量
            use_linear_projection=unet.config.use_linear_projection,  # 是否使用线性投影
            upcast_attention=unet.config.upcast_attention,  # 是否上升注意力
            resnet_time_scale_shift=unet.config.resnet_time_scale_shift,  # ResNet 时间缩放偏移
            conditioning_embedding_out_channels=conditioning_embedding_out_channels,  # 条件嵌入输出通道
            controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,  # 控制网条件通道顺序
        )
    
        # 如果需要从 UNet 加载权重
        if load_weights_from_unet:
            # 加载输入卷积层的权重
            controlnet.conv_in.load_state_dict(unet.conv_in.state_dict(), strict=False)
            # 加载时间投影层的权重
            controlnet.time_proj.load_state_dict(unet.time_proj.state_dict(), strict=False)
            # 加载时间嵌入层的权重
            controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict(), strict=False)
            # 加载下采样块的权重
            controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
            # 加载中间块的权重
            controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
    
        # 返回控制网模型实例
        return controlnet
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回:
            `dict` 的注意力处理器:一个字典,包含模型中使用的所有注意力处理器,
            按其权重名称索引。
        """
        # 初始化一个空字典,用于存储处理器
        processors = {}
    
        # 定义一个递归函数,用于添加处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否具有获取处理器的函数
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,使用名称作为键
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用该函数以处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            return processors
    
        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 调用递归函数以添加所有处理器
            fn_recursive_add_processors(name, module, processors)
    
        # 返回所有处理器的字典
        return processors
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制的
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或处理器类的字典,将作为 **所有** `Attention` 层的处理器设置。
    
                如果 `processor` 是字典,键需要定义相应的交叉注意力处理器的路径。
                在设置可训练的注意力处理器时,强烈推荐这样做。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 如果传入的是字典,检查其长度是否与注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传递了一个处理器字典,但处理器数量 {len(processor)} 与"
                f" 注意力层数量 {count} 不匹配。请确保传递 {count} 个处理器类。"
            )
    
        # 定义一个递归函数,用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否具有设置处理器的函数
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用该函数以处理子模块
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 调用递归函数以设置所有处理器
            fn_recursive_attn_processor(name, module, processor)
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制的
    # 设置默认的注意力处理器
    def set_default_attn_processor(self):
        # 禁用自定义注意力处理器,并设置默认的注意力实现
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        # 检查所有注意力处理器是否属于添加的键值注意力处理器
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建添加的键值注意力处理器实例
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建标准注意力处理器实例
            processor = AttnProcessor()
        else:
            # 如果注意力处理器类型不符合要求,则抛出错误
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )
    
        # 设置所选的注意力处理器
        self.set_attn_processor(processor)
    
    # 从 diffusers.models.unets.unet_2d_condition 中复制的设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 检查模块是否属于指定的类型
        if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
            # 设置模块的梯度检查点属性
            module.gradient_checkpointing = value
    
    # 前向传播方法
    def forward(
        self,
        # 输入的样本张量
        sample: torch.Tensor,
        # 时间步长,可以是张量、浮点数或整数
        timestep: Union[torch.Tensor, float, int],
        # 编码器的隐藏状态张量
        encoder_hidden_states: torch.Tensor,
        # 控制网络条件张量
        controlnet_cond: torch.Tensor,
        # 条件缩放因子,默认为 1.0
        conditioning_scale: float = 1.0,
        # 可选的时间步条件张量
        timestep_cond: Optional[torch.Tensor] = None,
        # 可选的注意力掩码张量
        attention_mask: Optional[torch.Tensor] = None,
        # 可选的交叉注意力参数字典
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可选的条件掩码张量
        conditioning_mask: Optional[torch.Tensor] = None,
        # 猜测模式,默认为 False
        guess_mode: bool = False,
        # 返回字典,默认为 True
        return_dict: bool = True,
# 从 diffusers.models.controlnet.zero_module 复制而来
def zero_module(module: nn.Module) -> nn.Module:
    # 遍历传入模块的所有参数
    for p in module.parameters():
        # 将每个参数初始化为零
        nn.init.zeros_(p)
    # 返回已初始化的模块
    return module

.\diffusers\models\controlnet_xs.py

# 版权信息,声明该文件归 HuggingFace 团队所有,所有权利保留
# 
# 根据 Apache 许可证第 2.0 版("许可证")进行授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,否则根据许可证分发的软件是按“原样”基础进行分发,
# 不提供任何形式的担保或条件,无论是明示或暗示的。
# 请参阅许可证以获取有关权限和限制的特定信息。
from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from math import gcd  # 从 math 模块导入 gcd 函数,用于计算最大公约数
from typing import Any, Dict, List, Optional, Tuple, Union  # 导入类型提示相关的类型

import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 的 checkpoint 工具,用于保存内存
from torch import Tensor, nn  # 从 torch 模块导入 Tensor 类和 nn 模块

from ..configuration_utils import ConfigMixin, register_to_config  # 从上层模块导入配置相关的类和函数
from ..utils import BaseOutput, is_torch_version, logging  # 从上层模块导入工具类和函数
from ..utils.torch_utils import apply_freeu  # 从上层模块导入特定的 PyTorch 工具函数
from .attention_processor import (  # 从当前包导入注意力处理器相关的类
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from .controlnet import ControlNetConditioningEmbedding  # 从当前包导入 ControlNet 的条件嵌入类
from .embeddings import TimestepEmbedding, Timesteps  # 从当前包导入时间步嵌入相关的类
from .modeling_utils import ModelMixin  # 从当前包导入模型混合类
from .unets.unet_2d_blocks import (  # 从当前包导入 2D U-Net 模块相关的类
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    Downsample2D,
    ResnetBlock2D,
    Transformer2DModel,
    UNetMidBlock2DCrossAttn,
    Upsample2D,
)
from .unets.unet_2d_condition import UNet2DConditionModel  # 从当前包导入带条件的 2D U-Net 模型类


logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器


@dataclass  # 将该类声明为数据类
class ControlNetXSOutput(BaseOutput):  # 定义 ControlNetXSOutput 类,继承自 BaseOutput
    """
    [`UNetControlNetXSModel`] 的输出。

    参数:
        sample (`Tensor`,形状为 `(batch_size, num_channels, height, width)`):
            `UNetControlNetXSModel` 的输出。与 `ControlNetOutput` 不同,此输出不是要与基础模型输出相加,而是已经是最终输出。
    """

    sample: Tensor = None  # 定义一个可选的 Tensor 属性 sample,默认为 None


class DownBlockControlNetXSAdapter(nn.Module):  # 定义 DownBlockControlNetXSAdapter 类,继承自 nn.Module
    """与基础模型的对应组件一起形成 `ControlNetXSCrossAttnDownBlock2D` 的组件"""

    def __init__(  # 定义初始化方法
        self,
        resnets: nn.ModuleList,  # 传入一个 ResNet 组件的模块列表
        base_to_ctrl: nn.ModuleList,  # 传入基础模型到 ControlNet 的模块列表
        ctrl_to_base: nn.ModuleList,  # 传入 ControlNet 到基础模型的模块列表
        attentions: Optional[nn.ModuleList] = None,  # 可选的注意力模块列表,默认为 None
        downsampler: Optional[nn.Conv2d] = None,  # 可选的下采样模块,默认为 None
    ):
        super().__init__()  # 调用父类的初始化方法
        self.resnets = resnets  # 保存 ResNet 组件列表
        self.base_to_ctrl = base_to_ctrl  # 保存基础模型到 ControlNet 的模块列表
        self.ctrl_to_base = ctrl_to_base  # 保存 ControlNet 到基础模型的模块列表
        self.attentions = attentions  # 保存注意力模块列表
        self.downsamplers = downsampler  # 保存下采样模块


class MidBlockControlNetXSAdapter(nn.Module):  # 定义 MidBlockControlNetXSAdapter 类,继承自 nn.Module
    """与基础模型的对应组件一起形成 `ControlNetXSCrossAttnMidBlock2D` 的组件"""
    # 初始化类的构造函数
        def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList):
            # 调用父类的构造函数
            super().__init__()
            # 将传入的 midblock 参数赋值给实例变量 midblock
            self.midblock = midblock
            # 将传入的 base_to_ctrl 参数赋值给实例变量 base_to_ctrl
            self.base_to_ctrl = base_to_ctrl
            # 将传入的 ctrl_to_base 参数赋值给实例变量 ctrl_to_base
            self.ctrl_to_base = ctrl_to_base
# 定义一个名为 UpBlockControlNetXSAdapter 的类,继承自 nn.Module
class UpBlockControlNetXSAdapter(nn.Module):
    """与基础模型的相应组件一起组成 `ControlNetXSCrossAttnUpBlock2D`"""

    # 初始化方法,接受一个控制到基础的模块列表
    def __init__(self, ctrl_to_base: nn.ModuleList):
        super().__init__()  # 调用父类的初始化方法
        self.ctrl_to_base = ctrl_to_base  # 将传入的控制到基础模块列表保存为实例变量


# 定义一个函数,获取下行块适配器
def get_down_block_adapter(
    base_in_channels: int,  # 基础输入通道数
    base_out_channels: int,  # 基础输出通道数
    ctrl_in_channels: int,  # 控制输入通道数
    ctrl_out_channels: int,  # 控制输出通道数
    temb_channels: int,  # 时间嵌入通道数
    max_norm_num_groups: Optional[int] = 32,  # 最大归一化组数
    has_crossattn=True,  # 是否使用交叉注意力
    transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,  # 每个块的变换器层数
    num_attention_heads: Optional[int] = 1,  # 注意力头数量
    cross_attention_dim: Optional[int] = 1024,  # 交叉注意力维度
    add_downsample: bool = True,  # 是否添加下采样
    upcast_attention: Optional[bool] = False,  # 是否上调注意力
    use_linear_projection: Optional[bool] = True,  # 是否使用线性投影
):
    num_layers = 2  # 仅支持 sd + sdxl

    resnets = []  # 存储 ResNet 块的列表
    attentions = []  # 存储注意力模型的列表
    ctrl_to_base = []  # 存储控制到基础的卷积层列表
    base_to_ctrl = []  # 存储基础到控制的卷积层列表

    # 如果传入的是整数,则将其转换为与层数相同的列表
    if isinstance(transformer_layers_per_block, int):
        transformer_layers_per_block = [transformer_layers_per_block] * num_layers

    # 遍历每层以构建网络结构
    for i in range(num_layers):
        # 第一层使用基础输入通道数,后续层使用基础输出通道数
        base_in_channels = base_in_channels if i == 0 else base_out_channels
        # 第一层使用控制输入通道数,后续层使用控制输出通道数
        ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels

        # 在应用 ResNet/注意力之前,从基础到控制的通道信息进行连接
        # 连接不需要更改通道数量
        base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))

        resnets.append(
            ResnetBlock2D(
                in_channels=ctrl_in_channels + base_in_channels,  # 从基础连接到控制的信息
                out_channels=ctrl_out_channels,  # 控制输出通道数
                temb_channels=temb_channels,  # 时间嵌入通道数
                groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups),  # 计算组数
                groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),  # 计算输出组数
                eps=1e-5,  # 小常数以避免除零
            )
        )

        # 如果需要交叉注意力,则添加对应的模型
        if has_crossattn:
            attentions.append(
                Transformer2DModel(
                    num_attention_heads,  # 注意力头数量
                    ctrl_out_channels // num_attention_heads,  # 每个头的通道数
                    in_channels=ctrl_out_channels,  # 输入通道数
                    num_layers=transformer_layers_per_block[i],  # 当前块的变换器层数
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                    use_linear_projection=use_linear_projection,  # 是否使用线性投影
                    upcast_attention=upcast_attention,  # 是否上调注意力
                    norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),  # 计算归一化组数
                )
            )

        # 在应用 ResNet/注意力之后,从控制到基础的通道信息进行相加
        # 相加需要更改通道数量
        ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))  # 添加控制到基础的卷积层
    # 判断是否需要进行下采样
    if add_downsample:
        # 在应用下采样器之前,将 base 的信息与 control 的信息连接
        # 连接操作不需要改变通道数量
        base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))

        # 创建下采样器对象,输入通道为控制通道和基础通道之和,使用卷积,输出通道为控制通道数量,命名为 "op"
        downsamplers = Downsample2D(
            ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
        )

        # 在应用下采样器之后,将控制的数据信息添加到基础数据中
        # 添加操作需要改变通道数量
        ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
    else:
        # 如果不需要下采样,则将 downsamplers 设置为 None
        downsamplers = None

    # 创建下块控制网络适配器,传入残差网络和连接的控制基础模块
    down_block_components = DownBlockControlNetXSAdapter(
        resnets=nn.ModuleList(resnets),
        base_to_ctrl=nn.ModuleList(base_to_ctrl),
        ctrl_to_base=nn.ModuleList(ctrl_to_base),
    )

    # 如果存在交叉注意力,则将注意力模块添加到下块组件中
    if has_crossattn:
        down_block_components.attentions = nn.ModuleList(attentions)
    # 如果下采样器不为 None,则将下采样器添加到下块组件中
    if downsamplers is not None:
        down_block_components.downsamplers = downsamplers

    # 返回下块组件
    return down_block_components
# 定义一个函数,用于获取中间块适配器,接受多个参数以配置其行为
def get_mid_block_adapter(
    # 基础通道数
    base_channels: int,
    # 控制通道数
    ctrl_channels: int,
    # 可选的时间嵌入通道数
    temb_channels: Optional[int] = None,
    # 最大归一化组数量,默认为32
    max_norm_num_groups: Optional[int] = 32,
    # 每个块的变换层数,默认为1
    transformer_layers_per_block: int = 1,
    # 可选的注意力头数量,默认为1
    num_attention_heads: Optional[int] = 1,
    # 可选的交叉注意力维度,默认为1024
    cross_attention_dim: Optional[int] = 1024,
    # 是否提升注意力精度,默认为False
    upcast_attention: bool = False,
    # 是否使用线性投影,默认为True
    use_linear_projection: bool = True,
):
    # 在中间块应用之前,从基础通道到控制通道的信息进行拼接
    # 拼接不需要改变通道数
    base_to_ctrl = make_zero_conv(base_channels, base_channels)

    # 创建一个中间块对象,使用交叉注意力
    midblock = UNetMidBlock2DCrossAttn(
        # 设置每个块的变换层数
        transformer_layers_per_block=transformer_layers_per_block,
        # 输入通道为控制通道和基础通道的和
        in_channels=ctrl_channels + base_channels,
        # 输出通道为控制通道数
        out_channels=ctrl_channels,
        # 时间嵌入通道数
        temb_channels=temb_channels,
        # 归一化组数量必须能够同时整除输入和输出通道数
        resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
        # 交叉注意力的维度
        cross_attention_dim=cross_attention_dim,
        # 注意力头的数量
        num_attention_heads=num_attention_heads,
        # 是否使用线性投影
        use_linear_projection=use_linear_projection,
        # 是否提升注意力精度
        upcast_attention=upcast_attention,
    )

    # 在中间块应用之后,从控制通道到基础通道的信息进行相加
    # 相加需要改变通道数
    ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)

    # 返回一个中间块控制适配器的实例,包含拼接层、中间块和相加层
    return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base)


# 定义一个函数,用于获取上块适配器,接受输出通道数、前一层输出通道数和控制跳跃通道
def get_up_block_adapter(
    # 输出通道数
    out_channels: int,
    # 前一层的输出通道数
    prev_output_channel: int,
    # 控制跳跃通道列表
    ctrl_skip_channels: List[int],
):
    # 初始化控制到基础的卷积层列表
    ctrl_to_base = []
    # 设置层数为3,仅支持 sd 和 sdxl
    num_layers = 3  
    # 循环构建每一层的控制到基础卷积层
    for i in range(num_layers):
        # 第一层使用前一层输出通道,其他层使用输出通道
        resnet_in_channels = prev_output_channel if i == 0 else out_channels
        # 将控制跳跃通道与当前输入通道连接
        ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))

    # 返回一个上块控制适配器的实例,使用nn.ModuleList管理控制到基础卷积层
    return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))


# 定义一个控制网络适配器类,继承自ModelMixin和ConfigMixin
class ControlNetXSAdapter(ModelMixin, ConfigMixin):
    r"""
    控制网络适配器模型。使用时,将其传递给 `UNetControlNetXSModel`(以及一个
    `UNet2DConditionModel` 基础模型)。

    该模型继承自[`ModelMixin`]和[`ConfigMixin`]。请查看超类文档,了解其通用
    方法(例如下载或保存)。

    与`UNetControlNetXSModel`一样,`ControlNetXSAdapter`与StableDiffusion和StableDiffusion-XL兼容。其
    默认参数与StableDiffusion兼容。
    # 参数部分说明
    Parameters:
        # conditioning_channels: 条件输入的通道数(例如:一张图像),默认值为3
        conditioning_channels (`int`, defaults to 3):
            # 条件图像的通道顺序。若为 `bgr`,则转换为 `rgb`
            conditioning_channel_order (`str`, defaults to `"rgb"`):
            # `controlnet_cond_embedding` 层中每个块的输出通道的元组,默认值为 (16, 32, 96, 256)
            conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
            # time_embedding_mix: 如果为0,则仅使用控制适配器的时间嵌入;如果为1,则仅使用基础 UNet 的时间嵌入;否则,两者结合
            time_embedding_mix (`float`, defaults to 1.0):
            # learn_time_embedding: 是否应学习时间嵌入,若是则 `UNetControlNetXSModel` 会结合基础模型和控制适配器的时间嵌入,若否则只使用基础模型的时间嵌入
            learn_time_embedding (`bool`, defaults to `False`):
            # num_attention_heads: 注意力头的数量,默认值为 [4]
            num_attention_heads (`list[int]`, defaults to `[4]`):
            # block_out_channels: 每个块的输出通道的元组,默认值为 [4, 8, 16, 16]
            block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`):
            # base_block_out_channels: 基础 UNet 中每个块的输出通道的元组,默认值为 [320, 640, 1280, 1280]
            base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`):
            # cross_attention_dim: 跨注意力特征的维度,默认值为 1024
            cross_attention_dim (`int`, defaults to 1024):
            # down_block_types: 要使用的下采样块的元组,默认值为 ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
            down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`):
            # sample_size: 输入/输出样本的高度和宽度,默认值为 96
            sample_size (`int`, defaults to 96):
            # transformer_layers_per_block: 每个块的变换器块数量,默认值为 1,仅与某些块相关
            transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
            # upcast_attention: 是否应始终提升注意力计算的精度,默认值为 True
            upcast_attention (`bool`, defaults to `True`):
            # max_norm_num_groups: 分组归一化中的最大组数,默认值为 32,实际数量为不大于 max_norm_num_groups 的相应通道的最大除数
            max_norm_num_groups (`int`, defaults to 32):
    # 注释部分结束
    """
    
    # 注册到配置中
    @register_to_config
    # 初始化方法,设置 ControlNetXSAdapter 的基本参数
        def __init__(
            # 条件通道数,默认为 3
            self,
            conditioning_channels: int = 3,
            # 条件通道的颜色顺序,默认为 RGB
            conditioning_channel_order: str = "rgb",
            # 输出通道数的元组,定义各层的输出通道
            conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
            # 时间嵌入混合因子,默认为 1.0
            time_embedding_mix: float = 1.0,
            # 是否学习时间嵌入,默认为 False
            learn_time_embedding: bool = False,
            # 注意力头数,默认为 4,可以是整数或整数元组
            num_attention_heads: Union[int, Tuple[int]] = 4,
            # 块输出通道的元组,定义每个块的输出通道
            block_out_channels: Tuple[int] = (4, 8, 16, 16),
            # 基础块输出通道的元组
            base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 交叉注意力维度,默认为 1024
            cross_attention_dim: int = 1024,
            # 各层的块类型元组
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            # 采样大小,默认为 96
            sample_size: Optional[int] = 96,
            # 每个块的变换器层数,可以是整数或整数元组
            transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # 是否上溢注意力,默认为 True
            upcast_attention: bool = True,
            # 最大归一化组数,默认为 32
            max_norm_num_groups: int = 32,
            # 是否使用线性投影,默认为 True
            use_linear_projection: bool = True,
        # 类方法,从 UNet 创建 ControlNetXSAdapter
        @classmethod
        def from_unet(
            cls,
            # 传入的 UNet2DConditionModel 对象
            unet: UNet2DConditionModel,
            # 尺寸比例,默认为 None
            size_ratio: Optional[float] = None,
            # 可选的块输出通道列表
            block_out_channels: Optional[List[int]] = None,
            # 可选的注意力头数列表
            num_attention_heads: Optional[List[int]] = None,
            # 是否学习时间嵌入,默认为 False
            learn_time_embedding: bool = False,
            # 时间嵌入混合因子,默认为 1.0
            time_embedding_mix: int = 1.0,
            # 条件通道数,默认为 3
            conditioning_channels: int = 3,
            # 条件通道的颜色顺序,默认为 RGB
            conditioning_channel_order: str = "rgb",
            # 输出通道数的元组,默认为 (16, 32, 96, 256)
            conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
        # 前向传播方法,处理输入参数
        def forward(self, *args, **kwargs):
            # 抛出错误,指示不能单独运行 ControlNetXSAdapter
            raise ValueError(
                "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel."
            )
# 定义一个 UNet 融合 ControlNet-XS 适配器的模型类
class UNetControlNetXSModel(ModelMixin, ConfigMixin):
    r"""
    A UNet fused with a ControlNet-XS adapter model

    此模型继承自 [`ModelMixin`] 和 [`ConfigMixin`]。有关所有模型实现的通用方法(如下载或保存),请检查超类文档。

    `UNetControlNetXSModel` 与 StableDiffusion 和 StableDiffusion-XL 兼容。其默认参数与 StableDiffusion 兼容。

    它的参数要么传递给底层的 `UNet2DConditionModel`,要么与 `ControlNetXSAdapter` 完全相同。有关详细信息,请参阅它们的文档。
    """

    # 启用梯度检查点支持
    _supports_gradient_checkpointing = True

    # 注册到配置的方法
    @register_to_config
    def __init__(
        self,
        # unet 配置
        # 样本尺寸,默认值为 96
        sample_size: Optional[int] = 96,
        # 下采样块类型的元组
        down_block_types: Tuple[str] = (
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "CrossAttnDownBlock2D",
            "DownBlock2D",
        ),
        # 上采样块类型的元组
        up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
        # 每个块的输出通道数
        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
        # 归一化的组数,默认为 32
        norm_num_groups: Optional[int] = 32,
        # 交叉注意力维度,默认为 1024
        cross_attention_dim: Union[int, Tuple[int]] = 1024,
        # 每个块的变换器层数,默认为 1
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # 注意力头的数量,默认为 8
        num_attention_heads: Union[int, Tuple[int]] = 8,
        # 附加嵌入类型,默认为 None
        addition_embed_type: Optional[str] = None,
        # 附加时间嵌入维度,默认为 None
        addition_time_embed_dim: Optional[int] = None,
        # 是否上溯注意力,默认为 True
        upcast_attention: bool = True,
        # 是否使用线性投影,默认为 True
        use_linear_projection: bool = True,
        # 时间条件投影维度,默认为 None
        time_cond_proj_dim: Optional[int] = None,
        # 类别嵌入输入维度,默认为 None
        projection_class_embeddings_input_dim: Optional[int] = None,
        # 附加控制网配置
        # 时间嵌入混合系数,默认为 1.0
        time_embedding_mix: float = 1.0,
        # 控制条件通道数,默认为 3
        ctrl_conditioning_channels: int = 3,
        # 控制条件嵌入输出通道的元组
        ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
        # 控制条件通道顺序,默认为 "rgb"
        ctrl_conditioning_channel_order: str = "rgb",
        # 是否学习时间嵌入,默认为 False
        ctrl_learn_time_embedding: bool = False,
        # 控制块输出通道的元组
        ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
        # 控制注意力头的数量,默认为 4
        ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
        # 控制最大归一化组数,默认为 32
        ctrl_max_norm_num_groups: int = 32,
    # 定义类方法,从 UNet 创建模型
    @classmethod
    def from_unet(
        cls,
        # UNet2DConditionModel 实例
        unet: UNet2DConditionModel,
        # 可选的 ControlNetXSAdapter 实例
        controlnet: Optional[ControlNetXSAdapter] = None,
        # 可选的大小比例
        size_ratio: Optional[float] = None,
        # 可选的控制块输出通道列表
        ctrl_block_out_channels: Optional[List[float]] = None,
        # 可选的时间嵌入混合系数
        time_embedding_mix: Optional[float] = None,
        # 可选的控制额外参数字典
        ctrl_optional_kwargs: Optional[Dict] = None,
    # 冻结 UNet2DConditionModel 基本部分的权重,其他部分可用于微调
    def freeze_unet_params(self) -> None:
        """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
        tuning."""
        # 将所有参数的梯度计算设置为可用
        for param in self.parameters():
            param.requires_grad = True
    
        # 解冻 ControlNetXSAdapter 相关部分
        base_parts = [
            "base_time_proj",
            "base_time_embedding",
            "base_add_time_proj",
            "base_add_embedding",
            "base_conv_in",
            "base_conv_norm_out",
            "base_conv_act",
            "base_conv_out",
        ]
        # 获取存在的基本部分的属性,过滤掉 None
        base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None]
        # 冻结基本部分的所有参数
        for part in base_parts:
            for param in part.parameters():
                param.requires_grad = False
    
        # 冻结每个下采样块的基本参数
        for d in self.down_blocks:
            d.freeze_base_params()
        # 冻结中间块的基本参数
        self.mid_block.freeze_base_params()
        # 冻结每个上采样块的基本参数
        for u in self.up_blocks:
            u.freeze_base_params()
    
    # 设置模块的梯度检查点功能
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块具有梯度检查点属性,则设置其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value
    
    @property
    # 从 UNet2DConditionModel 中复制的属性,用于获取注意力处理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 用于递归设置处理器的字典
        processors = {}
    
        # 递归添加处理器的辅助函数
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模块具有获取处理器的方法,则将其添加到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的所有子模块,递归调用处理器添加函数
            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            return processors
    
        # 遍历当前对象的所有子模块,并调用处理器添加函数
        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)
    
        # 返回所有处理器的字典
        return processors
    
    # 从 UNet2DConditionModel 中复制的设置注意力处理器的方法
    # 定义设置注意力处理器的方法,参数为单个处理器或处理器字典
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor(`dict` 或 `AttentionProcessor`): 
                实例化的处理器类或将作为处理器设置的处理器类字典
                对于**所有** `Attention` 层。
    
                如果 `processor` 是字典,键需要定义对应交叉注意力处理器的路径。
                当设置可训练的注意力处理器时,强烈推荐这样做。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 如果传入的处理器为字典且数量不匹配,则抛出异常
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与"
                f" 注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
    
        # 定义递归设置注意力处理器的函数
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块具有 set_processor 方法,则设置处理器
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历模块的子模块,递归调用处理器设置函数
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的子模块,调用递归设置函数
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 复制的
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 如果所有处理器都是添加的 KV 注意力处理器,则设置处理器为 AttnAddedKVProcessor
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnAddedKVProcessor()
        # 如果所有处理器都是交叉注意力处理器,则设置处理器为 AttnProcessor
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            processor = AttnProcessor()
        else:
            # 否则抛出异常,提示无法设置默认处理器
            raise ValueError(
                f"当注意力处理器类型为 {next(iter(self.attn_processors.values()))} 时,无法调用 `set_default_attn_processor`"
            )
    
        # 调用设置注意力处理器的方法
        self.set_attn_processor(processor)
    
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu 复制的
    # 定义启用 FreeU 机制的方法,接收四个浮点数参数
        def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
            # 文档字符串,描述该方法的用途和参数含义
            r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
    
            The suffixes after the scaling factors represent the stage blocks where they are being applied.
    
            Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
            are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
    
            Args:
                s1 (`float`):
                    Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                    mitigate the "oversmoothing effect" in the enhanced denoising process.
                s2 (`float`):
                    Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                    mitigate the "oversmoothing effect" in the enhanced denoising process.
                b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
                b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
            """
            # 遍历上采样模块并为每个模块设置对应的 scaling 因子
            for i, upsample_block in enumerate(self.up_blocks):
                # 设置上采样块的 s1 属性为传入的 s1 值
                setattr(upsample_block, "s1", s1)
                # 设置上采样块的 s2 属性为传入的 s2 值
                setattr(upsample_block, "s2", s2)
                # 设置上采样块的 b1 属性为传入的 b1 值
                setattr(upsample_block, "b1", b1)
                # 设置上采样块的 b2 属性为传入的 b2 值
                setattr(upsample_block, "b2", b2)
    
        # 定义禁用 FreeU 机制的方法
        def disable_freeu(self):
            """Disables the FreeU mechanism."""
            # 定义 FreeU 机制中需要清除的键集合
            freeu_keys = {"s1", "s2", "b1", "b2"}
            # 遍历上采样模块
            for i, upsample_block in enumerate(self.up_blocks):
                # 遍历需要清除的键
                for k in freeu_keys:
                    # 检查模块是否具有该属性或属性值非 None
                    if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                        # 将属性值设置为 None,禁用 FreeU
                        setattr(upsample_block, k, None)
    
        # 定义融合 QKV 投影的方法
        def fuse_qkv_projections(self):
            """
            Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
            are fused. For cross-attention modules, key and value projection matrices are fused.
    
            <Tip warning={true}>
    
            This API is 🧪 experimental.
    
            </Tip>
            """
            # 初始化原始注意力处理器为 None
            self.original_attn_processors = None
    
            # 遍历注意力处理器
            for _, attn_processor in self.attn_processors.items():
                # 检查是否有添加的 KV 投影,不支持融合
                if "Added" in str(attn_processor.__class__.__name__):
                    raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
    
            # 记录原始注意力处理器
            self.original_attn_processors = self.attn_processors
    
            # 遍历所有模块
            for module in self.modules():
                # 检查模块是否为注意力模块
                if isinstance(module, Attention):
                    # 执行投影融合
                    module.fuse_projections(fuse=True)
    
            # 设置注意力处理器为融合后的处理器
            self.set_attn_processor(FusedAttnProcessor2_0())
    
        # 此部分代码未提供,可能是禁用 QKV 投影的方法
    # 定义一个方法,用于禁用已启用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """Disables the fused QKV projection if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        """
        # 如果原始的注意力处理器不为 None,则设置为原始处理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    # 定义前向传播方法,接收多个输入参数
    def forward(
        self,
        sample: Tensor,  # 输入样本,类型为 Tensor
        timestep: Union[torch.Tensor, float, int],  # 时间步,支持多种类型
        encoder_hidden_states: torch.Tensor,  # 编码器的隐藏状态,类型为 Tensor
        controlnet_cond: Optional[torch.Tensor] = None,  # 可选的控制网络条件,类型为 Tensor
        conditioning_scale: Optional[float] = 1.0,  # 条件缩放因子,默认为 1.0
        class_labels: Optional[torch.Tensor] = None,  # 可选的类标签,类型为 Tensor
        timestep_cond: Optional[torch.Tensor] = None,  # 可选的时间步条件,类型为 Tensor
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码,类型为 Tensor
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的交叉注意力参数,类型为字典
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,  # 可选的附加条件参数,类型为字典
        return_dict: bool = True,  # 是否返回字典格式的结果,默认为 True
        apply_control: bool = True,  # 是否应用控制逻辑,默认为 True
# 定义一个名为 ControlNetXSCrossAttnDownBlock2D 的类,继承自 nn.Module
class ControlNetXSCrossAttnDownBlock2D(nn.Module):
    # 初始化方法,定义类的属性和参数
    def __init__(
        self,
        base_in_channels: int,  # 基础输入通道数
        base_out_channels: int,  # 基础输出通道数
        ctrl_in_channels: int,  # 控制输入通道数
        ctrl_out_channels: int,  # 控制输出通道数
        temb_channels: int,  # 时间嵌入通道数
        norm_num_groups: int = 32,  # 规范化组数
        ctrl_max_norm_num_groups: int = 32,  # 控制最大规范化组数
        has_crossattn=True,  # 是否包含交叉注意力机制
        transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,  # 每个块的变换器层数
        base_num_attention_heads: Optional[int] = 1,  # 基础注意力头数
        ctrl_num_attention_heads: Optional[int] = 1,  # 控制注意力头数
        cross_attention_dim: Optional[int] = 1024,  # 交叉注意力维度
        add_downsample: bool = True,  # 是否添加下采样
        upcast_attention: Optional[bool] = False,  # 是否上升注意力
        use_linear_projection: Optional[bool] = True,  # 是否使用线性投影
    @classmethod
    # 定义一个类方法,用于冻结基础模型的参数
    def freeze_base_params(self) -> None:
        """冻结基础 UNet2DConditionModel 的权重,保持其他部分可调,以便微调。"""
        # 解冻所有参数
        for param in self.parameters():
            param.requires_grad = True

        # 冻结基础部分的参数
        base_parts = [self.base_resnets]  # 包含基础残差网络部分
        if isinstance(self.base_attentions, nn.ModuleList):  # 如果注意力部分是一个模块列表
            base_parts.append(self.base_attentions)  # 添加基础注意力部分
        if self.base_downsamplers is not None:  # 如果存在基础下采样部分
            base_parts.append(self.base_downsamplers)  # 添加基础下采样部分
        for part in base_parts:  # 遍历基础部分
            for param in part.parameters():  # 遍历参数
                param.requires_grad = False  # 冻结参数以防止更新

    # 定义前向传播方法
    def forward(
        self,
        hidden_states_base: Tensor,  # 基础隐藏状态
        temb: Tensor,  # 时间嵌入
        encoder_hidden_states: Optional[Tensor] = None,  # 编码器隐藏状态
        hidden_states_ctrl: Optional[Tensor] = None,  # 控制隐藏状态
        conditioning_scale: Optional[float] = 1.0,  # 条件缩放因子
        attention_mask: Optional[Tensor] = None,  # 注意力掩码
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的关键字参数
        encoder_attention_mask: Optional[Tensor] = None,  # 编码器注意力掩码
        apply_control: bool = True,  # 是否应用控制
class ControlNetXSCrossAttnMidBlock2D(nn.Module):
    # 定义一个名为 ControlNetXSCrossAttnMidBlock2D 的类,继承自 nn.Module
    def __init__(
        self,
        base_channels: int,  # 基础通道数
        ctrl_channels: int,  # 控制通道数
        temb_channels: Optional[int] = None,  # 时间嵌入通道数(可选)
        norm_num_groups: int = 32,  # 规范化组数
        ctrl_max_norm_num_groups: int = 32,  # 控制最大规范化组数
        transformer_layers_per_block: int = 1,  # 每个块的变换器层数
        base_num_attention_heads: Optional[int] = 1,  # 基础注意力头数
        ctrl_num_attention_heads: Optional[int] = 1,  # 控制注意力头数
        cross_attention_dim: Optional[int] = 1024,  # 交叉注意力维度
        upcast_attention: bool = False,  # 是否上升注意力
        use_linear_projection: Optional[bool] = True,  # 是否使用线性投影
    ):
        # 调用父类的构造函数以初始化继承的属性和方法
        super().__init__()

        # 在中间块应用之前,从基础信息到控制信息的连接。
        # 连接不需要改变通道数量
        self.base_to_ctrl = make_zero_conv(base_channels, base_channels)

        # 创建基础中间块,使用交叉注意力机制
        self.base_midblock = UNetMidBlock2DCrossAttn(
            # 每个块中的变换器层数量
            transformer_layers_per_block=transformer_layers_per_block,
            # 输入通道数为基础通道数
            in_channels=base_channels,
            # 嵌入通道数
            temb_channels=temb_channels,
            # ResNet 组的数量
            resnet_groups=norm_num_groups,
            # 交叉注意力维度
            cross_attention_dim=cross_attention_dim,
            # 注意力头的数量
            num_attention_heads=base_num_attention_heads,
            # 是否使用线性投影
            use_linear_projection=use_linear_projection,
            # 是否上溯注意力
            upcast_attention=upcast_attention,
        )

        # 创建控制中间块,使用交叉注意力机制
        self.ctrl_midblock = UNetMidBlock2DCrossAttn(
            # 每个块中的变换器层数量
            transformer_layers_per_block=transformer_layers_per_block,
            # 输入通道数为控制通道数加基础通道数
            in_channels=ctrl_channels + base_channels,
            # 输出通道数为控制通道数
            out_channels=ctrl_channels,
            # 嵌入通道数
            temb_channels=temb_channels,
            # norm 组数量必须同时能被输入和输出通道数整除
            resnet_groups=find_largest_factor(
                # 计算控制通道与控制通道加基础通道的最大公约数
                gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
            ),
            # 交叉注意力维度
            cross_attention_dim=cross_attention_dim,
            # 注意力头的数量
            num_attention_heads=ctrl_num_attention_heads,
            # 是否使用线性投影
            use_linear_projection=use_linear_projection,
            # 是否上溯注意力
            upcast_attention=upcast_attention,
        )

        # 在中间块应用之后,从控制信息到基础信息的相加
        # 相加需要改变通道数量
        self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)

        # 初始化梯度检查点标志为假
        self.gradient_checkpointing = False

    @classmethod
    def from_modules(
        # 类方法,接受基础中间块和控制中间块作为参数
        cls,
        base_midblock: UNetMidBlock2DCrossAttn,
        ctrl_midblock: MidBlockControlNetXSAdapter,
    ):
        # 获取中间块的基准到控制的映射
        base_to_ctrl = ctrl_midblock.base_to_ctrl
        # 获取中间块的控制到基准的映射
        ctrl_to_base = ctrl_midblock.ctrl_to_base
        # 获取中间块的实例
        ctrl_midblock = ctrl_midblock.midblock

        # 获取第一个交叉注意力模块
        def get_first_cross_attention(midblock):
            # 返回中间块的第一个注意力模块的交叉注意力层
            return midblock.attentions[0].transformer_blocks[0].attn2

        # 获取控制到基准的输出通道数
        base_channels = ctrl_to_base.out_channels
        # 获取控制到基准的输入通道数
        ctrl_channels = ctrl_to_base.in_channels
        # 获取基准中间块的每个块的转换层数
        transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
        # 获取基准中间块时间嵌入的输入特征数
        temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
        # 获取基准中间块的归一化组数
        num_groups = base_midblock.resnets[0].norm1.num_groups
        # 获取控制中间块的归一化组数
        ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
        # 获取基准中间块第一个交叉注意力模块的注意力头数
        base_num_attention_heads = get_first_cross_attention(base_midblock).heads
        # 获取控制中间块第一个交叉注意力模块的注意力头数
        ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
        # 获取基准中间块第一个交叉注意力模块的交叉注意力维度
        cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
        # 获取基准中间块第一个交叉注意力模块的上采样注意力设置
        upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
        # 获取基准中间块第一个注意力模块的线性投影使用情况
        use_linear_projection = base_midblock.attentions[0].use_linear_projection

        # 创建模型实例
        model = cls(
            # 传入基准通道数
            base_channels=base_channels,
            # 传入控制通道数
            ctrl_channels=ctrl_channels,
            # 传入时间嵌入通道数
            temb_channels=temb_channels,
            # 传入归一化组数
            norm_num_groups=num_groups,
            # 传入控制最大归一化组数
            ctrl_max_norm_num_groups=ctrl_num_groups,
            # 传入每块的转换层数
            transformer_layers_per_block=transformer_layers_per_block,
            # 传入基准注意力头数
            base_num_attention_heads=base_num_attention_heads,
            # 传入控制注意力头数
            ctrl_num_attention_heads=ctrl_num_attention_heads,
            # 传入交叉注意力维度
            cross_attention_dim=cross_attention_dim,
            # 传入上采样注意力设置
            upcast_attention=upcast_attention,
            # 传入线性投影使用情况
            use_linear_projection=use_linear_projection,
        )

        # 加载模型权重
        model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
        # 加载基准中间块的权重
        model.base_midblock.load_state_dict(base_midblock.state_dict())
        # 加载控制中间块的权重
        model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
        # 加载控制到基准的权重
        model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())

        # 返回构建好的模型
        return model

    def freeze_base_params(self) -> None:
        """冻结属于基准 UNet2DConditionModel 的权重,保留其他部分以便进行微调。"""
        # 解冻所有参数
        for param in self.parameters():
            param.requires_grad = True

        # 冻结基准部分的参数
        for param in self.base_midblock.parameters():
            param.requires_grad = False

    def forward(
        self,
        # 基准的隐藏状态
        hidden_states_base: Tensor,
        # 时间嵌入
        temb: Tensor,
        # 编码器的隐藏状态
        encoder_hidden_states: Tensor,
        # 控制的隐藏状态(可选)
        hidden_states_ctrl: Optional[Tensor] = None,
        # 条件缩放因子(可选),默认为1.0
        conditioning_scale: Optional[float] = 1.0,
        # 交叉注意力的额外参数(可选)
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 注意力掩码(可选)
        attention_mask: Optional[Tensor] = None,
        # 编码器的注意力掩码(可选)
        encoder_attention_mask: Optional[Tensor] = None,
        # 是否应用控制(默认为True)
        apply_control: bool = True,
    # 返回一个包含两个张量的元组
    ) -> Tuple[Tensor, Tensor]:
        # 如果提供了交叉注意力的参数
        if cross_attention_kwargs is not None:
            # 检查是否有 scale 参数,并发出警告
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
        # 设置基础隐藏状态
        h_base = hidden_states_base
        # 设置控制隐藏状态
        h_ctrl = hidden_states_ctrl
    
        # 创建一个包含多个参数的字典
        joint_args = {
            "temb": temb,
            "encoder_hidden_states": encoder_hidden_states,
            "attention_mask": attention_mask,
            "cross_attention_kwargs": cross_attention_kwargs,
            "encoder_attention_mask": encoder_attention_mask,
        }
    
        # 如果应用控制,则连接基础和控制隐藏状态
        if apply_control:
            h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1)  # concat base -> ctrl
        # 应用基础中间块到基础隐藏状态
        h_base = self.base_midblock(h_base, **joint_args)  # apply base mid block
        # 如果应用控制,则应用控制中间块
        if apply_control:
            h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args)  # apply ctrl mid block
            # 将控制结果加到基础隐藏状态上,乘以条件缩放因子
            h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale  # add ctrl -> base
    
        # 返回基础和控制的隐藏状态
        return h_base, h_ctrl
# 定义一个名为 ControlNetXSCrossAttnUpBlock2D 的神经网络模块,继承自 nn.Module
class ControlNetXSCrossAttnUpBlock2D(nn.Module):
    # 初始化方法,定义该模块的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        out_channels: int,  # 输出通道数
        prev_output_channel: int,  # 前一层的输出通道数
        ctrl_skip_channels: List[int],  # 控制跳跃连接的通道数列表
        temb_channels: int,  # 时间嵌入通道数
        norm_num_groups: int = 32,  # 归一化的组数,默认值为32
        resolution_idx: Optional[int] = None,  # 分辨率索引,可选
        has_crossattn=True,  # 是否包含交叉注意力机制,默认值为True
        transformer_layers_per_block: int = 1,  # 每个模块的变换器层数,默认值为1
        num_attention_heads: int = 1,  # 注意力头的数量,默认值为1
        cross_attention_dim: int = 1024,  # 交叉注意力的维度,默认值为1024
        add_upsample: bool = True,  # 是否添加上采样层,默认值为True
        upcast_attention: bool = False,  # 是否提升注意力计算精度,默认值为False
        use_linear_projection: Optional[bool] = True,  # 是否使用线性投影,默认值为True
    ):
        # 调用父类的初始化方法
        super().__init__()
        resnets = []  # 初始化一个空列表,用于存放 ResNet 模块
        attentions = []  # 初始化一个空列表,用于存放注意力模块
        ctrl_to_base = []  # 初始化一个空列表,用于存放控制到基础的卷积模块

        num_layers = 3  # 仅支持3层,适用于 sd 和 sdxl

        # 记录是否包含交叉注意力和注意力头的数量
        self.has_cross_attention = has_crossattn
        self.num_attention_heads = num_attention_heads

        # 如果 transformer_layers_per_block 是整数,则将其扩展为包含 num_layers 个相同值的列表
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 遍历每一层
        for i in range(num_layers):
            # 确定当前层的跳跃连接通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 确定当前层的输入通道数
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 创建从控制通道到基础通道的零卷积,并添加到列表中
            ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))

            # 添加 ResNet 模块到 resnets 列表
            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,  # 输入通道数
                    out_channels=out_channels,  # 输出通道数
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    groups=norm_num_groups,  # 归一化组数
                )
            )

            # 如果包含交叉注意力,则添加 Transformer 模块到 attentions 列表
            if has_crossattn:
                attentions.append(
                    Transformer2DModel(
                        num_attention_heads,  # 注意力头数量
                        out_channels // num_attention_heads,  # 每个头的输出通道数
                        in_channels=out_channels,  # 输入通道数
                        num_layers=transformer_layers_per_block[i],  # 当前层的变换器层数
                        cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                        use_linear_projection=use_linear_projection,  # 是否使用线性投影
                        upcast_attention=upcast_attention,  # 是否提升注意力计算精度
                        norm_num_groups=norm_num_groups,  # 归一化组数
                    )
                )

        # 将 ResNet 模块列表转换为 nn.ModuleList,以便在模型中管理
        self.resnets = nn.ModuleList(resnets)
        # 如果有交叉注意力,转换 attentions 列表为 nn.ModuleList,否则填充 None
        self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers
        # 将控制到基础的卷积模块列表转换为 nn.ModuleList
        self.ctrl_to_base = nn.ModuleList(ctrl_to_base)

        # 如果需要添加上采样层,初始化 Upsample2D 模块
        if add_upsample:
            self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels)
        else:
            self.upsamplers = None  # 如果不需要上采样,则将其设置为 None

        self.gradient_checkpointing = False  # 初始化时禁用梯度检查点
        self.resolution_idx = resolution_idx  # 设置分辨率索引
    # 从模块创建模型的类方法
        def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter):
            # 获取控制到基础的跳跃连接
            ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base
    
            # 获取参数
            # 获取第一个交叉注意力模块
            def get_first_cross_attention(block):
                return block.attentions[0].transformer_blocks[0].attn2
    
            # 获取基础上采样块的输出通道数
            out_channels = base_upblock.resnets[0].out_channels
            # 计算输入通道数
            in_channels = base_upblock.resnets[-1].in_channels - out_channels
            # 计算前一个输出通道数
            prev_output_channels = base_upblock.resnets[0].in_channels - out_channels
            # 获取控制跳跃连接的输入通道数
            ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections]
            # 获取时间嵌入的输入特征数
            temb_channels = base_upblock.resnets[0].time_emb_proj.in_features
            # 获取归一化组数
            num_groups = base_upblock.resnets[0].norm1.num_groups
            # 获取分辨率索引
            resolution_idx = base_upblock.resolution_idx
            # 检查基础上采样块是否有注意力模块
            if hasattr(base_upblock, "attentions"):
                has_crossattn = True
                # 获取每个块的变换层数
                transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks)
                # 获取注意力头数
                num_attention_heads = get_first_cross_attention(base_upblock).heads
                # 获取交叉注意力维度
                cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
                # 获取上升注意力标志
                upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
                # 获取是否使用线性投影
                use_linear_projection = base_upblock.attentions[0].use_linear_projection
            else:
                has_crossattn = False
                transformer_layers_per_block = None
                num_attention_heads = None
                cross_attention_dim = None
                upcast_attention = None
                use_linear_projection = None
            # 检查是否需要添加上采样
            add_upsample = base_upblock.upsamplers is not None
    
            # 创建模型
            model = cls(
                # 输入通道数
                in_channels=in_channels,
                # 输出通道数
                out_channels=out_channels,
                # 前一个输出通道
                prev_output_channel=prev_output_channels,
                # 控制跳跃连接的输入通道数
                ctrl_skip_channels=ctrl_skip_channelss,
                # 时间嵌入的通道数
                temb_channels=temb_channels,
                # 归一化的组数
                norm_num_groups=num_groups,
                # 分辨率索引
                resolution_idx=resolution_idx,
                # 是否有交叉注意力
                has_crossattn=has_crossattn,
                # 每个块的变换层数
                transformer_layers_per_block=transformer_layers_per_block,
                # 注意力头数
                num_attention_heads=num_attention_heads,
                # 交叉注意力维度
                cross_attention_dim=cross_attention_dim,
                # 是否添加上采样
                add_upsample=add_upsample,
                # 上升注意力标志
                upcast_attention=upcast_attention,
                # 是否使用线性投影
                use_linear_projection=use_linear_projection,
            )
    
            # 加载权重
            model.resnets.load_state_dict(base_upblock.resnets.state_dict())
            # 如果有交叉注意力,加载其权重
            if has_crossattn:
                model.attentions.load_state_dict(base_upblock.attentions.state_dict())
            # 如果需要添加上采样,加载其权重
            if add_upsample:
                model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict())
            # 加载控制到基础的跳跃连接权重
            model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict())
    
            # 返回创建的模型
            return model
    # 定义一个方法,用于冻结基础 UNet2DConditionModel 的参数
    def freeze_base_params(self) -> None:
        """冻结属于基础 UNet2DConditionModel 的权重,其他部分保持解冻以便微调。"""
        # 解冻所有参数,允许训练
        for param in self.parameters():
            param.requires_grad = True
    
        # 冻结基础部分的参数
        base_parts = [self.resnets]  # 将基础部分(resnets)添加到列表中
        # 检查 attentions 是否是 ModuleList 类型(可能包含 None)
        if isinstance(self.attentions, nn.ModuleList):
            base_parts.append(self.attentions)  # 如果是,则添加 attentions
        # 检查 upsamplers 是否不为 None
        if self.upsamplers is not None:
            base_parts.append(self.upsamplers)  # 如果存在,添加 upsamplers
        # 冻结基础部分的参数
        for part in base_parts:
            for param in part.parameters():
                param.requires_grad = False  # 设置参数为不可训练
    
    # 定义前向传播方法
    def forward(
        self,
        hidden_states: Tensor,  # 输入的隐藏状态
        res_hidden_states_tuple_base: Tuple[Tensor, ...],  # 基础残差隐藏状态元组
        res_hidden_states_tuple_ctrl: Tuple[Tensor, ...],  # 控制残差隐藏状态元组
        temb: Tensor,  # 时间嵌入
        encoder_hidden_states: Optional[Tensor] = None,  # 可选的编码器隐藏状态
        conditioning_scale: Optional[float] = 1.0,  # 可选的条件缩放因子,默认值为 1.0
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的交叉注意力参数
        attention_mask: Optional[Tensor] = None,  # 可选的注意力掩码
        upsample_size: Optional[int] = None,  # 可选的上采样大小
        encoder_attention_mask: Optional[Tensor] = None,  # 可选的编码器注意力掩码
        apply_control: bool = True,  # 是否应用控制,默认值为 True
    # 函数返回一个 Tensor 对象
    ) -> Tensor:
        # 检查交叉注意力参数是否存在
        if cross_attention_kwargs is not None:
            # 检查参数中是否包含 "scale"
            if cross_attention_kwargs.get("scale", None) is not None:
                # 记录警告信息,表示 "scale" 参数已弃用
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
        # 判断 FreeU 是否启用,检查相关属性是否存在
        is_freeu_enabled = (
            getattr(self, "s1", None)
            and getattr(self, "s2", None)
            and getattr(self, "b1", None)
            and getattr(self, "b2", None)
        )
    
        # 定义创建自定义前向传播的方法
        def create_custom_forward(module, return_dict=None):
            # 定义自定义前向传播函数
            def custom_forward(*inputs):
                # 根据是否返回字典选择调用方式
                if return_dict is not None:
                    return module(*inputs, return_dict=return_dict)
                else:
                    return module(*inputs)
    
            return custom_forward
    
        # 定义条件应用 FreeU 的方法
        def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
            # FreeU: 仅在前两个阶段操作
            if is_freeu_enabled:
                # 应用 FreeU 操作
                return apply_freeu(
                    self.resolution_idx,
                    hidden_states,
                    res_h_base,
                    s1=self.s1,
                    s2=self.s2,
                    b1=self.b1,
                    b2=self.b2,
                )
            else:
                # 如果未启用 FreeU,直接返回输入状态
                return hidden_states, res_h_base
    
        # 同时遍历多个列表
        for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(
            self.resnets,
            self.attentions,
            self.ctrl_to_base,
            reversed(res_hidden_states_tuple_base),
            reversed(res_hidden_states_tuple_ctrl),
        ):
            # 如果应用控制,则调整隐藏状态
            if apply_control:
                hidden_states += c2b(res_h_ctrl) * conditioning_scale
    
            # 可能应用 FreeU 操作
            hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
            # 将隐藏状态和基础状态沿维度 1 拼接
            hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
    
            # 如果在训练并启用梯度检查点
            if self.training and self.gradient_checkpointing:
                # 根据 PyTorch 版本设置检查点参数
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 应用检查点以减少内存使用
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    **ckpt_kwargs,
                )
            else:
                # 直接使用残差网络处理隐藏状态
                hidden_states = resnet(hidden_states, temb)
    
            # 如果注意力模块不为 None,则进行注意力计算
            if attn is not None:
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
    
        # 如果上采样器存在,应用上采样操作
        if self.upsamplers is not None:
            hidden_states = self.upsamplers(hidden_states, upsample_size)
    
        # 返回最终的隐藏状态
        return hidden_states
# 创建一个零卷积层的函数,接收输入和输出通道数
def make_zero_conv(in_channels, out_channels=None):
    # 使用 zero_module 函数初始化一个卷积层,并设置卷积核大小为1,填充为0
    return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))


# 初始化传入模块的参数为零的函数
def zero_module(module):
    # 遍历模块的所有参数
    for p in module.parameters():
        # 将每个参数初始化为零
        nn.init.zeros_(p)
    # 返回已初始化的模块
    return module


# 查找给定数字的最大因数的函数,最大因数不超过指定值
def find_largest_factor(number, max_factor):
    # 将最大因数设置为初始因数
    factor = max_factor
    # 如果最大因数大于或等于数字,直接返回数字
    if factor >= number:
        return number
    # 循环直到找到一个因数
    while factor != 0:
        # 计算数字与因数的余数
        residual = number % factor
        # 如果余数为零,则因数是有效的
        if residual == 0:
            return factor
        # 减小因数,继续查找
        factor -= 1
posted @ 2024-10-22 12:36  绝不原创的飞龙  阅读(59)  评论(0编辑  收藏  举报