diffusers-源码解析-十五-

diffusers 源码解析(十五)

.\diffusers\models\unets\unet_3d_condition.py

# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
# 版权声明,声明此代码的版权信息和所有权
# Copyright 2024 The ModelScope Team.
#
# 许可声明,声明本代码使用的 Apache 许可证 2.0 版本
# Licensed under the Apache License, Version 2.0 (the "License");
# 使用此文件前需遵守许可证规定
# you may not use this file except in compliance with the License.
# 可在以下网址获取许可证副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,说明软件在许可下按 "原样" 提供,不附加任何明示或暗示的保证
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证中规定的权限和限制说明
# See the License for the specific language governing permissions and
# limitations under the License.

# 从 dataclasses 模块导入 dataclass 装饰器
from dataclasses import dataclass
# 从 typing 模块导入所需的类型提示
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 PyTorch 库
import torch
# 导入 PyTorch 神经网络模块
import torch.nn as nn
# 导入 PyTorch 的检查点工具
import torch.utils.checkpoint

# 导入配置相关的工具类和函数
from ...configuration_utils import ConfigMixin, register_to_config
# 导入 UNet2D 条件加载器混合类
from ...loaders import UNet2DConditionLoadersMixin
# 导入基本输出类和日志工具
from ...utils import BaseOutput, logging
# 导入激活函数获取工具
from ..activations import get_activation
# 导入各种注意力处理器相关组件
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入添加键值对注意力处理器
    CROSS_ATTENTION_PROCESSORS,      # 导入交叉注意力处理器
    Attention,                       # 导入注意力类
    AttentionProcessor,              # 导入注意力处理器基类
    AttnAddedKVProcessor,            # 导入添加键值对的注意力处理器
    AttnProcessor,                   # 导入普通注意力处理器
    FusedAttnProcessor2_0,           # 导入融合注意力处理器
)
# 导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 导入模型混合类
from ..modeling_utils import ModelMixin
# 导入时间变换器模型
from ..transformers.transformer_temporal import TransformerTemporalModel
# 导入 3D UNet 相关的块
from .unet_3d_blocks import (
    CrossAttnDownBlock3D,          # 导入交叉注意力下采样块
    CrossAttnUpBlock3D,            # 导入交叉注意力上采样块
    DownBlock3D,                   # 导入下采样块
    UNetMidBlock3DCrossAttn,      # 导入 UNet 中间交叉注意力块
    UpBlock3D,                     # 导入上采样块
    get_down_block,                # 导入获取下采样块的函数
    get_up_block,                  # 导入获取上采样块的函数
)

# 创建日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义 UNet3DConditionOutput 数据类,继承自 BaseOutput
@dataclass
class UNet3DConditionOutput(BaseOutput):
    """
    [`UNet3DConditionModel`] 的输出类。

    参数:
        sample (`torch.Tensor` 的形状为 `(batch_size, num_channels, num_frames, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
    """

    sample: torch.Tensor  # 定义样本输出,类型为 PyTorch 张量

# 定义 UNet3DConditionModel 类,继承自多个混合类
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    条件 3D UNet 模型,接受噪声样本、条件状态和时间步,并返回形状为样本的输出。

    此模型继承自 [`ModelMixin`]。有关其通用方法的文档,请参阅超类文档(如下载或保存)。
    # 参数说明部分
    Parameters:
        # 输入/输出样本的高度和宽度,类型可以为整数或元组,默认为 None
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
        # 输入样本的通道数,默认为 4
        in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
        # 输出的通道数,默认为 4
        out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
        # 使用的下采样块类型的元组,默认为指定的四种块
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
            The tuple of downsample blocks to use.
        # 使用的上采样块类型的元组,默认为指定的四种块
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
            The tuple of upsample blocks to use.
        # 每个块的输出通道数的元组,默认为 (320, 640, 1280, 1280)
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 每个块的层数,默认为 2
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        # 下采样卷积使用的填充,默认为 1
        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
        # 中间块使用的缩放因子,默认为 1.0
        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
        # 使用的激活函数,默认为 "silu"
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        # 用于归一化的组数,默认为 32;如果为 None,则跳过归一化和激活层
        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
            If `None`, normalization and activation layers is skipped in post-processing.
        # 归一化使用的 epsilon 值,默认为 1e-5
        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
        # 交叉注意力特征的维度,默认为 1024
        cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
        # 注意力头的维度,默认为 64
        attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
        # 注意力头的数量,类型为整数,默认为 None
        num_attention_heads (`int`, *optional*): The number of attention heads.
        # 时间条件投影层的维度,默认为 None
        time_cond_proj_dim (`int`, *optional*, defaults to `None`):
            The dimension of `cond_proj` layer in the timestep embedding.
    """

    # 是否支持梯度检查点,默认为 False
    _supports_gradient_checkpointing = False

    # 将此类注册到配置中
    @register_to_config
    # 初始化方法,用于创建类的实例
        def __init__(
            # 样本大小,默认为 None
            self,
            sample_size: Optional[int] = None,
            # 输入通道数量,默认为 4
            in_channels: int = 4,
            # 输出通道数量,默认为 4
            out_channels: int = 4,
            # 下采样块类型的元组,定义模型的下采样结构
            down_block_types: Tuple[str, ...] = (
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "CrossAttnDownBlock3D",
                "DownBlock3D",
            ),
            # 上采样块类型的元组,定义模型的上采样结构
            up_block_types: Tuple[str, ...] = (
                "UpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
                "CrossAttnUpBlock3D",
            ),
            # 每个块的输出通道数量,定义模型每个层的通道设置
            block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
            # 每个块的层数,默认为 2
            layers_per_block: int = 2,
            # 下采样时的填充大小,默认为 1
            downsample_padding: int = 1,
            # 中间块的缩放因子,默认为 1
            mid_block_scale_factor: float = 1,
            # 激活函数类型,默认为 "silu"
            act_fn: str = "silu",
            # 归一化组的数量,默认为 32
            norm_num_groups: Optional[int] = 32,
            # 归一化的 epsilon 值,默认为 1e-5
            norm_eps: float = 1e-5,
            # 跨注意力维度,默认为 1024
            cross_attention_dim: int = 1024,
            # 注意力头的维度,可以是单一整数或整数元组,默认为 64
            attention_head_dim: Union[int, Tuple[int]] = 64,
            # 注意力头的数量,可选参数
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 时间条件投影维度,可选参数
            time_cond_proj_dim: Optional[int] = None,
        @property
        # 从 UNet2DConditionModel 复制的属性,获取注意力处理器
        # 返回所有注意力处理器的字典,以权重名称为索引
        def attn_processors(self) -> Dict[str, AttentionProcessor]:
            r"""
            Returns:
                `dict` of attention processors: A dictionary containing all attention processors used in the model with
                indexed by its weight name.
            """
            # 初始化处理器字典
            processors = {}
    
            # 递归添加处理器的函数
            def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
                # 如果模块有获取处理器的方法,添加到处理器字典中
                if hasattr(module, "get_processor"):
                    processors[f"{name}.processor"] = module.get_processor()
    
                # 遍历子模块,递归调用该函数
                for sub_name, child in module.named_children():
                    fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
                # 返回处理器字典
                return processors
    
            # 遍历当前类的子模块,调用递归添加处理器的函数
            for name, module in self.named_children():
                fn_recursive_add_processors(name, module, processors)
    
            # 返回所有处理器
            return processors
    
        # 从 UNet2DConditionModel 复制的设置注意力切片的方法
        # 从 UNet2DConditionModel 复制的设置注意力处理器的方法
    # 定义一个方法用于设置注意力处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor(`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或一个处理器类的字典,将作为所有 `Attention` 层的处理器。
    
                如果 `processor` 是一个字典,键需要定义相应的交叉注意力处理器的路径。
                在设置可训练的注意力处理器时,强烈推荐这样做。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 如果传入的处理器是字典,且数量不等于注意力层数量,抛出错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了一个处理器字典,但处理器的数量 {len(processor)} 与"
                f" 注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
    
        # 定义一个递归函数来设置每个模块的处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中获取相应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历子模块并递归调用处理器设置
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的所有子模块,并调用递归设置函数
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    # 定义一个方法来启用前馈层的分块处理
        def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
            """
            设置注意力处理器以使用 [前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
    
            参数:
                chunk_size (`int`, *可选*):
                    前馈层的分块大小。如果未指定,将对维度为`dim`的每个张量单独运行前馈层。
                dim (`int`, *可选*, 默认为`0`):
                    应对哪个维度进行前馈计算的分块。可以选择 dim=0(批次)或 dim=1(序列长度)。
            """
            # 确保 dim 参数为 0 或 1
            if dim not in [0, 1]:
                raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
    
            # 默认的分块大小为 1
            chunk_size = chunk_size or 1
    
            # 定义一个递归函数来设置每个模块的分块前馈处理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模块具有设置分块前馈的属性,则设置它
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍历子模块,递归调用函数
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍历当前实例的子模块,应用递归函数
            for module in self.children():
                fn_recursive_feed_forward(module, chunk_size, dim)
    
        # 定义一个方法来禁用前馈层的分块处理
        def disable_forward_chunking(self):
            # 定义一个递归函数来禁用分块前馈处理
            def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
                # 如果模块具有设置分块前馈的属性,则设置为 None
                if hasattr(module, "set_chunk_feed_forward"):
                    module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
                # 遍历子模块,递归调用函数
                for child in module.children():
                    fn_recursive_feed_forward(child, chunk_size, dim)
    
            # 遍历当前实例的子模块,应用递归函数,禁用分块
            for module in self.children():
                fn_recursive_feed_forward(module, None, 0)
    
        # 从 diffusers.models.unets.unet_2d_condition 中复制的方法,设置默认注意力处理器
        def set_default_attn_processor(self):
            """
            禁用自定义注意力处理器并设置默认注意力实现。
            """
            # 检查所有注意力处理器是否为添加的 KV 处理器
            if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnAddedKVProcessor()  # 设置为添加的 KV 处理器
            # 检查所有注意力处理器是否为交叉注意力处理器
            elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
                processor = AttnProcessor()  # 设置为普通注意力处理器
            else:
                # 抛出异常,若注意力处理器类型不符合预期
                raise ValueError(
                    f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
                )
    
            # 设置选定的注意力处理器
            self.set_attn_processor(processor)
    
        # 定义一个私有方法来设置模块的梯度检查点
        def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
            # 检查模块是否属于特定类型
            if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
                module.gradient_checkpointing = value  # 设置梯度检查点值
    
        # 从 diffusers.models.unets.unet_2d_condition 中复制的方法,启用自由度
    # 启用 FreeU 机制,参数为两个缩放因子和两个增强因子的值
    def enable_freeu(self, s1, s2, b1, b2):
        r"""从 https://arxiv.org/abs/2309.11497 启用 FreeU 机制。

        缩放因子的后缀表示它们应用的阶段块。

        请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中已知效果良好的值组合。

        Args:
            s1 (`float`):
                第1阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            s2 (`float`):
                第2阶段的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            b1 (`float`): 第1阶段的缩放因子,用于增强骨干特征的贡献。
            b2 (`float`): 第2阶段的缩放因子,用于增强骨干特征的贡献。
        """
        # 遍历上采样块,给每个块设置缩放因子和增强因子
        for i, upsample_block in enumerate(self.up_blocks):
            # 设置第1阶段的缩放因子
            setattr(upsample_block, "s1", s1)
            # 设置第2阶段的缩放因子
            setattr(upsample_block, "s2", s2)
            # 设置第1阶段的增强因子
            setattr(upsample_block, "b1", b1)
            # 设置第2阶段的增强因子
            setattr(upsample_block, "b2", b2)

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu 复制
    # 禁用 FreeU 机制
    def disable_freeu(self):
        """禁用 FreeU 机制。"""
        # 定义 FreeU 机制的关键属性
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍历上采样块
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历 FreeU 关键属性
            for k in freeu_keys:
                # 如果上采样块有该属性,或者该属性值不为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 将属性值设置为 None,禁用 FreeU
                    setattr(upsample_block, k, None)

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 复制
    # 启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。对于交叉注意力模块,键和值投影矩阵被融合。

        <Tip warning={true}>

        此 API 是 🧪 实验性的。

        </Tip>
        """
        # 保存原始的注意力处理器
        self.original_attn_processors = None

        # 遍历注意力处理器
        for _, attn_processor in self.attn_processors.items():
            # 如果注意力处理器的类名中包含“Added”
            if "Added" in str(attn_processor.__class__.__name__):
                # 抛出错误,表示不支持具有附加 KV 投影的模型
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 保存当前的注意力处理器
        self.original_attn_processors = self.attn_processors

        # 遍历所有模块
        for module in self.modules():
            # 如果模块是 Attention 类型
            if isinstance(module, Attention):
                # 融合投影
                module.fuse_projections(fuse=True)

        # 设置注意力处理器为融合的注意力处理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 复制
    # 定义一个方法,用于禁用已启用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """禁用已启用的融合 QKV 投影。
    
        <Tip warning={true}>
    
        该 API 是 🧪 实验性的。
    
        </Tip>
    
        """
        # 如果存在原始的注意力处理器,则设置当前的注意力处理器为原始处理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)
    
    # 定义前向传播方法,接受多个参数进行计算
    def forward(
        self,
        sample: torch.Tensor,  # 输入样本,张量格式
        timestep: Union[torch.Tensor, float, int],  # 当前时间步,可以是张量、浮点数或整数
        encoder_hidden_states: torch.Tensor,  # 编码器的隐藏状态,张量格式
        class_labels: Optional[torch.Tensor] = None,  # 类别标签,默认为 None
        timestep_cond: Optional[torch.Tensor] = None,  # 时间步条件,默认为 None
        attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码,默认为 None
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 跨注意力的关键字参数,默认为 None
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,  # 降级块的附加残差,默认为 None
        mid_block_additional_residual: Optional[torch.Tensor] = None,  # 中间块的附加残差,默认为 None
        return_dict: bool = True,  # 是否返回字典格式的结果,默认为 True

.\diffusers\models\unets\unet_i2vgen_xl.py

# 版权声明,表明版权归2024年阿里巴巴DAMO-VILAB和HuggingFace团队所有
# 提供Apache许可证2.0版本的使用条款
# 说明只能在遵循许可证的情况下使用此文件
# 可在指定网址获取许可证副本
#
# 除非适用法律或书面协议另有约定,否则软件按“原样”分发
# 不提供任何形式的担保或条件
# 请参见许可证以获取与权限和限制相关的具体信息

from typing import Any, Dict, Optional, Tuple, Union  # 导入类型提示工具,用于类型注解

import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch的神经网络模块
import torch.utils.checkpoint  # 导入PyTorch的检查点工具

from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具导入类和函数
from ...loaders import UNet2DConditionLoadersMixin  # 导入2D条件加载器混合类
from ...utils import logging  # 导入日志工具
from ..activations import get_activation  # 导入激活函数获取工具
from ..attention import Attention, FeedForward  # 导入注意力机制和前馈网络
from ..attention_processor import (  # 从注意力处理器模块导入多个处理器
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps  # 导入时间步嵌入和时间步类
from ..modeling_utils import ModelMixin  # 导入模型混合类
from ..transformers.transformer_temporal import TransformerTemporalModel  # 导入时间变换器模型
from .unet_3d_blocks import (  # 从3D U-Net块模块导入多个类
    CrossAttnDownBlock3D,
    CrossAttnUpBlock3D,
    DownBlock3D,
    UNetMidBlock3DCrossAttn,
    UpBlock3D,
    get_down_block,
    get_up_block,
)
from .unet_3d_condition import UNet3DConditionOutput  # 导入3D条件输出类

logger = logging.get_logger(__name__)  # 创建日志记录器,用于记录当前模块的信息

class I2VGenXLTransformerTemporalEncoder(nn.Module):  # 定义一个名为I2VGenXLTransformerTemporalEncoder的类,继承自nn.Module
    def __init__(  # 构造函数,用于初始化类的实例
        self,
        dim: int,  # 输入的特征维度
        num_attention_heads: int,  # 注意力头的数量
        attention_head_dim: int,  # 每个注意力头的维度
        activation_fn: str = "geglu",  # 激活函数类型,默认使用geglu
        upcast_attention: bool = False,  # 是否提升注意力计算的精度
        ff_inner_dim: Optional[int] = None,  # 前馈网络的内部维度,默认为None
        dropout: int = 0.0,  # dropout概率,默认为0.0
    ):
        super().__init__()  # 调用父类构造函数
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)  # 初始化层归一化层
        self.attn1 = Attention(  # 初始化注意力层
            query_dim=dim,  # 查询维度
            heads=num_attention_heads,  # 注意力头数量
            dim_head=attention_head_dim,  # 每个头的维度
            dropout=dropout,  # dropout概率
            bias=False,  # 不使用偏置
            upcast_attention=upcast_attention,  # 是否提升注意力计算精度
            out_bias=True,  # 输出使用偏置
        )
        self.ff = FeedForward(  # 初始化前馈网络
            dim,  # 输入维度
            dropout=dropout,  # dropout概率
            activation_fn=activation_fn,  # 激活函数类型
            final_dropout=False,  # 最后层不使用dropout
            inner_dim=ff_inner_dim,  # 内部维度
            bias=True,  # 使用偏置
        )

    def forward(  # 定义前向传播方法
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态
    # 该方法返回处理后的隐藏状态张量
    ) -> torch.Tensor:
        # 对隐藏状态进行归一化处理
        norm_hidden_states = self.norm1(hidden_states)
        # 计算注意力输出,使用归一化后的隐藏状态
        attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
        # 将注意力输出与原始隐藏状态相加,更新隐藏状态
        hidden_states = attn_output + hidden_states
        # 如果隐藏状态是四维,则去掉第一维
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
    
        # 通过前馈网络处理隐藏状态
        ff_output = self.ff(hidden_states)
        # 将前馈输出与当前隐藏状态相加,更新隐藏状态
        hidden_states = ff_output + hidden_states
        # 如果隐藏状态是四维,则去掉第一维
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)
    
        # 返回最终的隐藏状态
        return hidden_states
# 定义 I2VGenXL UNet 类,继承多个混入类以增加功能
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    I2VGenXL UNet。一个条件3D UNet模型,接收噪声样本、条件状态和时间步,
    返回与样本形状相同的输出。

    该模型继承自 [`ModelMixin`]。有关所有模型实现的通用方法(如下载或保存),
    请查看超类文档。

    参数:
        sample_size (`int` 或 `Tuple[int, int]`, *可选*, 默认值为 `None`):
            输入/输出样本的高度和宽度。
        in_channels (`int`, *可选*, 默认值为 4): 输入样本的通道数。
        out_channels (`int`, *可选*, 默认值为 4): 输出样本的通道数。
        down_block_types (`Tuple[str]`, *可选*, 默认值为 `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
            使用的下采样块的元组。
        up_block_types (`Tuple[str]`, *可选*, 默认值为 `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
            使用的上采样块的元组。
        block_out_channels (`Tuple[int]`, *可选*, 默认值为 `(320, 640, 1280, 1280)`):
            每个块的输出通道元组。
        layers_per_block (`int`, *可选*, 默认值为 2): 每个块的层数。
        norm_num_groups (`int`, *可选*, 默认值为 32): 用于归一化的组数。
            如果为 `None`,则跳过后处理中的归一化和激活层。
        cross_attention_dim (`int`, *可选*, 默认值为 1280): 跨注意力特征的维度。
        attention_head_dim (`int`, *可选*, 默认值为 64): 注意力头的维度。
        num_attention_heads (`int`, *可选*): 注意力头的数量。
    """

    # 设置不支持梯度检查点的属性为 False
    _supports_gradient_checkpointing = False

    @register_to_config
    # 初始化方法,接受多种可选参数以设置模型配置
    def __init__(
        self,
        sample_size: Optional[int] = None,  # 输入/输出样本大小,默认为 None
        in_channels: int = 4,  # 输入样本的通道数,默认为 4
        out_channels: int = 4,  # 输出样本的通道数,默认为 4
        down_block_types: Tuple[str, ...] = (  # 下采样块的类型,默认为指定的元组
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types: Tuple[str, ...] = (  # 上采样块的类型,默认为指定的元组
            "UpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
        ),
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),  # 每个块的输出通道,默认为指定的元组
        layers_per_block: int = 2,  # 每个块的层数,默认为 2
        norm_num_groups: Optional[int] = 32,  # 归一化组数,默认为 32
        cross_attention_dim: int = 1024,  # 跨注意力特征的维度,默认为 1024
        attention_head_dim: Union[int, Tuple[int]] = 64,  # 注意力头的维度,默认为 64
        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,  # 注意力头的数量,默认为 None
    @property
    # 该属性从 UNet2DConditionModel 的 attn_processors 复制
    # 定义返回注意力处理器的函数,返回类型为字典,键为字符串,值为 AttentionProcessor 对象
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 创建一个空字典,用于存储处理器
        processors = {}

        # 定义一个递归函数,用于添加处理器到字典
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否有 get_processor 方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为名称加上 ".processor"
                processors[f"{name}.processor"] = module.get_processor()

            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用,处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新后的处理器字典
            return processors

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,将处理器添加到字典中
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有处理器的字典
        return processors

    # 从 diffusers.models.unets.unet_2d_condition 中复制的设置注意力处理器的函数
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())

        # 如果传入的是字典且数量不匹配,则引发错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定义一个递归函数,用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否有 set_processor 方法
            if hasattr(module, "set_processor"):
                # 如果传入的处理器不是字典,直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中移除并设置对应的处理器
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用,设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,为每个模块设置处理器
            fn_recursive_attn_processor(name, module, processor)

    # 从 diffusers.models.unets.unet_3d_condition 中复制的启用前向分块的函数
    # 启用前馈层的分块处理
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        设置注意力处理器使用[前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。

        参数:
            chunk_size (`int`, *可选*):
                前馈层的块大小。如果未指定,将对维度为`dim`的每个张量单独运行前馈层。
            dim (`int`, *可选*, 默认为`0`):
                前馈计算应分块的维度。可以选择dim=0(批次)或dim=1(序列长度)。
        """
        # 检查维度是否在有效范围内
        if dim not in [0, 1]:
            # 抛出错误,确保dim只为0或1
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # 默认块大小为1
        chunk_size = chunk_size or 1

        # 定义递归函数,用于设置每个模块的前馈分块
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模块有设置分块前馈的方法,调用该方法
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 递归遍历子模块
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 对当前对象的所有子模块应用前馈分块设置
        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # 从diffusers.models.unets.unet_3d_condition.UNet3DConditionModel复制的禁用前馈分块的方法
    def disable_forward_chunking(self):
        # 定义递归函数,用于禁用模块的前馈分块
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模块有设置分块前馈的方法,调用该方法
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 递归遍历子模块
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 对当前对象的所有子模块应用禁用前馈分块设置
        for module in self.children():
            fn_recursive_feed_forward(module, None, 0)

    # 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel复制的设置默认注意力处理器的方法
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 检查所有注意力处理器是否属于已添加的KV注意力处理器类
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,则设置为已添加KV处理器
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器类
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,则设置为标准注意力处理器
            processor = AttnProcessor()
        else:
            # 抛出错误,说明当前的注意力处理器类型不被支持
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 设置当前对象的注意力处理器为选择的处理器
        self.set_attn_processor(processor)

    # 从diffusers.models.unets.unet_3d_condition.UNet3DConditionModel复制的设置梯度检查点的方法
    # 设置梯度检查点,指定模块和布尔值
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 检查模块是否为指定的类型之一
        if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
            # 设置模块的梯度检查点属性为指定值
            module.gradient_checkpointing = value

    # 从 UNet2DConditionModel 中复制的启用 FreeU 方法
    def enable_freeu(self, s1, s2, b1, b2):
        r"""启用 FreeU 机制,详情见 https://arxiv.org/abs/2309.11497.

        后缀表示缩放因子应用的阶段块。

        请参考 [官方库](https://github.com/ChenyangSi/FreeU) 以获取适用于不同管道(如 Stable Diffusion v1, v2 和 Stable Diffusion XL)的有效值组合。

        参数:
            s1 (`float`):
                阶段 1 的缩放因子,用于减弱跳过特征的贡献,以缓解增强去噪过程中的“过平滑效应”。
            s2 (`float`):
                阶段 2 的缩放因子,用于减弱跳过特征的贡献,以缓解增强去噪过程中的“过平滑效应”。
            b1 (`float`): 阶段 1 的缩放因子,用于放大主干特征的贡献。
            b2 (`float`): 阶段 2 的缩放因子,用于放大主干特征的贡献。
        """
        # 遍历上采样块,索引 i 和块对象 upsample_block
        for i, upsample_block in enumerate(self.up_blocks):
            # 设置上采样块的属性 s1 为给定值 s1
            setattr(upsample_block, "s1", s1)
            # 设置上采样块的属性 s2 为给定值 s2
            setattr(upsample_block, "s2", s2)
            # 设置上采样块的属性 b1 为给定值 b1
            setattr(upsample_block, "b1", b1)
            # 设置上采样块的属性 b2 为给定值 b2
            setattr(upsample_block, "b2", b2)

    # 从 UNet2DConditionModel 中复制的禁用 FreeU 方法
    def disable_freeu(self):
        """禁用 FreeU 机制。"""
        # 定义 FreeU 相关的属性键
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍历上采样块,索引 i 和块对象 upsample_block
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历 FreeU 属性键
            for k in freeu_keys:
                # 如果上采样块具有该属性或属性值不为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 将上采样块的该属性设置为 None
                    setattr(upsample_block, k, None)

    # 从 UNet2DConditionModel 中复制的融合 QKV 投影方法
    # 定义一个方法,用于启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        # 提供方法的文档字符串,描述其功能和警告信息
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
        """
        # 初始化原始注意力处理器为 None
        self.original_attn_processors = None
    
        # 遍历当前对象的注意力处理器
        for _, attn_processor in self.attn_processors.items():
            # 检查处理器类名中是否包含 "Added"
            if "Added" in str(attn_processor.__class__.__name__):
                # 如果包含,抛出异常提示不支持
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
    
        # 保存当前的注意力处理器以备后用
        self.original_attn_processors = self.attn_processors
    
        # 遍历当前对象的所有模块
        for module in self.modules():
            # 检查模块是否为 Attention 类型
            if isinstance(module, Attention):
                # 调用模块的方法,启用融合投影
                module.fuse_projections(fuse=True)
    
        # 设置注意力处理器为 FusedAttnProcessor2_0 的实例
        self.set_attn_processor(FusedAttnProcessor2_0())
    
    # 从 UNet2DConditionModel 复制的方法,用于禁用融合的 QKV 投影
    def unfuse_qkv_projections(self):
        # 提供方法的文档字符串,描述其功能和警告信息
        """Disables the fused QKV projection if enabled.
    
        <Tip warning={true}>
    
        This API is 🧪 experimental.
    
        </Tip>
    
        """
        # 检查原始注意力处理器是否不为 None
        if self.original_attn_processors is not None:
            # 如果不为 None,恢复原始的注意力处理器
            self.set_attn_processor(self.original_attn_processors)
    
    # 定义前向传播方法,接受多个输入参数
    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        fps: torch.Tensor,
        image_latents: torch.Tensor,
        image_embeddings: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        return_dict: bool = True,

.\diffusers\models\unets\unet_kandinsky3.py

# 版权声明,指明该文件属于 HuggingFace 团队,所有权利保留
# 
# 根据 Apache License 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律要求或书面同意,否则根据许可证分发的软件
# 是在“原样”基础上分发的,不附带任何形式的保证或条件。
# 有关特定语言的许可条款和条件,请参见许可证。

from dataclasses import dataclass  # 从 dataclasses 模块导入 dataclass 装饰器
from typing import Dict, Tuple, Union  # 导入用于类型提示的字典、元组和联合类型

import torch  # 导入 PyTorch 库
import torch.utils.checkpoint  # 导入 PyTorch 的检查点工具
from torch import nn  # 从 PyTorch 导入神经网络模块

from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具导入混合类和注册函数
from ...utils import BaseOutput, logging  # 从工具模块导入基础输出类和日志功能
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor  # 导入注意力处理器相关类
from ..embeddings import TimestepEmbedding, Timesteps  # 导入时间步嵌入相关类
from ..modeling_utils import ModelMixin  # 导入模型混合类

logger = logging.get_logger(__name__)  # 创建一个记录器,用于当前模块的日志记录

@dataclass  # 将该类标记为数据类,以简化初始化和表示
class Kandinsky3UNetOutput(BaseOutput):  # 定义 Kandinsky3UNetOutput 类,继承自 BaseOutput
    sample: torch.Tensor = None  # 定义输出样本,默认为 None

class Kandinsky3EncoderProj(nn.Module):  # 定义 Kandinsky3EncoderProj 类,继承自 nn.Module
    def __init__(self, encoder_hid_dim, cross_attention_dim):  # 初始化方法,接收隐藏维度和交叉注意力维度
        super().__init__()  # 调用父类的初始化方法
        self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False)  # 定义线性投影层,不使用偏置
        self.projection_norm = nn.LayerNorm(cross_attention_dim)  # 定义层归一化层

    def forward(self, x):  # 定义前向传播方法
        x = self.projection_linear(x)  # 通过线性层处理输入
        x = self.projection_norm(x)  # 通过层归一化处理输出
        return x  # 返回处理后的结果

class Kandinsky3UNet(ModelMixin, ConfigMixin):  # 定义 Kandinsky3UNet 类,继承自 ModelMixin 和 ConfigMixin
    @register_to_config  # 将该方法注册到配置中
    def __init__(  # 初始化方法
        self,
        in_channels: int = 4,  # 输入通道数,默认值为 4
        time_embedding_dim: int = 1536,  # 时间嵌入维度,默认值为 1536
        groups: int = 32,  # 组数,默认值为 32
        attention_head_dim: int = 64,  # 注意力头维度,默认值为 64
        layers_per_block: Union[int, Tuple[int]] = 3,  # 每个块的层数,默认值为 3,可以是整数或元组
        block_out_channels: Tuple[int] = (384, 768, 1536, 3072),  # 块输出通道,默认为指定元组
        cross_attention_dim: Union[int, Tuple[int]] = 4096,  # 交叉注意力维度,默认值为 4096
        encoder_hid_dim: int = 4096,  # 编码器隐藏维度,默认值为 4096
    @property  # 定义一个属性
    def attn_processors(self) -> Dict[str, AttentionProcessor]:  # 返回注意力处理器字典
        r"""  # 文档字符串,描述该方法的功能
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 设置一个空字典以递归存储处理器
        processors = {}

        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):  # 定义递归函数添加处理器
            if hasattr(module, "set_processor"):  # 检查模块是否具有 set_processor 属性
                processors[f"{name}.processor"] = module.processor  # 将处理器添加到字典中

            for sub_name, child in module.named_children():  # 遍历模块的所有子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)  # 递归调用自身

            return processors  # 返回更新后的处理器字典

        for name, module in self.named_children():  # 遍历当前类的所有子模块
            fn_recursive_add_processors(name, module, processors)  # 调用递归函数

        return processors  # 返回包含所有处理器的字典
    # 定义设置注意力处理器的方法,参数为处理器,可以是 AttentionProcessor 类或其字典形式
        def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
            r"""
            设置用于计算注意力的处理器。
    
            参数:
                processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                    实例化的处理器类或处理器类的字典,将作为 **所有** `Attention` 层的处理器。
    
                    如果 `processor` 是一个字典,键需要定义相应交叉注意力处理器的路径。这在设置可训练注意力处理器时强烈推荐。
    
            """
            # 获取当前注意力处理器的数量
            count = len(self.attn_processors.keys())
    
            # 如果传入的是字典且其长度与注意力层的数量不匹配,则抛出错误
            if isinstance(processor, dict) and len(processor) != count:
                raise ValueError(
                    f"传入了处理器字典,但处理器的数量 {len(processor)} 与"
                    f" 注意力层的数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
                )
    
            # 定义递归设置注意力处理器的方法
            def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
                # 如果模块有设置处理器的方法
                if hasattr(module, "set_processor"):
                    # 如果处理器不是字典,则直接设置
                    if not isinstance(processor, dict):
                        module.set_processor(processor)
                    else:
                        # 从字典中获取对应的处理器并设置
                        module.set_processor(processor.pop(f"{name}.processor"))
    
                # 遍历模块的所有子模块
                for sub_name, child in module.named_children():
                    # 递归调用处理子模块
                    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
            # 遍历当前对象的所有子模块
            for name, module in self.named_children():
                # 递归设置每个子模块的处理器
                fn_recursive_attn_processor(name, module, processor)
    
        # 定义设置默认注意力处理器的方法
        def set_default_attn_processor(self):
            """
            禁用自定义注意力处理器,并设置默认的注意力实现。
            """
            # 调用设置注意力处理器的方法,使用默认的 AttnProcessor 实例
            self.set_attn_processor(AttnProcessor())
    
        # 定义设置梯度检查点的方法
        def _set_gradient_checkpointing(self, module, value=False):
            # 如果模块有梯度检查点的属性
            if hasattr(module, "gradient_checkpointing"):
                # 设置该属性为指定的值
                module.gradient_checkpointing = value
    # 定义前向传播函数,接收样本、时间步以及可选的编码器隐藏状态和注意力掩码
    def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
        # 如果存在编码器注意力掩码,则进行调整以适应后续计算
        if encoder_attention_mask is not None:
            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
            # 增加一个维度,以便后续处理
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
    
        # 检查时间步是否为张量类型
        if not torch.is_tensor(timestep):
            # 根据时间步类型确定数据类型
            dtype = torch.float32 if isinstance(timestep, float) else torch.int32
            # 将时间步转换为张量并指定设备
            timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
        # 如果时间步为标量,则扩展为一维张量
        elif len(timestep.shape) == 0:
            timestep = timestep[None].to(sample.device)
    
        # 扩展时间步到与批量维度兼容的形状
        timestep = timestep.expand(sample.shape[0])
        # 通过时间投影获取时间嵌入输入并转换为样本的数据类型
        time_embed_input = self.time_proj(timestep).to(sample.dtype)
        # 获取时间嵌入
        time_embed = self.time_embedding(time_embed_input)
    
        # 对编码器隐藏状态进行线性变换
        encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
    
        # 如果存在编码器隐藏状态,则将时间嵌入与隐藏状态结合
        if encoder_hidden_states is not None:
            time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
    
        # 初始化隐藏状态列表
        hidden_states = []
        # 对输入样本进行初步卷积处理
        sample = self.conv_in(sample)
        # 遍历下采样块
        for level, down_sample in enumerate(self.down_blocks):
            # 通过下采样块处理样本
            sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
            # 如果不是最后一个层级,记录当前样本状态
            if level != self.num_levels - 1:
                hidden_states.append(sample)
    
        # 遍历上采样块
        for level, up_sample in enumerate(self.up_blocks):
            # 如果不是第一个层级,则拼接当前样本与之前的隐藏状态
            if level != 0:
                sample = torch.cat([sample, hidden_states.pop()], dim=1)
            # 通过上采样块处理样本
            sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
    
        # 进行输出卷积规范化
        sample = self.conv_norm_out(sample)
        # 进行输出激活
        sample = self.conv_act_out(sample)
        # 进行最终输出卷积
        sample = self.conv_out(sample)
    
        # 根据返回标志返回相应的结果
        if not return_dict:
            return (sample,)
        # 返回结果对象
        return Kandinsky3UNetOutput(sample=sample)
# 定义 Kandinsky3UpSampleBlock 类,继承自 nn.Module
class Kandinsky3UpSampleBlock(nn.Module):
    # 初始化方法,设置各参数
    def __init__(
        self,
        in_channels,  # 输入通道数
        cat_dim,  # 拼接维度
        out_channels,  # 输出通道数
        time_embed_dim,  # 时间嵌入维度
        context_dim=None,  # 上下文维度,可选
        num_blocks=3,  # 块的数量
        groups=32,  # 分组数
        head_dim=64,  # 头维度
        expansion_ratio=4,  # 扩展比例
        compression_ratio=2,  # 压缩比例
        up_sample=True,  # 是否上采样
        self_attention=True,  # 是否使用自注意力
    ):
        # 调用父类初始化方法
        super().__init__()
        # 设置上采样分辨率
        up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
        # 设置隐藏通道数
        hidden_channels = (
            [(in_channels + cat_dim, in_channels)]  # 第一层的通道
            + [(in_channels, in_channels)] * (num_blocks - 2)  # 中间层的通道
            + [(in_channels, out_channels)]  # 最后一层的通道
        )
        attentions = []  # 用于存储注意力块
        resnets_in = []  # 用于存储输入 ResNet 块
        resnets_out = []  # 用于存储输出 ResNet 块

        # 设置自注意力和上下文维度
        self.self_attention = self_attention
        self.context_dim = context_dim

        # 如果使用自注意力,添加注意力块
        if self_attention:
            attentions.append(
                Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
            )
        else:
            attentions.append(nn.Identity())  # 否则添加身份映射

        # 遍历隐藏通道和上采样分辨率
        for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
            # 添加输入 ResNet 块
            resnets_in.append(
                Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
            )

            # 如果上下文维度不为 None,添加注意力块
            if context_dim is not None:
                attentions.append(
                    Kandinsky3AttentionBlock(
                        in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
                    )
                )
            else:
                attentions.append(nn.Identity())  # 否则添加身份映射

            # 添加输出 ResNet 块
            resnets_out.append(
                Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
            )

        # 将注意力块和 ResNet 块转换为模块列表
        self.attentions = nn.ModuleList(attentions)
        self.resnets_in = nn.ModuleList(resnets_in)
        self.resnets_out = nn.ModuleList(resnets_out)

    # 前向传播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 遍历注意力块和 ResNet 块进行前向计算
        for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
            x = resnet_in(x, time_embed)  # 输入经过 ResNet 块
            if self.context_dim is not None:  # 如果上下文维度存在
                x = attention(x, time_embed, context, context_mask, image_mask)  # 应用注意力块
            x = resnet_out(x, time_embed)  # 输出经过 ResNet 块

        # 如果使用自注意力,应用首个注意力块
        if self.self_attention:
            x = self.attentions[0](x, time_embed, image_mask=image_mask)
        return x  # 返回处理后的结果


# 定义 Kandinsky3DownSampleBlock 类,继承自 nn.Module
class Kandinsky3DownSampleBlock(nn.Module):
    # 初始化方法,设置各参数
    def __init__(
        self,
        in_channels,  # 输入通道数
        out_channels,  # 输出通道数
        time_embed_dim,  # 时间嵌入维度
        context_dim=None,  # 上下文维度,可选
        num_blocks=3,  # 块的数量
        groups=32,  # 分组数
        head_dim=64,  # 头维度
        expansion_ratio=4,  # 扩展比例
        compression_ratio=2,  # 压缩比例
        down_sample=True,  # 是否下采样
        self_attention=True,  # 是否使用自注意力
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化注意力模块列表
        attentions = []
        # 初始化输入残差块列表
        resnets_in = []
        # 初始化输出残差块列表
        resnets_out = []

        # 保存自注意力标志
        self.self_attention = self_attention
        # 保存上下文维度
        self.context_dim = context_dim

        # 如果启用自注意力
        if self_attention:
            # 添加 Kandinsky3AttentionBlock 到注意力列表
            attentions.append(
                Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
            )
        else:
            # 否则添加身份层(不改变输入)
            attentions.append(nn.Identity())

        # 生成上采样分辨率列表
        up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
        # 生成隐藏通道的元组列表
        hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
        # 遍历隐藏通道和上采样分辨率
        for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
            # 添加输入残差块到列表
            resnets_in.append(
                Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
            )

            # 如果上下文维度不为 None
            if context_dim is not None:
                # 添加 Kandinsky3AttentionBlock 到注意力列表
                attentions.append(
                    Kandinsky3AttentionBlock(
                        out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
                    )
                )
            else:
                # 否则添加身份层(不改变输入)
                attentions.append(nn.Identity())

            # 添加输出残差块到列表
            resnets_out.append(
                Kandinsky3ResNetBlock(
                    out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
                )
            )

        # 将注意力模块列表转换为 nn.ModuleList 以便管理
        self.attentions = nn.ModuleList(attentions)
        # 将输入残差块列表转换为 nn.ModuleList 以便管理
        self.resnets_in = nn.ModuleList(resnets_in)
        # 将输出残差块列表转换为 nn.ModuleList 以便管理
        self.resnets_out = nn.ModuleList(resnets_out)

    # 定义前向传播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 如果启用自注意力
        if self.self_attention:
            # 使用第一个注意力模块处理输入
            x = self.attentions[0](x, time_embed, image_mask=image_mask)

        # 遍历剩余的注意力模块、输入和输出残差块
        for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
            # 通过输入残差块处理输入
            x = resnet_in(x, time_embed)
            # 如果上下文维度不为 None
            if self.context_dim is not None:
                # 使用当前注意力模块处理输入
                x = attention(x, time_embed, context, context_mask, image_mask)
            # 通过输出残差块处理输入
            x = resnet_out(x, time_embed)
        # 返回处理后的输出
        return x
# 定义 Kandinsky3ConditionalGroupNorm 类,继承自 nn.Module
class Kandinsky3ConditionalGroupNorm(nn.Module):
    # 初始化方法,设置分组数、标准化形状和上下文维度
    def __init__(self, groups, normalized_shape, context_dim):
        # 调用父类构造函数
        super().__init__()
        # 创建分组归一化层,不使用仿射变换
        self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
        # 定义上下文多层感知机,包含 SiLU 激活和线性层
        self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
        # 将线性层的权重初始化为零
        self.context_mlp[1].weight.data.zero_()
        # 将线性层的偏置初始化为零
        self.context_mlp[1].bias.data.zero_()

    # 前向传播方法,接收输入和上下文
    def forward(self, x, context):
        # 通过上下文多层感知机处理上下文
        context = self.context_mlp(context)

        # 为了匹配输入的维度,逐层扩展上下文
        for _ in range(len(x.shape[2:])):
            context = context.unsqueeze(-1)

        # 将上下文分割为缩放和偏移量
        scale, shift = context.chunk(2, dim=1)
        # 应用归一化并进行缩放和偏移
        x = self.norm(x) * (scale + 1.0) + shift
        # 返回处理后的输入
        return x


# 定义 Kandinsky3Block 类,继承自 nn.Module
class Kandinsky3Block(nn.Module):
    # 初始化方法,设置输入通道、输出通道、时间嵌入维度等参数
    def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
        # 调用父类构造函数
        super().__init__()
        # 创建条件分组归一化层
        self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
        # 定义 SiLU 激活函数
        self.activation = nn.SiLU()
        # 如果需要上采样,使用转置卷积进行上采样
        if up_resolution is not None and up_resolution:
            self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        else:
            # 否则使用恒等映射
            self.up_sample = nn.Identity()

        # 根据卷积核大小确定填充
        padding = int(kernel_size > 1)
        # 定义卷积投影层
        self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        # 如果不需要上采样,定义下采样卷积层
        if up_resolution is not None and not up_resolution:
            self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
        else:
            # 否则使用恒等映射
            self.down_sample = nn.Identity()

    # 前向传播方法,接收输入和时间嵌入
    def forward(self, x, time_embed):
        # 通过条件分组归一化处理输入
        x = self.group_norm(x, time_embed)
        # 应用激活函数
        x = self.activation(x)
        # 进行上采样
        x = self.up_sample(x)
        # 通过卷积投影层处理输入
        x = self.projection(x)
        # 进行下采样
        x = self.down_sample(x)
        # 返回处理后的输出
        return x


# 定义 Kandinsky3ResNetBlock 类,继承自 nn.Module
class Kandinsky3ResNetBlock(nn.Module):
    # 初始化方法,设置输入通道、输出通道、时间嵌入维度等参数
    def __init__(
        self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
    # 初始化父类
        ):
            super().__init__()
            # 定义卷积核的大小
            kernel_sizes = [1, 3, 3, 1]
            # 计算隐藏通道数
            hidden_channel = max(in_channels, out_channels) // compression_ratio
            # 构建隐藏通道的元组列表
            hidden_channels = (
                [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
            )
            # 创建包含多个 Kandinsky3Block 的模块列表
            self.resnet_blocks = nn.ModuleList(
                [
                    Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
                    # 将隐藏通道、卷积核大小和上采样分辨率组合在一起
                    for (in_channel, out_channel), kernel_size, up_resolution in zip(
                        hidden_channels, kernel_sizes, up_resolutions
                    )
                ]
            )
            # 定义上采样的快捷连接
            self.shortcut_up_sample = (
                nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
                # 如果存在上采样分辨率,则使用反卷积;否则使用恒等映射
                if True in up_resolutions
                else nn.Identity()
            )
            # 定义通道数不同时的投影连接
            self.shortcut_projection = (
                nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
            )
            # 定义下采样的快捷连接
            self.shortcut_down_sample = (
                nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
                # 如果存在下采样分辨率,则使用卷积;否则使用恒等映射
                if False in up_resolutions
                else nn.Identity()
            )
    
        # 前向传播方法
        def forward(self, x, time_embed):
            # 初始化输出为输入
            out = x
            # 依次通过每个 ResNet 块
            for resnet_block in self.resnet_blocks:
                out = resnet_block(out, time_embed)
    
            # 上采样输入
            x = self.shortcut_up_sample(x)
            # 投影输入到输出通道
            x = self.shortcut_projection(x)
            # 下采样输入
            x = self.shortcut_down_sample(x)
            # 将输出与处理后的输入相加
            x = x + out
            # 返回最终输出
            return x
# 定义 Kandinsky3AttentionPooling 类,继承自 nn.Module
class Kandinsky3AttentionPooling(nn.Module):
    # 初始化方法,接受通道数、上下文维度和头维度
    def __init__(self, num_channels, context_dim, head_dim=64):
        # 调用父类构造函数
        super().__init__()
        # 创建注意力机制对象,指定输入和输出维度及其他参数
        self.attention = Attention(
            context_dim,
            context_dim,
            dim_head=head_dim,
            out_dim=num_channels,
            out_bias=False,
        )

    # 前向传播方法
    def forward(self, x, context, context_mask=None):
        # 将上下文掩码转换为与上下文相同的数据类型
        context_mask = context_mask.to(dtype=context.dtype)
        # 使用注意力机制计算上下文与其平均值的加权和
        context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
        # 返回输入与上下文的和
        return x + context.squeeze(1)


# 定义 Kandinsky3AttentionBlock 类,继承自 nn.Module
class Kandinsky3AttentionBlock(nn.Module):
    # 初始化方法,接受多种参数
    def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
        # 调用父类构造函数
        super().__init__()
        # 创建条件组归一化对象,用于输入规范化
        self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
        # 创建注意力机制对象,指定输入和输出维度及其他参数
        self.attention = Attention(
            num_channels,
            context_dim or num_channels,
            dim_head=head_dim,
            out_dim=num_channels,
            out_bias=False,
        )

        # 计算隐藏通道数,作为扩展比和通道数的乘积
        hidden_channels = expansion_ratio * num_channels
        # 创建条件组归一化对象,用于输出规范化
        self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
        # 定义前馈网络,包含两个卷积层和激活函数
        self.feed_forward = nn.Sequential(
            nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
            nn.SiLU(),
            nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
        )

    # 前向传播方法
    def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
        # 获取输入的高度和宽度
        height, width = x.shape[-2:]
        # 对输入进行归一化处理
        out = self.in_norm(x, time_embed)
        # 将输出重塑为适合注意力机制的形状
        out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
        # 如果没有上下文,则使用当前的输出作为上下文
        context = context if context is not None else out
        # 如果存在上下文掩码,转换为与上下文相同的数据类型
        if context_mask is not None:
            context_mask = context_mask.to(dtype=context.dtype)

        # 使用注意力机制处理输出和上下文
        out = self.attention(out, context, context_mask)
        # 重塑输出为原始输入形状
        out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
        # 将处理后的输出与原输入相加
        x = x + out

        # 对相加后的结果进行输出归一化
        out = self.out_norm(x, time_embed)
        # 通过前馈网络处理归一化输出
        out = self.feed_forward(out)
        # 将处理后的输出与相加后的输入相加
        x = x + out
        # 返回最终输出
        return x

.\diffusers\models\unets\unet_motion_model.py

# 版权声明,表明该文件的所有权及相关使用条款
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache License, Version 2.0 (“许可证”) 授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,否则根据许可证分发的软件
# 是在“按原样”基础上分发的,不提供任何形式的保证或条件,
# 无论是明示还是暗示。
# 有关许可证所管辖的权限和限制,请参见许可证。
#
# 导入所需的库和模块
from dataclasses import dataclass  # 导入数据类装饰器
from typing import Any, Dict, Optional, Tuple, Union  # 导入类型提示相关的类型

import torch  # 导入 PyTorch 库
import torch.nn as nn  # 导入 PyTorch 的神经网络模块
import torch.nn.functional as F  # 导入 PyTorch 的功能性神经网络模块
import torch.utils.checkpoint  # 导入 PyTorch 的检查点功能

# 导入自定义的配置和加载工具
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, is_torch_version, logging  # 导入常用的工具函数
from ...utils.torch_utils import apply_freeu  # 导入应用 FreeU 的工具函数
from ..attention import BasicTransformerBlock  # 导入基础变换器模块
from ..attention_processor import (  # 导入注意力处理器相关的类
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    AttnProcessor2_0,
    FusedAttnProcessor2_0,
    IPAdapterAttnProcessor,
    IPAdapterAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps  # 导入时间步嵌入相关的类
from ..modeling_utils import ModelMixin  # 导入模型混合工具
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D  # 导入 ResNet 相关的模块
from ..transformers.dual_transformer_2d import DualTransformer2DModel  # 导入双重变换器模型
from ..transformers.transformer_2d import Transformer2DModel  # 导入 2D 变换器模型
from .unet_2d_blocks import UNetMidBlock2DCrossAttn  # 导入 U-Net 中间块
from .unet_2d_condition import UNet2DConditionModel  # 导入条件 U-Net 模型

logger = logging.get_logger(__name__)  # 获取当前模块的日志记录器,便于调试和日志输出

@dataclass
class UNetMotionOutput(BaseOutput):  # 定义 UNetMotionOutput 数据类,继承自 BaseOutput
    """
    [`UNetMotionOutput`] 的输出。

    参数:
        sample (`torch.Tensor` 的形状为 `(batch_size, num_channels, num_frames, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
    """

    sample: torch.Tensor  # 定义 sample 属性,类型为 torch.Tensor


class AnimateDiffTransformer3D(nn.Module):  # 定义 AnimateDiffTransformer3D 类,继承自 nn.Module
    """
    一个用于视频类数据的变换器模型。
    # 参数说明部分,描述初始化函数中每个参数的用途
    Parameters:
        # 多头注意力机制中头的数量,默认为16
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        # 每个头中的通道数,默认为88
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        # 输入和输出的通道数,如果输入是**连续**,则需要指定
        in_channels (`int`, *optional*):
            The number of channels in the input and output (specify if the input is **continuous**).
        # Transformer块的层数,默认为1
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        # dropout概率,默认为0.0
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        # 使用的`encoder_hidden_states`维度数
        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
        # 配置`TransformerBlock`的注意力是否包含偏置参数
        attention_bias (`bool`, *optional*):
            Configure if the `TransformerBlock` attention should contain a bias parameter.
        # 潜在图像的宽度,如果输入是**离散**,则需要指定
        sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
            # 该值在训练期间固定,用于学习位置嵌入的数量
            This is fixed during training since it is used to learn a number of position embeddings.
        # 前馈中的激活函数,默认为"geglu"
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
            activation functions.
        # 配置`TransformerBlock`是否使用可学习的逐元素仿射参数进行归一化
        norm_elementwise_affine (`bool`, *optional*):
            Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
        # 配置每个`TransformerBlock`是否包含两个自注意力层
        double_self_attention (`bool`, *optional*):
            Configure if each `TransformerBlock` should contain two self-attention layers.
        # 应用到序列输入的位置信息嵌入的类型
        positional_embeddings: (`str`, *optional*):
            The type of positional embeddings to apply to the sequence input before passing use.
        # 应用位置嵌入的最大序列长度
        num_positional_embeddings: (`int`, *optional*):
            The maximum length of the sequence over which to apply positional embeddings.
    """

    # 初始化方法定义
    def __init__(
        # 多头注意力机制中头的数量,默认为16
        self,
        num_attention_heads: int = 16,
        # 每个头中的通道数,默认为88
        attention_head_dim: int = 88,
        # 输入通道数,可选
        in_channels: Optional[int] = None,
        # 输出通道数,可选
        out_channels: Optional[int] = None,
        # Transformer块的层数,默认为1
        num_layers: int = 1,
        # dropout概率,默认为0.0
        dropout: float = 0.0,
        # 归一化分组数,默认为32
        norm_num_groups: int = 32,
        # 使用的`encoder_hidden_states`维度数,可选
        cross_attention_dim: Optional[int] = None,
        # 注意力是否包含偏置参数,默认为False
        attention_bias: bool = False,
        # 潜在图像的宽度,可选
        sample_size: Optional[int] = None,
        # 前馈中的激活函数,默认为"geglu"
        activation_fn: str = "geglu",
        # 归一化是否使用可学习的逐元素仿射参数,默认为True
        norm_elementwise_affine: bool = True,
        # 每个`TransformerBlock`是否包含两个自注意力层,默认为True
        double_self_attention: bool = True,
        # 位置信息嵌入的类型,可选
        positional_embeddings: Optional[str] = None,
        # 应用位置嵌入的最大序列长度,可选
        num_positional_embeddings: Optional[int] = None,
    ):
        # 调用父类的构造函数以初始化父类的属性
        super().__init__()
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads
        # 设置每个注意力头的维度
        self.attention_head_dim = attention_head_dim
        # 计算内部维度,等于注意力头数量与每个注意力头维度的乘积
        inner_dim = num_attention_heads * attention_head_dim

        # 设置输入通道数
        self.in_channels = in_channels

        # 定义归一化层,使用组归一化,允许可学习的偏移
        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
        # 定义输入线性变换层,将输入通道映射到内部维度
        self.proj_in = nn.Linear(in_channels, inner_dim)

        # 3. 定义变换器块
        self.transformer_blocks = nn.ModuleList(
            [
                # 创建指定数量的基本变换器块
                BasicTransformerBlock(
                    inner_dim,
                    num_attention_heads,
                    attention_head_dim,
                    dropout=dropout,
                    cross_attention_dim=cross_attention_dim,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    double_self_attention=double_self_attention,
                    norm_elementwise_affine=norm_elementwise_affine,
                    positional_embeddings=positional_embeddings,
                    num_positional_embeddings=num_positional_embeddings,
                )
                # 遍历创建 num_layers 个基本变换器块
                for _ in range(num_layers)
            ]
        )

        # 定义输出线性变换层,将内部维度映射回输入通道数
        self.proj_out = nn.Linear(inner_dim, in_channels)

    def forward(
        # 定义前向传播方法的输入参数
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        encoder_hidden_states: Optional[torch.LongTensor] = None,  # 编码器的隐藏状态,默认为 None
        timestep: Optional[torch.LongTensor] = None,  # 时间步,默认为 None
        class_labels: Optional[torch.LongTensor] = None,  # 类标签,默认为 None
        num_frames: int = 1,  # 帧数,默认值为 1
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 跨注意力参数,默认为 None
    # 该方法用于 [`AnimateDiffTransformer3D`] 的前向传播
    
        ) -> torch.Tensor:
            """
            方法参数说明:
                hidden_states (`torch.LongTensor`): 输入的隐状态,形状为 `(batch size, num latent pixels)` 或 `(batch size, channel, height, width)` 
                encoder_hidden_states ( `torch.LongTensor`, *可选*): 
                    交叉注意力层的条件嵌入。如果未提供,交叉注意力将默认使用自注意力。
                timestep ( `torch.LongTensor`, *可选*): 
                    用于指示去噪步骤的时间戳。
                class_labels ( `torch.LongTensor`, *可选*): 
                    用于指示类别标签的条件嵌入。
                num_frames (`int`, *可选*, 默认为 1): 
                    每个批次处理的帧数,用于重新形状隐状态。
                cross_attention_kwargs (`dict`, *可选*): 
                    可选的关键字字典,传递给 `AttentionProcessor`。
            返回值:
                torch.Tensor: 
                    输出张量。
            """
            # 1. 输入
            # 获取输入隐状态的形状信息
            batch_frames, channel, height, width = hidden_states.shape
            # 计算批次大小
            batch_size = batch_frames // num_frames
    
            # 将隐状态保留用于残差连接
            residual = hidden_states
    
            # 调整隐状态的形状以适应批次和帧数
            hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
            # 调整维度顺序以便后续处理
            hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
    
            # 对隐状态进行规范化
            hidden_states = self.norm(hidden_states)
            # 再次调整维度顺序并重塑为适当的形状
            hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
    
            # 输入层投影
            hidden_states = self.proj_in(hidden_states)
    
            # 2. 处理块
            # 遍历每个变换块以处理隐状态
            for block in self.transformer_blocks:
                hidden_states = block(
                    hidden_states,  # 当前的隐状态
                    encoder_hidden_states=encoder_hidden_states,  # 可选的编码器隐状态
                    timestep=timestep,  # 可选的时间戳
                    cross_attention_kwargs=cross_attention_kwargs,  # 可选的交叉注意力参数
                    class_labels=class_labels,  # 可选的类标签
                )
    
            # 3. 输出
            # 输出层投影
            hidden_states = self.proj_out(hidden_states)
            # 调整输出张量的形状
            hidden_states = (
                hidden_states[None, None, :]  # 添加维度
                .reshape(batch_size, height, width, num_frames, channel)  # 重塑为适当形状
                .permute(0, 3, 4, 1, 2)  # 调整维度顺序
                .contiguous()  # 确保内存连续性
            )
            # 最终调整输出的形状
            hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
    
            # 将残差添加到输出中以形成最终输出
            output = hidden_states + residual
            # 返回最终的输出张量
            return output
# 定义一个名为 DownBlockMotion 的类,继承自 nn.Module
class DownBlockMotion(nn.Module):
    # 初始化方法,定义多个参数,包括输入输出通道、dropout 率等
    def __init__(
        self,
        in_channels: int,  # 输入通道数量
        out_channels: int,  # 输出通道数量
        temb_channels: int,  # 时间嵌入通道数量
        dropout: float = 0.0,  # dropout 率,默认为 0
        num_layers: int = 1,  # 网络层数,默认为 1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 参数
        resnet_time_scale_shift: str = "default",  # ResNet 时间尺度偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数,默认为 swish
        resnet_groups: int = 32,  # ResNet 组数,默认为 32
        resnet_pre_norm: bool = True,  # ResNet 是否使用预归一化
        output_scale_factor: float = 1.0,  # 输出缩放因子
        add_downsample: bool = True,  # 是否添加下采样层
        downsample_padding: int = 1,  # 下采样时的填充
        temporal_num_attention_heads: Union[int, Tuple[int]] = 1,  # 时间注意力头数
        temporal_cross_attention_dim: Optional[int] = None,  # 时间交叉注意力维度
        temporal_max_seq_length: int = 32,  # 最大序列长度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每个块的变换器层数
        temporal_double_self_attention: bool = True,  # 是否双重自注意力
    ):
    # 前向传播方法,接收隐藏状态和时间嵌入等参数
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        num_frames: int = 1,  # 帧数,默认为 1
        *args,  # 接受任意位置参数
        **kwargs,  # 接受任意关键字参数
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 返回隐藏状态和输出状态的张量或元组
        # 检查位置参数或关键字参数中的 scale 是否被传递
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定义弃用信息,提示用户 scale 参数将被忽略
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 调用弃用函数,发出警告
            deprecate("scale", "1.0.0", deprecation_message)

        # 初始化输出状态为一个空元组
        output_states = ()

        # 将 ResNet 和运动模块进行配对
        blocks = zip(self.resnets, self.motion_modules)
        # 遍历每对 ResNet 和运动模块
        for resnet, motion_module in blocks:
            # 如果处于训练模式且启用了梯度检查点
            if self.training and self.gradient_checkpointing:
                # 定义一个自定义前向传播函数
                def create_custom_forward(module):
                    def custom_forward(*inputs):  # 自定义前向函数,接受任意输入
                        return module(*inputs)  # 返回模块的输出

                    return custom_forward  # 返回自定义前向函数

                # 如果 PyTorch 版本大于等于 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用检查点机制来节省内存
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 创建的自定义前向函数
                        hidden_states,  # 输入的隐藏状态
                        temb,  # 输入的时间嵌入
                        use_reentrant=False,  # 不使用重入
                    )
                else:
                    # 在较早版本中也使用检查点机制
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )

            else:
                # 如果不是训练模式,直接通过 ResNet 处理隐藏状态
                hidden_states = resnet(hidden_states, temb)

            # 使用运动模块处理当前的隐藏状态
            hidden_states = motion_module(hidden_states, num_frames=num_frames)

            # 将当前隐藏状态添加到输出状态中
            output_states = output_states + (hidden_states,)

        # 如果下采样器不为空
        if self.downsamplers is not None:
            # 遍历每个下采样器
            for downsampler in self.downsamplers:
                # 通过下采样器处理隐藏状态
                hidden_states = downsampler(hidden_states)

            # 将下采样后的隐藏状态添加到输出状态中
            output_states = output_states + (hidden_states,)

        # 返回最终的隐藏状态和输出状态
        return hidden_states, output_states
    # 初始化方法,用于设置网络的参数
        def __init__(
            # 输入通道数量
            self,
            in_channels: int,
            # 输出通道数量
            out_channels: int,
            # 时间嵌入通道数量
            temb_channels: int,
            # dropout 概率,默认为 0.0
            dropout: float = 0.0,
            # 网络层数,默认为 1
            num_layers: int = 1,
            # 每个块中的变换器层数,默认为 1
            transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # ResNet 中的 epsilon 值,默认为 1e-6
            resnet_eps: float = 1e-6,
            # ResNet 时间尺度偏移,默认为 "default"
            resnet_time_scale_shift: str = "default",
            # ResNet 激活函数,默认为 "swish"
            resnet_act_fn: str = "swish",
            # ResNet 中的组数,默认为 32
            resnet_groups: int = 32,
            # 是否在 ResNet 中使用预归一化,默认为 True
            resnet_pre_norm: bool = True,
            # 注意力头的数量,默认为 1
            num_attention_heads: int = 1,
            # 交叉注意力维度,默认为 1280
            cross_attention_dim: int = 1280,
            # 输出缩放因子,默认为 1.0
            output_scale_factor: float = 1.0,
            # 下采样填充,默认为 1
            downsample_padding: int = 1,
            # 是否添加下采样层,默认为 True
            add_downsample: bool = True,
            # 是否使用双交叉注意力,默认为 False
            dual_cross_attention: bool = False,
            # 是否使用线性投影,默认为 False
            use_linear_projection: bool = False,
            # 是否仅使用交叉注意力,默认为 False
            only_cross_attention: bool = False,
            # 是否提升注意力计算精度,默认为 False
            upcast_attention: bool = False,
            # 注意力类型,默认为 "default"
            attention_type: str = "default",
            # 时间交叉注意力维度,可选参数
            temporal_cross_attention_dim: Optional[int] = None,
            # 时间注意力头数量,默认为 8
            temporal_num_attention_heads: int = 8,
            # 时间序列的最大长度,默认为 32
            temporal_max_seq_length: int = 32,
            # 时间变换器块中的层数,默认为 1
            temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # 是否使用双重自注意力,默认为 True
            temporal_double_self_attention: bool = True,
        # 前向传播方法,定义如何通过模型传递数据
        def forward(
            # 隐藏状态张量,输入到模型中的主要数据
            self,
            hidden_states: torch.Tensor,
            # 可选的时间嵌入张量
            temb: Optional[torch.Tensor] = None,
            # 可选的编码器隐藏状态
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 可选的注意力掩码
            attention_mask: Optional[torch.Tensor] = None,
            # 每次处理的帧数,默认为 1
            num_frames: int = 1,
            # 可选的编码器注意力掩码
            encoder_attention_mask: Optional[torch.Tensor] = None,
            # 可选的交叉注意力参数
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可选的额外残差连接
            additional_residuals: Optional[torch.Tensor] = None,
    ):
        # 检查 cross_attention_kwargs 是否不为空
        if cross_attention_kwargs is not None:
            # 检查 scale 参数是否存在,若存在则发出警告
            if cross_attention_kwargs.get("scale", None) is not None:
                logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")

        # 初始化输出状态为空元组
        output_states = ()

        # 将自残差网络、注意力模块和运动模块组合成一个列表
        blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
        # 遍历组合后的模块及其索引
        for i, (resnet, attn, motion_module) in enumerate(blocks):
            # 如果处于训练状态且启用了梯度检查点
            if self.training and self.gradient_checkpointing:

                # 定义自定义前向传播函数
                def create_custom_forward(module, return_dict=None):
                    # 定义实际的前向传播逻辑
                    def custom_forward(*inputs):
                        # 根据 return_dict 的值选择返回方式
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)

                    return custom_forward

                # 定义检查点参数字典,根据 PyTorch 版本设置
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用检查点机制计算隐藏状态
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    **ckpt_kwargs,
                )
                # 通过注意力模块处理隐藏状态
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            else:
                # 在非训练模式下直接通过残差网络处理隐藏状态
                hidden_states = resnet(hidden_states, temb)

                # 通过注意力模块处理隐藏状态
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    attention_mask=attention_mask,
                    encoder_attention_mask=encoder_attention_mask,
                    return_dict=False,
                )[0]
            # 通过运动模块处理隐藏状态
            hidden_states = motion_module(
                hidden_states,
                num_frames=num_frames,
            )

            # 如果是最后一对模块且有额外残差,则将其应用到隐藏状态
            if i == len(blocks) - 1 and additional_residuals is not None:
                hidden_states = hidden_states + additional_residuals

            # 将当前隐藏状态添加到输出状态中
            output_states = output_states + (hidden_states,)

        # 如果存在下采样模块,则依次应用它们
        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

            # 将下采样后的隐藏状态添加到输出状态中
            output_states = output_states + (hidden_states,)

        # 返回最终的隐藏状态和输出状态
        return hidden_states, output_states
# 定义一个继承自 nn.Module 的类,用于交叉注意力上采样块
class CrossAttnUpBlockMotion(nn.Module):
    # 初始化方法,设置各层的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        out_channels: int,  # 输出通道数
        prev_output_channel: int,  # 前一层输出的通道数
        temb_channels: int,  # 时间嵌入通道数
        resolution_idx: Optional[int] = None,  # 分辨率索引,默认为 None
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每个块的变换器层数
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数
        resnet_groups: int = 32,  # ResNet 组数
        resnet_pre_norm: bool = True,  # 是否在前面进行归一化
        num_attention_heads: int = 1,  # 注意力头的数量
        cross_attention_dim: int = 1280,  # 交叉注意力的维度
        output_scale_factor: float = 1.0,  # 输出缩放因子
        add_upsample: bool = True,  # 是否添加上采样
        dual_cross_attention: bool = False,  # 是否使用双重交叉注意力
        use_linear_projection: bool = False,  # 是否使用线性投影
        only_cross_attention: bool = False,  # 是否仅使用交叉注意力
        upcast_attention: bool = False,  # 是否上浮注意力
        attention_type: str = "default",  # 注意力类型
        temporal_cross_attention_dim: Optional[int] = None,  # 时间交叉注意力维度,默认为 None
        temporal_num_attention_heads: int = 8,  # 时间注意力头数量
        temporal_max_seq_length: int = 32,  # 时间序列的最大长度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 时间块的变换器层数
    # 定义前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前隐藏状态的元组
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的可选参数
        upsample_size: Optional[int] = None,  # 可选的上采样大小
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
        encoder_attention_mask: Optional[torch.Tensor] = None,  # 可选的编码器注意力掩码
        num_frames: int = 1,  # 帧数,默认为 1
# 定义一个继承自 nn.Module 的类,用于上采样块
class UpBlockMotion(nn.Module):
    # 初始化方法,设置各层的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        prev_output_channel: int,  # 前一层输出的通道数
        out_channels: int,  # 输出通道数
        temb_channels: int,  # 时间嵌入通道数
        resolution_idx: Optional[int] = None,  # 分辨率索引,默认为 None
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数
        resnet_groups: int = 32,  # ResNet 组数
        resnet_pre_norm: bool = True,  # 是否在前面进行归一化
        output_scale_factor: float = 1.0,  # 输出缩放因子
        add_upsample: bool = True,  # 是否添加上采样
        temporal_cross_attention_dim: Optional[int] = None,  # 时间交叉注意力维度,默认为 None
        temporal_num_attention_heads: int = 8,  # 时间注意力头数量
        temporal_max_seq_length: int = 32,  # 时间序列的最大长度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 时间块的变换器层数
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化空列表,用于存放 ResNet 模块
        resnets = []
        # 初始化空列表,用于存放运动模块
        motion_modules = []

        # 支持每个时间块的变换层数量为变量
        if isinstance(temporal_transformer_layers_per_block, int):
            # 将单个整数转换为与层数相同的元组
            temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
        elif len(temporal_transformer_layers_per_block) != num_layers:
            # 检查传入的层数是否与预期一致
            raise ValueError(
                f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
            )

        # 遍历每层,构建 ResNet 和运动模块
        for i in range(num_layers):
            # 设定跳过连接的通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 设定当前层的输入通道数
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 添加 ResNetBlock2D 模块到 resnets 列表
            resnets.append(
                ResnetBlock2D(
                    # 输入通道数为当前层的输入和跳过连接的通道数之和
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 输出通道数设定
                    out_channels=out_channels,
                    # 时间嵌入通道数
                    temb_channels=temb_channels,
                    # 小常数以避免除零
                    eps=resnet_eps,
                    # 组归一化的组数
                    groups=resnet_groups,
                    # Dropout 率
                    dropout=dropout,
                    # 时间嵌入的归一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 激活函数设定
                    non_linearity=resnet_act_fn,
                    # 输出尺度因子
                    output_scale_factor=output_scale_factor,
                    # 是否使用预归一化
                    pre_norm=resnet_pre_norm,
                )
            )

            # 添加 AnimateDiffTransformer3D 模块到 motion_modules 列表
            motion_modules.append(
                AnimateDiffTransformer3D(
                    # 注意力头的数量
                    num_attention_heads=temporal_num_attention_heads,
                    # 输入通道数
                    in_channels=out_channels,
                    # 当前层的变换层数量
                    num_layers=temporal_transformer_layers_per_block[i],
                    # 组归一化的组数
                    norm_num_groups=resnet_groups,
                    # 跨注意力维度
                    cross_attention_dim=temporal_cross_attention_dim,
                    # 是否使用注意力偏置
                    attention_bias=False,
                    # 激活函数类型
                    activation_fn="geglu",
                    # 位置信息嵌入类型
                    positional_embeddings="sinusoidal",
                    # 位置信息嵌入数量
                    num_positional_embeddings=temporal_max_seq_length,
                    # 每个注意力头的维度
                    attention_head_dim=out_channels // temporal_num_attention_heads,
                )
            )

        # 将 ResNet 模块列表转换为 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 将运动模块列表转换为 nn.ModuleList
        self.motion_modules = nn.ModuleList(motion_modules)

        # 如果需要上采样,则初始化上采样模块
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否则,设定为 None
            self.upsamplers = None

        # 设定梯度检查点标志为 False
        self.gradient_checkpointing = False
        # 保存分辨率索引
        self.resolution_idx = resolution_idx

    def forward(
        # 前向传播方法的参数定义
        self,
        hidden_states: torch.Tensor,
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可选的时间嵌入
        temb: Optional[torch.Tensor] = None,
        # 上采样大小
        upsample_size=None,
        # 帧数,默认为 1
        num_frames: int = 1,
        # 额外的参数
        *args,
        **kwargs,
    # 函数返回类型为 torch.Tensor
    ) -> torch.Tensor:
        # 检查传入参数是否存在或 "scale" 参数是否为非 None
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 定义弃用提示信息
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 调用 deprecate 函数记录弃用警告
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 检查 FreeU 是否启用,确保相关属性均不为 None
        is_freeu_enabled = (
            getattr(self, "s1", None)
            and getattr(self, "s2", None)
            and getattr(self, "b1", None)
            and getattr(self, "b2", None)
        )
    
        # 将自定义模块打包成元组,方便遍历
        blocks = zip(self.resnets, self.motion_modules)
    
        # 遍历每一对 resnet 和 motion_module
        for resnet, motion_module in blocks:
            # 从隐藏状态元组中弹出最后一个隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隐藏状态元组,移除最后一个元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
            # 如果启用 FreeU,则仅对前两个阶段进行操作
            if is_freeu_enabled:
                # 应用 FreeU 函数获取新的隐藏状态
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,
                    hidden_states,
                    res_hidden_states,
                    s1=self.s1,
                    s2=self.s2,
                    b1=self.b1,
                    b2=self.b2,
                )
    
            # 将当前隐藏状态和残差隐藏状态在维度 1 上拼接
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
            # 如果在训练模式并且启用了梯度检查点
            if self.training and self.gradient_checkpointing:
                # 定义创建自定义前向传播函数
                def create_custom_forward(module):
                    # 定义自定义前向传播的实现
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 如果 torch 版本大于等于 1.11.0
                if is_torch_version(">=", "1.11.0"):
                    # 使用检查点机制保存内存
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),
                        hidden_states,
                        temb,
                        use_reentrant=False,
                    )
                else:
                    # 否则使用旧版检查点机制
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet), hidden_states, temb
                    )
            else:
                # 否则直接通过 resnet 计算隐藏状态
                hidden_states = resnet(hidden_states, temb)
    
            # 通过 motion_module 处理隐藏状态,传入帧数
            hidden_states = motion_module(hidden_states, num_frames=num_frames)
    
        # 如果存在上采样器,则对每个上采样器进行处理
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                # 通过上采样器处理隐藏状态,传入上采样大小
                hidden_states = upsampler(hidden_states, upsample_size)
    
        # 返回最终处理后的隐藏状态
        return hidden_states
# 定义 UNetMidBlockCrossAttnMotion 类,继承自 nn.Module
class UNetMidBlockCrossAttnMotion(nn.Module):
    # 初始化方法,定义类的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        temb_channels: int,  # 时间嵌入通道数
        dropout: float = 0.0,  # Dropout 率
        num_layers: int = 1,  # 层数
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每个块的变换层数
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        resnet_time_scale_shift: str = "default",  # ResNet 时间尺度偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数类型
        resnet_groups: int = 32,  # ResNet 组数
        resnet_pre_norm: bool = True,  # 是否进行前置归一化
        num_attention_heads: int = 1,  # 注意力头数量
        output_scale_factor: float = 1.0,  # 输出缩放因子
        cross_attention_dim: int = 1280,  # 交叉注意力维度
        dual_cross_attention: bool = False,  # 是否使用双重交叉注意力
        use_linear_projection: bool = False,  # 是否使用线性投影
        upcast_attention: bool = False,  # 是否上升注意力精度
        attention_type: str = "default",  # 注意力类型
        temporal_num_attention_heads: int = 1,  # 时间注意力头数量
        temporal_cross_attention_dim: Optional[int] = None,  # 时间交叉注意力维度
        temporal_max_seq_length: int = 32,  # 时间序列最大长度
        temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 时间块的变换层数
    # 前向传播方法,定义输入和输出
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隐藏状态的输入张量
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 可选的交叉注意力参数
        encoder_attention_mask: Optional[torch.Tensor] = None,  # 可选的编码器注意力掩码
        num_frames: int = 1,  # 帧数
    # 该函数的返回类型为 torch.Tensor
        ) -> torch.Tensor:
            # 检查交叉注意力参数是否不为 None
            if cross_attention_kwargs is not None:
                # 如果参数中包含 "scale",发出警告,说明该参数已弃用
                if cross_attention_kwargs.get("scale", None) is not None:
                    logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
    
            # 通过第一个残差网络处理隐藏状态
            hidden_states = self.resnets[0](hidden_states, temb)
    
            # 将注意力层、残差网络和运动模块打包在一起
            blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
            # 遍历每个注意力层、残差网络和运动模块
            for attn, resnet, motion_module in blocks:
                # 如果在训练模式下并且启用了梯度检查点
                if self.training and self.gradient_checkpointing:
    
                    # 创建自定义前向函数
                    def create_custom_forward(module, return_dict=None):
                        # 定义自定义前向函数,接受任意输入
                        def custom_forward(*inputs):
                            # 如果返回字典不为 None,使用返回字典调用模块
                            if return_dict is not None:
                                return module(*inputs, return_dict=return_dict)
                            else:
                                # 否则直接调用模块
                                return module(*inputs)
    
                        return custom_forward
    
                    # 根据 PyTorch 版本设置检查点参数
                    ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                    # 调用注意力模块并获取输出的第一个元素
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        cross_attention_kwargs=cross_attention_kwargs,
                        attention_mask=attention_mask,
                        encoder_attention_mask=encoder_attention_mask,
                        return_dict=False,
                    )[0]
                    # 使用梯度检查点对运动模块进行前向传播
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(motion_module),
                        hidden_states,
                        temb,
                        **ckpt_kwargs,
                    )
                    # 使用梯度检查点对残差网络进行前向传播
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),
                        hidden_states,
                        temb,
                        **ckpt_kwargs,
                    )
                else:
                    # 在非训练模式下直接调用注意力模块
                    hidden_states = attn(
                        hidden_states,
                        encoder_hidden_states=encoder_hidden_states,
                        cross_attention_kwargs=cross_attention_kwargs,
                        attention_mask=attention_mask,
                        encoder_attention_mask=encoder_attention_mask,
                        return_dict=False,
                    )[0]
                    # 调用运动模块,传入隐藏状态和帧数
                    hidden_states = motion_module(
                        hidden_states,
                        num_frames=num_frames,
                    )
                    # 调用残差网络处理隐藏状态
                    hidden_states = resnet(hidden_states, temb)
    
            # 返回处理后的隐藏状态
            return hidden_states
# 定义一个继承自 nn.Module 的运动模块类
class MotionModules(nn.Module):
    # 初始化方法,接收多个参数配置运动模块
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        layers_per_block: int = 2,  # 每个模块块的层数,默认是 2
        transformer_layers_per_block: Union[int, Tuple[int]] = 8,  # 每个块中的变换层数
        num_attention_heads: Union[int, Tuple[int]] = 8,  # 注意力头的数量
        attention_bias: bool = False,  # 是否使用注意力偏差
        cross_attention_dim: Optional[int] = None,  # 交叉注意力维度
        activation_fn: str = "geglu",  # 激活函数,默认使用 "geglu"
        norm_num_groups: int = 32,  # 归一化组的数量
        max_seq_length: int = 32,  # 最大序列长度
    ):
        # 调用父类初始化方法
        super().__init__()
        # 初始化运动模块列表
        self.motion_modules = nn.ModuleList([])

        # 如果变换层数是整数,重复为每个模块块填充
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = (transformer_layers_per_block,) * layers_per_block
        # 检查变换层数与块层数是否匹配
        elif len(transformer_layers_per_block) != layers_per_block:
            raise ValueError(
                f"The number of transformer layers per block must match the number of layers per block, "
                f"got {layers_per_block} and {len(transformer_layers_per_block)}"
            )

        # 遍历每个模块块
        for i in range(layers_per_block):
            # 向运动模块列表添加 AnimateDiffTransformer3D 实例
            self.motion_modules.append(
                AnimateDiffTransformer3D(
                    in_channels=in_channels,  # 输入通道数
                    num_layers=transformer_layers_per_block[i],  # 当前块的变换层数
                    norm_num_groups=norm_num_groups,  # 归一化组的数量
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                    activation_fn=activation_fn,  # 激活函数
                    attention_bias=attention_bias,  # 注意力偏差
                    num_attention_heads=num_attention_heads,  # 注意力头数量
                    attention_head_dim=in_channels // num_attention_heads,  # 每个注意力头的维度
                    positional_embeddings="sinusoidal",  # 使用正弦波的位置嵌入
                    num_positional_embeddings=max_seq_length,  # 位置嵌入的数量
                )
            )


# 定义一个运动适配器类,结合多个混合类
class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    @register_to_config
    # 初始化方法,配置多个运动适配器参数
    def __init__(
        self,
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),  # 块输出通道
        motion_layers_per_block: Union[int, Tuple[int]] = 2,  # 每个运动块的层数
        motion_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]] = 1,  # 每个运动块中的变换层数
        motion_mid_block_layers_per_block: int = 1,  # 中间块的层数
        motion_transformer_layers_per_mid_block: Union[int, Tuple[int]] = 1,  # 中间块中的变换层数
        motion_num_attention_heads: Union[int, Tuple[int]] = 8,  # 中间块的注意力头数量
        motion_norm_num_groups: int = 32,  # 中间块的归一化组数量
        motion_max_seq_length: int = 32,  # 中间块的最大序列长度
        use_motion_mid_block: bool = True,  # 是否使用中间块
        conv_in_channels: Optional[int] = None,  # 输入通道数
    ):
        pass  # 前向传播方法,尚未实现


# 定义一个修改后的条件 2D UNet 模型
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
    r"""
    一个修改后的条件 2D UNet 模型,接收嘈杂样本、条件状态和时间步,返回形状输出。

    该模型继承自 [`ModelMixin`]。查看超类文档以获取所有模型的通用方法实现(如下载或保存)。
    """

    # 支持梯度检查点
    _supports_gradient_checkpointing = True

    @register_to_config
    # 初始化方法,用于创建类的实例
    def __init__(
        # 可选参数,样本大小,默认为 None
        self,
        sample_size: Optional[int] = None,
        # 输入通道数,默认为 4
        in_channels: int = 4,
        # 输出通道数,默认为 4
        out_channels: int = 4,
        # 下采样块的类型元组
        down_block_types: Tuple[str, ...] = (
            "CrossAttnDownBlockMotion",  # 第一个下采样块类型
            "CrossAttnDownBlockMotion",  # 第二个下采样块类型
            "CrossAttnDownBlockMotion",  # 第三个下采样块类型
            "DownBlockMotion",            # 第四个下采样块类型
        ),
        # 上采样块的类型元组
        up_block_types: Tuple[str, ...] = (
            "UpBlockMotion",              # 第一个上采样块类型
            "CrossAttnUpBlockMotion",    # 第二个上采样块类型
            "CrossAttnUpBlockMotion",    # 第三个上采样块类型
            "CrossAttnUpBlockMotion",    # 第四个上采样块类型
        ),
        # 块的输出通道数元组
        block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
        # 每个块的层数,默认为 2
        layers_per_block: Union[int, Tuple[int]] = 2,
        # 下采样填充,默认为 1
        downsample_padding: int = 1,
        # 中间块的缩放因子,默认为 1
        mid_block_scale_factor: float = 1,
        # 激活函数类型,默认为 "silu"
        act_fn: str = "silu",
        # 归一化的组数,默认为 32
        norm_num_groups: int = 32,
        # 归一化的 epsilon 值,默认为 1e-5
        norm_eps: float = 1e-5,
        # 交叉注意力的维度,默认为 1280
        cross_attention_dim: int = 1280,
        # 每个块的变换器层数,默认为 1
        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        # 可选参数,反向变换器层数,默认为 None
        reverse_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
        # 时间变换器的层数,默认为 1
        temporal_transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
        # 可选参数,反向时间变换器层数,默认为 None
        reverse_temporal_transformer_layers_per_block: Optional[Union[int, Tuple[int], Tuple[Tuple]]] = None,
        # 每个中间块的变换器层数,默认为 None
        transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
        # 每个中间块的时间变换器层数,默认为 1
        temporal_transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = 1,
        # 是否使用线性投影,默认为 False
        use_linear_projection: bool = False,
        # 注意力头的数量,默认为 8
        num_attention_heads: Union[int, Tuple[int, ...]] = 8,
        # 动作最大序列长度,默认为 32
        motion_max_seq_length: int = 32,
        # 动作注意力头的数量,默认为 8
        motion_num_attention_heads: Union[int, Tuple[int, ...]] = 8,
        # 可选参数,反向动作注意力头的数量,默认为 None
        reverse_motion_num_attention_heads: Optional[Union[int, Tuple[int, ...], Tuple[Tuple[int, ...], ...]]] = None,
        # 是否使用动作中间块,默认为 True
        use_motion_mid_block: bool = True,
        # 中间块的层数,默认为 1
        mid_block_layers: int = 1,
        # 编码器隐藏层维度,默认为 None
        encoder_hid_dim: Optional[int] = None,
        # 编码器隐藏层类型,默认为 None
        encoder_hid_dim_type: Optional[str] = None,
        # 可选参数,附加嵌入类型,默认为 None
        addition_embed_type: Optional[str] = None,
        # 可选参数,附加时间嵌入维度,默认为 None
        addition_time_embed_dim: Optional[int] = None,
        # 可选参数,投影类别嵌入的输入维度,默认为 None
        projection_class_embeddings_input_dim: Optional[int] = None,
        # 可选参数,时间条件投影维度,默认为 None
        time_cond_proj_dim: Optional[int] = None,
    # 类方法,用于从 UNet2DConditionModel 创建对象
    @classmethod
    def from_unet2d(
        cls,
        # UNet2DConditionModel 对象
        unet: UNet2DConditionModel,
        # 可选的运动适配器,默认为 None
        motion_adapter: Optional[MotionAdapter] = None,
        # 是否加载权重,默认为 True
        load_weights: bool = True,
    # 冻结 UNet2DConditionModel 的权重,只保留运动模块可训练,便于微调
    def freeze_unet2d_params(self) -> None:
        """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
        unfrozen for fine tuning.
        """
        # 冻结所有参数
        for param in self.parameters():
            # 将参数的 requires_grad 属性设置为 False,禁止梯度更新
            param.requires_grad = False

        # 解冻运动模块
        for down_block in self.down_blocks:
            # 获取当前下采样块的运动模块
            motion_modules = down_block.motion_modules
            for param in motion_modules.parameters():
                # 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
                param.requires_grad = True

        for up_block in self.up_blocks:
            # 获取当前上采样块的运动模块
            motion_modules = up_block.motion_modules
            for param in motion_modules.parameters():
                # 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
                param.requires_grad = True

        # 检查中间块是否具有运动模块
        if hasattr(self.mid_block, "motion_modules"):
            # 获取中间块的运动模块
            motion_modules = self.mid_block.motion_modules
            for param in motion_modules.parameters():
                # 将运动模块参数的 requires_grad 属性设置为 True,允许梯度更新
                param.requires_grad = True

    # 加载运动模块的状态字典
    def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
        # 遍历运动适配器的下采样块
        for i, down_block in enumerate(motion_adapter.down_blocks):
            # 加载下采样块的运动模块状态字典
            self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
        # 遍历运动适配器的上采样块
        for i, up_block in enumerate(motion_adapter.up_blocks):
            # 加载上采样块的运动模块状态字典
            self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())

        # 支持没有中间块的旧运动模块
        if hasattr(self.mid_block, "motion_modules"):
            # 加载中间块的运动模块状态字典
            self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())

    # 保存运动模块的状态
    def save_motion_modules(
        self,
        save_directory: str,
        is_main_process: bool = True,
        safe_serialization: bool = True,
        variant: Optional[str] = None,
        push_to_hub: bool = False,
        **kwargs,
    ) -> None:
        # 获取当前模型的状态字典
        state_dict = self.state_dict()

        # 提取所有运动模块的状态
        motion_state_dict = {}
        for k, v in state_dict.items():
            # 筛选出包含 "motion_modules" 的键值对
            if "motion_modules" in k:
                motion_state_dict[k] = v

        # 创建运动适配器实例
        adapter = MotionAdapter(
            block_out_channels=self.config["block_out_channels"],
            motion_layers_per_block=self.config["layers_per_block"],
            motion_norm_num_groups=self.config["norm_num_groups"],
            motion_num_attention_heads=self.config["motion_num_attention_heads"],
            motion_max_seq_length=self.config["motion_max_seq_length"],
            use_motion_mid_block=self.config["use_motion_mid_block"],
        )
        # 加载运动状态字典
        adapter.load_state_dict(motion_state_dict)
        # 保存适配器的预训练状态
        adapter.save_pretrained(
            save_directory=save_directory,
            is_main_process=is_main_process,
            safe_serialization=safe_serialization,
            variant=variant,
            push_to_hub=push_to_hub,
            **kwargs,
        )

    @property
    # 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 复制的属性
    # 定义一个方法,返回模型中所有注意力处理器的字典
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 类型的注意力处理器: 包含模型中所有注意力处理器的字典,
            按照其权重名称索引。
        """
        # 初始化一个空字典,用于存储注意力处理器
        processors = {}

        # 定义一个递归函数,用于添加注意力处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否具有 'get_processor' 方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为处理器名称
                processors[f"{name}.processor"] = module.get_processor()

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,继续添加子模块的处理器
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回处理器字典
            return processors

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数添加所有处理器
            fn_recursive_add_processors(name, module, processors)

        # 返回最终的处理器字典
        return processors

    # 从 diffusers.models.unets.unet_2d_condition 中复制的方法,用于设置注意力处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的注意力处理器。

        参数:
            processor (`dict` of `AttentionProcessor` 或仅 `AttentionProcessor`):
                实例化的处理器类或处理器类的字典,将被设置为**所有** `Attention` 层的处理器。

                如果 `processor` 是字典,键需要定义相应的交叉注意力处理器的路径。
                当设置可训练的注意力处理器时,强烈推荐这样做。

        """
        # 获取当前注意力处理器字典的键数量
        count = len(self.attn_processors.keys())

        # 检查传入的处理器字典长度是否与注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            # 如果不匹配,抛出错误
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与"
                f" 注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
            )

        # 定义一个递归函数,用于设置注意力处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否具有 'set_processor' 方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置处理器
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,继续设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前对象的所有子模块
        for name, module in self.named_children():
            # 调用递归函数设置所有处理器
            fn_recursive_attn_processor(name, module, processor)
    # 定义一个方法以启用前向分块处理
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        设置注意力处理器以使用[前馈分块](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。

        参数:
            chunk_size (`int`, *可选*):
                前馈层的块大小。如果未指定,将单独对维度为`dim`的每个张量运行前馈层。
            dim (`int`, *可选*, 默认为`0`):
                前馈计算应分块的维度。选择dim=0(批次)或dim=1(序列长度)。
        """
        # 检查dim参数是否在有效范围内(0或1)
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

        # 默认块大小为1
        chunk_size = chunk_size or 1

        # 定义递归前馈函数以设置模块的分块前馈处理
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模块有set_chunk_feed_forward属性,设置块大小和维度
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遍历模块的子模块
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 遍历当前对象的子模块,应用递归前馈函数
        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)

    # 定义一个方法以禁用前向分块处理
    def disable_forward_chunking(self) -> None:
        # 定义递归前馈函数以设置模块的分块前馈处理为None
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模块有set_chunk_feed_forward属性,设置块大小和维度为None
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

            # 遍历模块的子模块
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)

        # 遍历当前对象的子模块,应用递归前馈函数
        for module in self.children():
            fn_recursive_feed_forward(module, None, 0)

    # 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor复制的方法
    def set_default_attn_processor(self) -> None:
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 如果所有注意力处理器都是ADDED_KV_ATTENTION_PROCESSORS类型
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 设置处理器为AttnAddedKVProcessor
            processor = AttnAddedKVProcessor()
        # 如果所有注意力处理器都是CROSS_ATTENTION_PROCESSORS类型
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 设置处理器为AttnProcessor
            processor = AttnProcessor()
        else:
            # 抛出错误,表示不能在不匹配的注意力处理器类型下调用该方法
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 设置当前对象的注意力处理器
        self.set_attn_processor(processor)

    # 定义一个方法以设置模块的梯度检查点
    def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
        # 检查模块是否为特定类型,如果是则设置其梯度检查点属性
        if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
            module.gradient_checkpointing = value

    # 从diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu复制的方法
    # 启用 FreeU 机制,接受四个浮点型缩放因子作为参数
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
        # 文档字符串,描述该方法的作用及参数含义
        r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stage blocks where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
        are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate the "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        # 遍历上采样块,并为每个块设置缩放因子
        for i, upsample_block in enumerate(self.up_blocks):
            # 为上采样块设置阶段1的缩放因子
            setattr(upsample_block, "s1", s1)
            # 为上采样块设置阶段2的缩放因子
            setattr(upsample_block, "s2", s2)
            # 为上采样块设置阶段1的主干特征缩放因子
            setattr(upsample_block, "b1", b1)
            # 为上采样块设置阶段2的主干特征缩放因子
            setattr(upsample_block, "b2", b2)

    # 禁用 FreeU 机制
    def disable_freeu(self) -> None:
        # 文档字符串,描述该方法的作用
        """Disables the FreeU mechanism."""
        # 定义 FreeU 相关的键名集合
        freeu_keys = {"s1", "s2", "b1", "b2"}
        # 遍历上采样块
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历 FreeU 键名
            for k in freeu_keys:
                # 检查上采样块是否具有该属性或该属性是否不为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    # 将上采样块的该属性设置为 None
                    setattr(upsample_block, k, None)

    # 启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        # 文档字符串,描述该方法的作用
        """
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>
        """
        # 初始化原始注意力处理器为 None
        self.original_attn_processors = None

        # 遍历注意力处理器
        for _, attn_processor in self.attn_processors.items():
            # 检查注意力处理器类名中是否包含 "Added"
            if "Added" in str(attn_processor.__class__.__name__):
                # 抛出异常,说明不支持该操作
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 保存原始的注意力处理器
        self.original_attn_processors = self.attn_processors

        # 遍历所有模块
        for module in self.modules():
            # 检查模块是否为 Attention 类型
            if isinstance(module, Attention):
                # 融合投影
                module.fuse_projections(fuse=True)

        # 设置融合后的注意力处理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 解融合 QKV 投影的方法(省略具体实现)
    # 定义一个禁用融合 QKV 投影的方法
    def unfuse_qkv_projections(self):
        """如果启用了,禁用融合 QKV 投影。
    
        <Tip warning={true}>
        
        此 API 是 🧪 实验性。
        
        </Tip>
    
        """
        # 检查原始注意力处理器是否不为 None
        if self.original_attn_processors is not None:
            # 设置当前注意力处理器为原始的注意力处理器
            self.set_attn_processor(self.original_attn_processors)
    
    # 定义前向传播方法,接收多个参数
    def forward(
        self,
        # 输入样本张量
        sample: torch.Tensor,
        # 时间步,可以是张量、浮点数或整数
        timestep: Union[torch.Tensor, float, int],
        # 编码器隐藏状态张量
        encoder_hidden_states: torch.Tensor,
        # 可选的时间步条件张量
        timestep_cond: Optional[torch.Tensor] = None,
        # 可选的注意力掩码张量
        attention_mask: Optional[torch.Tensor] = None,
        # 可选的交叉注意力参数字典
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        # 可选的附加条件参数字典
        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
        # 可选的下块附加残差元组
        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
        # 可选的中间块附加残差张量
        mid_block_additional_residual: Optional[torch.Tensor] = None,
        # 是否返回字典格式的结果,默认为 True
        return_dict: bool = True,
posted @ 2024-10-22 12:39  绝不原创的飞龙  阅读(66)  评论(0编辑  收藏  举报