diffusers-源码解析-十四-

diffusers 源码解析(十四)

.\diffusers\models\unets\unet_2d_blocks_flax.py

# 版权声明,说明该文件的版权信息及相关许可协议
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 许可信息,使用 Apache License 2.0 许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在符合许可证的情况下使用
# you may not use this file except in compliance with the License.
# 提供许可证获取链接
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,表明不提供任何形式的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证的相关权限及限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 导入 flax.linen 模块,用于构建神经网络
import flax.linen as nn
# 导入 jax.numpy,用于数值计算
import jax.numpy as jnp

# 从其他模块导入特定的类,用于构建模型的各个组件
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D


# 定义 FlaxCrossAttnDownBlock2D 类,表示一个 2D 跨注意力下采样模块
class FlaxCrossAttnDownBlock2D(nn.Module):
    r"""
    跨注意力 2D 下采样块 - 原始架构来自 Unet transformers:
    https://arxiv.org/abs/2103.06104

    参数说明:
        in_channels (:obj:`int`):
            输入通道数
        out_channels (:obj:`int`):
            输出通道数
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            注意力块层数
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            每个空间变换块的注意力头数
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            是否在每个最终输出之前添加下采样层
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            启用内存高效的注意力 https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            是否将头维度拆分为一个新的轴进行自注意力计算。在大多数情况下,
            启用此标志应加快 Stable Diffusion 2.x 和 Stable Diffusion XL 的计算速度。
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的数据类型
    """

    # 定义输入通道数
    in_channels: int
    # 定义输出通道数
    out_channels: int
    # 定义 Dropout 率,默认为 0.0
    dropout: float = 0.0
    # 定义注意力块的层数,默认为 1
    num_layers: int = 1
    # 定义注意力头数,默认为 1
    num_attention_heads: int = 1
    # 定义是否添加下采样层,默认为 True
    add_downsample: bool = True
    # 定义是否使用线性投影,默认为 False
    use_linear_projection: bool = False
    # 定义是否仅使用跨注意力,默认为 False
    only_cross_attention: bool = False
    # 定义是否启用内存高效注意力,默认为 False
    use_memory_efficient_attention: bool = False
    # 定义是否拆分头维度,默认为 False
    split_head_dim: bool = False
    # 定义参数的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 定义每个块的变换器层数,默认为 1
    transformer_layers_per_block: int = 1
    # 设置模型的各个组成部分,包括残差块和注意力块
        def setup(self):
            # 初始化残差块列表
            resnets = []
            # 初始化注意力块列表
            attentions = []
    
            # 遍历每一层,构建残差块和注意力块
            for i in range(self.num_layers):
                # 第一层的输入通道为 in_channels,其他层为 out_channels
                in_channels = self.in_channels if i == 0 else self.out_channels
    
                # 创建一个 FlaxResnetBlock2D 实例
                res_block = FlaxResnetBlock2D(
                    in_channels=in_channels,  # 输入通道
                    out_channels=self.out_channels,  # 输出通道
                    dropout_prob=self.dropout,  # 丢弃率
                    dtype=self.dtype,  # 数据类型
                )
                # 将残差块添加到列表中
                resnets.append(res_block)
    
                # 创建一个 FlaxTransformer2DModel 实例
                attn_block = FlaxTransformer2DModel(
                    in_channels=self.out_channels,  # 输入通道
                    n_heads=self.num_attention_heads,  # 注意力头数
                    d_head=self.out_channels // self.num_attention_heads,  # 每个头的维度
                    depth=self.transformer_layers_per_block,  # 每个块的层数
                    use_linear_projection=self.use_linear_projection,  # 是否使用线性投影
                    only_cross_attention=self.only_cross_attention,  # 是否只使用交叉注意力
                    use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用内存高效的注意力
                    split_head_dim=self.split_head_dim,  # 是否拆分头的维度
                    dtype=self.dtype,  # 数据类型
                )
                # 将注意力块添加到列表中
                attentions.append(attn_block)
    
            # 将残差块列表赋值给实例变量
            self.resnets = resnets
            # 将注意力块列表赋值给实例变量
            self.attentions = attentions
    
            # 如果需要下采样,则创建下采样层
            if self.add_downsample:
                self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
    
        # 定义前向调用方法,处理隐藏状态和编码器隐藏状态
        def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
            # 初始化输出状态元组
            output_states = ()
    
            # 遍历残差块和注意力块并进行处理
            for resnet, attn in zip(self.resnets, self.attentions):
                # 通过残差块处理隐藏状态
                hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
                # 通过注意力块处理隐藏状态
                hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
                # 将当前隐藏状态添加到输出状态元组中
                output_states += (hidden_states,)
    
            # 如果需要下采样,则进行下采样
            if self.add_downsample:
                hidden_states = self.downsamplers_0(hidden_states)
                # 将下采样后的隐藏状态添加到输出状态元组中
                output_states += (hidden_states,)
    
            # 返回最终的隐藏状态和输出状态元组
            return hidden_states, output_states
# 定义 Flax 2D 降维块类,继承自 nn.Module
class FlaxDownBlock2D(nn.Module):
    r"""
    Flax 2D downsizing block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    
    # 声明输入输出通道和其他参数
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 设置方法,用于初始化模型的层
    def setup(self):
        # 创建空列表以存储残差块
        resnets = []

        # 根据层数创建残差块
        for i in range(self.num_layers):
            # 第一个块的输入通道为 in_channels,其余为 out_channels
            in_channels = self.in_channels if i == 0 else self.out_channels

            # 创建残差块实例
            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            # 将残差块添加到列表中
            resnets.append(res_block)
        # 将列表赋值给实例属性
        self.resnets = resnets

        # 如果需要,添加降采样层
        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)

    # 调用方法,执行前向传播
    def __call__(self, hidden_states, temb, deterministic=True):
        # 创建空元组以存储输出状态
        output_states = ()

        # 遍历所有残差块进行前向传播
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 将当前隐藏状态添加到输出状态中
            output_states += (hidden_states,)

        # 如果需要,应用降采样层
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            # 将降采样后的隐藏状态添加到输出状态中
            output_states += (hidden_states,)

        # 返回最终的隐藏状态和输出状态
        return hidden_states, output_states


# 定义 Flax 交叉注意力 2D 上采样块类,继承自 nn.Module
class FlaxCrossAttnUpBlock2D(nn.Module):
    r"""
    Cross Attention 2D Upsampling block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104
    # 定义参数的文档字符串,描述各个参数的用途和类型
        Parameters:
            in_channels (:obj:`int`):  # 输入通道数
                Input channels
            out_channels (:obj:`int`):  # 输出通道数
                Output channels
            dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout 率,默认值为 0.0
                Dropout rate
            num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力块的层数,默认值为 1
                Number of attention blocks layers
            num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每个空间变换块的注意力头数量,默认值为 1
                Number of attention heads of each spatial transformer block
            add_upsample (:obj:`bool`, *optional*, defaults to `True`):  # 是否在每个最终输出前添加上采样层,默认值为 True
                Whether to add upsampling layer before each final output
            use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 启用内存高效注意力,默认值为 False
                enable memory efficient attention https://arxiv.org/abs/2112.05682
            split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度拆分为新轴以进行自注意力计算,默认值为 False
                Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
                enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
            dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 数据类型参数,默认值为 jnp.float32
                Parameters `dtype`
        """
    
        in_channels: int  # 输入通道数的声明
        out_channels: int  # 输出通道数的声明
        prev_output_channel: int  # 前一个输出通道数的声明
        dropout: float = 0.0  # Dropout 率的声明,默认值为 0.0
        num_layers: int = 1  # 注意力层数的声明,默认值为 1
        num_attention_heads: int = 1  # 注意力头数量的声明,默认值为 1
        add_upsample: bool = True  # 是否添加上采样的声明,默认值为 True
        use_linear_projection: bool = False  # 是否使用线性投影的声明,默认值为 False
        only_cross_attention: bool = False  # 是否仅使用交叉注意力的声明,默认值为 False
        use_memory_efficient_attention: bool = False  # 是否启用内存高效注意力的声明,默认值为 False
        split_head_dim: bool = False  # 是否拆分头维度的声明,默认值为 False
        dtype: jnp.dtype = jnp.float32  # 数据类型的声明,默认值为 jnp.float32
        transformer_layers_per_block: int = 1  # 每个块的变换层数的声明,默认值为 1
    # 设置方法,初始化网络结构
    def setup(self):
        # 初始化空列表以存储 ResNet 块
        resnets = []
        # 初始化空列表以存储注意力块
        attentions = []
    
        # 遍历每一层以创建相应的 ResNet 和注意力块
        for i in range(self.num_layers):
            # 设置跳跃连接的通道数,最后一层使用输入通道,否则使用输出通道
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 设置当前 ResNet 块的输入通道,第一层使用前一层的输出通道
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
    
            # 创建 FlaxResnetBlock2D 实例
            res_block = FlaxResnetBlock2D(
                # 设置输入通道为当前 ResNet 块输入通道加跳跃连接通道
                in_channels=resnet_in_channels + res_skip_channels,
                # 设置输出通道为指定的输出通道
                out_channels=self.out_channels,
                # 设置 dropout 概率
                dropout_prob=self.dropout,
                # 设置数据类型
                dtype=self.dtype,
            )
            # 将创建的 ResNet 块添加到列表中
            resnets.append(res_block)
    
            # 创建 FlaxTransformer2DModel 实例
            attn_block = FlaxTransformer2DModel(
                # 设置输入通道为输出通道
                in_channels=self.out_channels,
                # 设置注意力头数
                n_heads=self.num_attention_heads,
                # 设置每个注意力头的维度
                d_head=self.out_channels // self.num_attention_heads,
                # 设置 transformer 块的深度
                depth=self.transformer_layers_per_block,
                # 设置是否使用线性投影
                use_linear_projection=self.use_linear_projection,
                # 设置是否仅使用交叉注意力
                only_cross_attention=self.only_cross_attention,
                # 设置是否使用内存高效的注意力机制
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                # 设置是否分割头部维度
                split_head_dim=self.split_head_dim,
                # 设置数据类型
                dtype=self.dtype,
            )
            # 将创建的注意力块添加到列表中
            attentions.append(attn_block)
    
        # 将 ResNet 列表保存到实例属性
        self.resnets = resnets
        # 将注意力列表保存到实例属性
        self.attentions = attentions
    
        # 如果需要添加上采样层,则创建相应的 FlaxUpsample2D 实例
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
    
    # 定义调用方法,接受隐藏状态和其他参数
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
        # 遍历 ResNet 和注意力块
        for resnet, attn in zip(self.resnets, self.attentions):
            # 从跳跃连接的隐藏状态元组中取出最后一个状态
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新跳跃连接的隐藏状态元组,去掉最后一个状态
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
            # 将隐藏状态与跳跃连接的隐藏状态在最后一个轴上拼接
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
    
            # 使用当前的 ResNet 块处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            # 使用当前的注意力块处理隐藏状态
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
    
        # 如果需要添加上采样,则使用上采样层处理隐藏状态
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)
    
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个 2D 上采样块类,继承自 nn.Module
class FlaxUpBlock2D(nn.Module):
    r"""
    Flax 2D upsampling block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        prev_output_channel (:obj:`int`):
            Output channels from the previous block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """

    # 定义输入输出通道和其他参数
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 设置方法用于初始化块的结构
    def setup(self):
        resnets = []  # 创建一个空列表用于存储 ResNet 块

        # 遍历每一层,创建 ResNet 块
        for i in range(self.num_layers):
            # 计算跳跃连接通道数
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            # 设置输入通道数
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            # 创建一个新的 FlaxResnetBlock2D 实例
            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
            )
            resnets.append(res_block)  # 将块添加到列表中

        self.resnets = resnets  # 将列表赋值给实例变量

        # 如果需要上采样,初始化上采样层
        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)

    # 定义前向传播方法
    def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
        # 遍历每个 ResNet 块进行前向传播
        for resnet in self.resnets:
            # 从元组中弹出最后的残差隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元组,去掉最后一项
            # 连接当前隐藏状态与残差隐藏状态
            hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            # 通过 ResNet 块处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        # 如果需要上采样,调用上采样层
        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)

        return hidden_states  # 返回处理后的隐藏状态


# 定义一个 2D 中级交叉注意力块类,继承自 nn.Module
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
    r"""
    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
    # 定义参数的文档字符串
    Parameters:
        in_channels (:obj:`int`):  # 输入通道数
            Input channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):  # Dropout比率,默认为0.0
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):  # 注意力层的数量,默认为1
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):  # 每个空间变换块的注意力头数量,默认为1
            Number of attention heads of each spatial transformer block
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):  # 是否启用内存高效的注意力机制,默认为False
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):  # 是否将头维度分割为新的轴以加速计算,默认为False
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 数据类型参数,默认为jnp.float32
            Parameters `dtype`
    """

    in_channels: int  # 输入通道数的类型
    dropout: float = 0.0  # Dropout比率的默认值
    num_layers: int = 1  # 注意力层数量的默认值
    num_attention_heads: int = 1  # 注意力头数量的默认值
    use_linear_projection: bool = False  # 是否使用线性投影的默认值
    use_memory_efficient_attention: bool = False  # 是否使用内存高效注意力的默认值
    split_head_dim: bool = False  # 是否分割头维度的默认值
    dtype: jnp.dtype = jnp.float32  # 数据类型的默认值
    transformer_layers_per_block: int = 1  # 每个块中的变换层数量的默认值

    def setup(self):  # 设置方法,用于初始化
        # 至少会有一个ResNet块
        resnets = [  # 创建ResNet块列表
            FlaxResnetBlock2D(  # 创建一个ResNet块
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout_prob=self.dropout,  # Dropout概率
                dtype=self.dtype,  # 数据类型
            )
        ]

        attentions = []  # 初始化注意力块列表

        for _ in range(self.num_layers):  # 遍历指定的注意力层数
            attn_block = FlaxTransformer2DModel(  # 创建一个Transformer块
                in_channels=self.in_channels,  # 输入通道数
                n_heads=self.num_attention_heads,  # 注意力头数量
                d_head=self.in_channels // self.num_attention_heads,  # 每个头的维度
                depth=self.transformer_layers_per_block,  # 变换层深度
                use_linear_projection=self.use_linear_projection,  # 是否使用线性投影
                use_memory_efficient_attention=self.use_memory_efficient_attention,  # 是否使用内存高效注意力
                split_head_dim=self.split_head_dim,  # 是否分割头维度
                dtype=self.dtype,  # 数据类型
            )
            attentions.append(attn_block)  # 将注意力块添加到列表中

            res_block = FlaxResnetBlock2D(  # 创建一个ResNet块
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout_prob=self.dropout,  # Dropout概率
                dtype=self.dtype,  # 数据类型
            )
            resnets.append(res_block)  # 将ResNet块添加到列表中

        self.resnets = resnets  # 将ResNet块列表赋值给实例属性
        self.attentions = attentions  # 将注意力块列表赋值给实例属性

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):  # 调用方法
        hidden_states = self.resnets[0](hidden_states, temb)  # 通过第一个ResNet块处理隐藏状态
        for attn, resnet in zip(self.attentions, self.resnets[1:]):  # 遍历每个注意力块和后续ResNet块
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)  # 处理隐藏状态
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)  # 再次处理隐藏状态

        return hidden_states  # 返回处理后的隐藏状态

.\diffusers\models\unets\unet_2d_condition.py

# 版权声明,标明版权信息和使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache License 2.0 版本进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得在未遵守许可的情况下使用此文件
# you may not use this file except in compliance with the License.
# 你可以在以下网址获得许可证的副本
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件以“按原样”方式分发,不提供任何形式的担保或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言所适用的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.

# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入所需的类型注释
from typing import Any, Dict, List, Optional, Tuple, Union

# 导入 PyTorch 库和相关模块
import torch
import torch.nn as nn
import torch.utils.checkpoint

# 从配置和加载器模块中导入所需的类和函数
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入与注意力机制相关的处理器
    CROSS_ATTENTION_PROCESSORS,
    Attention,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
    FusedAttnProcessor2_0,
)
from ..embeddings import (
    GaussianFourierProjection,  # 导入多种嵌入方法
    GLIGENTextBoundingboxProjection,
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
    Timesteps,
)
from ..modeling_utils import ModelMixin  # 导入模型混合类
from .unet_2d_blocks import (
    get_down_block,  # 导入下采样块的构造函数
    get_mid_block,   # 导入中间块的构造函数
    get_up_block,    # 导入上采样块的构造函数
)

# 创建一个日志记录器,用于记录模型相关信息
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义 UNet2DConditionOutput 数据类,用于存储 UNet2DConditionModel 的输出
@dataclass
class UNet2DConditionOutput(BaseOutput):
    """
    UNet2DConditionModel 的输出。

    参数:
        sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出,模型最后一层的输出。
    """

    sample: torch.Tensor = None  # 定义一个样本属性,默认为 None

# 定义 UNet2DConditionModel 类,表示一个条件 2D UNet 模型
class UNet2DConditionModel(
    ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
    r"""
    一个条件 2D UNet 模型,接受一个噪声样本、条件状态和时间步,并返回样本形状的输出。

    该模型继承自 [`ModelMixin`]。查看超类文档以获取其为所有模型实现的通用方法
    (例如下载或保存)。
    """

    _supports_gradient_checkpointing = True  # 表示该模型支持梯度检查点
    _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]  # 不进行拆分的模块列表

    @register_to_config  # 将该方法注册到配置中
    # 初始化方法,设置类的基本属性
        def __init__(
            # 样本大小,默认为 None
            self,
            sample_size: Optional[int] = None,
            # 输入通道数,默认为 4
            in_channels: int = 4,
            # 输出通道数,默认为 4
            out_channels: int = 4,
            # 是否将输入样本中心化,默认为 False
            center_input_sample: bool = False,
            # 是否将正弦函数翻转为余弦函数,默认为 True
            flip_sin_to_cos: bool = True,
            # 频率偏移量,默认为 0
            freq_shift: int = 0,
            # 向下采样的块类型,包含多种块类型
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D",
            ),
            # 中间块的类型,默认为 UNet 的中间块类型
            mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
            # 向上采样的块类型,包含多种块类型
            up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
            # 是否仅使用交叉注意力,默认为 False
            only_cross_attention: Union[bool, Tuple[bool]] = False,
            # 每个块的输出通道数
            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 每个块的层数,默认为 2
            layers_per_block: Union[int, Tuple[int]] = 2,
            # 下采样时的填充大小,默认为 1
            downsample_padding: int = 1,
            # 中间块的缩放因子,默认为 1
            mid_block_scale_factor: float = 1,
            # dropout 概率,默认为 0.0
            dropout: float = 0.0,
            # 激活函数类型,默认为 "silu"
            act_fn: str = "silu",
            # 归一化的组数,默认为 32
            norm_num_groups: Optional[int] = 32,
            # 归一化的 epsilon 值,默认为 1e-5
            norm_eps: float = 1e-5,
            # 交叉注意力的维度,默认为 1280
            cross_attention_dim: Union[int, Tuple[int]] = 1280,
            # 每个块的变换层数,默认为 1
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
            # 反向变换层的块数,默认为 None
            reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
            # 编码器隐藏层的维度,默认为 None
            encoder_hid_dim: Optional[int] = None,
            # 编码器隐藏层类型,默认为 None
            encoder_hid_dim_type: Optional[str] = None,
            # 注意力头的维度,默认为 8
            attention_head_dim: Union[int, Tuple[int]] = 8,
            # 注意力头的数量,默认为 None
            num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
            # 是否使用双交叉注意力,默认为 False
            dual_cross_attention: bool = False,
            # 是否使用线性投影,默认为 False
            use_linear_projection: bool = False,
            # 类嵌入类型,默认为 None
            class_embed_type: Optional[str] = None,
            # 附加嵌入类型,默认为 None
            addition_embed_type: Optional[str] = None,
            # 附加时间嵌入维度,默认为 None
            addition_time_embed_dim: Optional[int] = None,
            # 类嵌入数量,默认为 None
            num_class_embeds: Optional[int] = None,
            # 是否上溯注意力,默认为 False
            upcast_attention: bool = False,
            # ResNet 时间缩放偏移类型,默认为 "default"
            resnet_time_scale_shift: str = "default",
            # ResNet 是否跳过时间激活,默认为 False
            resnet_skip_time_act: bool = False,
            # ResNet 输出缩放因子,默认为 1.0
            resnet_out_scale_factor: float = 1.0,
            # 时间嵌入类型,默认为 "positional"
            time_embedding_type: str = "positional",
            # 时间嵌入维度,默认为 None
            time_embedding_dim: Optional[int] = None,
            # 时间嵌入激活函数,默认为 None
            time_embedding_act_fn: Optional[str] = None,
            # 时间步后激活函数,默认为 None
            timestep_post_act: Optional[str] = None,
            # 时间条件投影维度,默认为 None
            time_cond_proj_dim: Optional[int] = None,
            # 输入卷积核大小,默认为 3
            conv_in_kernel: int = 3,
            # 输出卷积核大小,默认为 3
            conv_out_kernel: int = 3,
            # 投影类嵌入输入维度,默认为 None
            projection_class_embeddings_input_dim: Optional[int] = None,
            # 注意力类型,默认为 "default"
            attention_type: str = "default",
            # 类嵌入是否拼接,默认为 False
            class_embeddings_concat: bool = False,
            # 中间块是否仅使用交叉注意力,默认为 None
            mid_block_only_cross_attention: Optional[bool] = None,
            # 交叉注意力归一化类型,默认为 None
            cross_attention_norm: Optional[str] = None,
            # 附加嵌入类型的头数量,默认为 64
            addition_embed_type_num_heads: int = 64,
    # 定义一个私有方法,用于检查配置参数
        def _check_config(
            self,
            # 定义下行块类型的元组,表示模型的结构
            down_block_types: Tuple[str],
            # 定义上行块类型的元组,表示模型的结构
            up_block_types: Tuple[str],
            # 定义仅使用交叉注意力的标志,可以是布尔值或布尔值的元组
            only_cross_attention: Union[bool, Tuple[bool]],
            # 定义每个块的输出通道数的元组,表示层的宽度
            block_out_channels: Tuple[int],
            # 定义每个块的层数,可以是整数或整数的元组
            layers_per_block: Union[int, Tuple[int]],
            # 定义交叉注意力维度,可以是整数或整数的元组
            cross_attention_dim: Union[int, Tuple[int]],
            # 定义每个块的变换器层数,可以是整数、整数的元组或元组的元组
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
            # 定义是否反转变换器层的布尔值
            reverse_transformer_layers_per_block: bool,
            # 定义注意力头的维度,表示注意力的分辨率
            attention_head_dim: int,
            # 定义注意力头的数量,可以是可选的整数或整数的元组
            num_attention_heads: Optional[Union[int, Tuple[int]],
    ):
        # 检查 down_block_types 和 up_block_types 的长度是否相同
        if len(down_block_types) != len(up_block_types):
            # 如果不同,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
            )

        # 检查 block_out_channels 和 down_block_types 的长度是否相同
        if len(block_out_channels) != len(down_block_types):
            # 如果不同,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
            )

        # 检查 only_cross_attention 是否为布尔值且长度与 down_block_types 相同
        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
            )

        # 检查 num_attention_heads 是否为整数且长度与 down_block_types 相同
        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
            )

        # 检查 attention_head_dim 是否为整数且长度与 down_block_types 相同
        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
            )

        # 检查 cross_attention_dim 是否为列表且长度与 down_block_types 相同
        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
            )

        # 检查 layers_per_block 是否为整数且长度与 down_block_types 相同
        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
            # 如果不满足条件,抛出值错误并提供详细信息
            raise ValueError(
                f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
            )
        # 检查 transformer_layers_per_block 是否为列表且 reverse_transformer_layers_per_block 为 None
        if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
            # 遍历 transformer_layers_per_block 中的每个层
            for layer_number_per_block in transformer_layers_per_block:
                # 检查每个层是否为列表
                if isinstance(layer_number_per_block, list):
                    # 如果是,则抛出值错误,提示需要提供 reverse_transformer_layers_per_block
                    raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")

    # 定义设置时间投影的私有方法
    def _set_time_proj(
        self,
        # 时间嵌入类型
        time_embedding_type: str,
        # 块输出通道数
        block_out_channels: int,
        # 是否翻转正弦和余弦
        flip_sin_to_cos: bool,
        # 频率偏移
        freq_shift: float,
        # 时间嵌入维度
        time_embedding_dim: int,
    # 返回时间嵌入维度和时间步输入维度的元组
    ) -> Tuple[int, int]:
        # 判断时间嵌入类型是否为傅里叶
        if time_embedding_type == "fourier":
            # 计算时间嵌入维度,默认为 block_out_channels[0] * 2
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
            # 确保时间嵌入维度为偶数
            if time_embed_dim % 2 != 0:
                raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
            # 初始化高斯傅里叶投影,设定相关参数
            self.time_proj = GaussianFourierProjection(
                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
            )
            # 设置时间步输入维度为时间嵌入维度
            timestep_input_dim = time_embed_dim
        # 判断时间嵌入类型是否为位置编码
        elif time_embedding_type == "positional":
            # 计算时间嵌入维度,默认为 block_out_channels[0] * 4
            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
            # 初始化时间步对象,设定相关参数
            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
            # 设置时间步输入维度为 block_out_channels[0]
            timestep_input_dim = block_out_channels[0]
        # 如果时间嵌入类型不合法,抛出错误
        else:
            raise ValueError(
                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
            )
    
        # 返回时间嵌入维度和时间步输入维度
        return time_embed_dim, timestep_input_dim
    
    # 定义设置编码器隐藏投影的方法
    def _set_encoder_hid_proj(
        self,
        encoder_hid_dim_type: Optional[str],
        cross_attention_dim: Union[int, Tuple[int]],
        encoder_hid_dim: Optional[int],
    ):
        # 如果编码器隐藏维度类型为空且隐藏维度已定义
        if encoder_hid_dim_type is None and encoder_hid_dim is not None:
            # 默认将编码器隐藏维度类型设为'text_proj'
            encoder_hid_dim_type = "text_proj"
            # 注册编码器隐藏维度类型到配置中
            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
            # 记录信息日志
            logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
    
        # 如果编码器隐藏维度为空且隐藏维度类型已定义,抛出错误
        if encoder_hid_dim is None and encoder_hid_dim_type is not None:
            raise ValueError(
                f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
            )
    
        # 判断编码器隐藏维度类型是否为'text_proj'
        if encoder_hid_dim_type == "text_proj":
            # 初始化线性投影层,输入维度为encoder_hid_dim,输出维度为cross_attention_dim
            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
        # 判断编码器隐藏维度类型是否为'text_image_proj'
        elif encoder_hid_dim_type == "text_image_proj":
            # 初始化文本-图像投影对象,设定相关参数
            self.encoder_hid_proj = TextImageProjection(
                text_embed_dim=encoder_hid_dim,
                image_embed_dim=cross_attention_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 判断编码器隐藏维度类型是否为'image_proj'
        elif encoder_hid_dim_type == "image_proj":
            # 初始化图像投影对象,设定相关参数
            self.encoder_hid_proj = ImageProjection(
                image_embed_dim=encoder_hid_dim,
                cross_attention_dim=cross_attention_dim,
            )
        # 如果编码器隐藏维度类型不合法,抛出错误
        elif encoder_hid_dim_type is not None:
            raise ValueError(
                f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
            )
        # 如果都不符合,将编码器隐藏投影设为None
        else:
            self.encoder_hid_proj = None
    # 设置类嵌入的私有方法
        def _set_class_embedding(
            self,
            class_embed_type: Optional[str],  # 嵌入类型,可能为 None 或特定字符串
            act_fn: str,  # 激活函数的名称
            num_class_embeds: Optional[int],  # 类嵌入数量,可能为 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影类嵌入输入维度,可能为 None
            time_embed_dim: int,  # 时间嵌入的维度
            timestep_input_dim: int,  # 时间步输入的维度
        ):
            # 如果嵌入类型为 None 且类嵌入数量不为 None
            if class_embed_type is None and num_class_embeds is not None:
                # 创建嵌入层,大小为类嵌入数量和时间嵌入维度
                self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
            # 如果嵌入类型为 "timestep"
            elif class_embed_type == "timestep":
                # 创建时间步嵌入对象
                self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
            # 如果嵌入类型为 "identity"
            elif class_embed_type == "identity":
                # 创建恒等层,输入和输出维度相同
                self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
            # 如果嵌入类型为 "projection"
            elif class_embed_type == "projection":
                # 如果投影类嵌入输入维度为 None,抛出错误
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 创建投影时间步嵌入对象
                self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果嵌入类型为 "simple_projection"
            elif class_embed_type == "simple_projection":
                # 如果投影类嵌入输入维度为 None,抛出错误
                if projection_class_embeddings_input_dim is None:
                    raise ValueError(
                        "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
                    )
                # 创建线性层作为简单投影
                self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
            # 如果没有匹配的嵌入类型
            else:
                # 将类嵌入设置为 None
                self.class_embedding = None
    
        # 设置附加嵌入的私有方法
        def _set_add_embedding(
            self,
            addition_embed_type: str,  # 附加嵌入类型
            addition_embed_type_num_heads: int,  # 附加嵌入类型的头数
            addition_time_embed_dim: Optional[int],  # 附加时间嵌入维度,可能为 None
            flip_sin_to_cos: bool,  # 是否翻转正弦到余弦
            freq_shift: float,  # 频率偏移量
            cross_attention_dim: Optional[int],  # 交叉注意力维度,可能为 None
            encoder_hid_dim: Optional[int],  # 编码器隐藏维度,可能为 None
            projection_class_embeddings_input_dim: Optional[int],  # 投影类嵌入输入维度,可能为 None
            time_embed_dim: int,  # 时间嵌入维度
    ):
        # 检查附加嵌入类型是否为 "text"
        if addition_embed_type == "text":
            # 如果编码器隐藏维度不为 None,则使用该维度
            if encoder_hid_dim is not None:
                text_time_embedding_from_dim = encoder_hid_dim
            # 否则使用交叉注意力维度
            else:
                text_time_embedding_from_dim = cross_attention_dim

            # 创建文本时间嵌入对象
            self.add_embedding = TextTimeEmbedding(
                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
            )
        # 检查附加嵌入类型是否为 "text_image"
        elif addition_embed_type == "text_image":
            # text_embed_dim 和 image_embed_dim 不必是 `cross_attention_dim`,为了避免 __init__ 过于繁杂
            # 在这里设置为 `cross_attention_dim`,因为这是当前唯一使用情况的所需维度 (Kandinsky 2.1)
            self.add_embedding = TextImageTimeEmbedding(
                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
            )
        # 检查附加嵌入类型是否为 "text_time"
        elif addition_embed_type == "text_time":
            # 创建时间投影对象
            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
            # 创建时间嵌入对象
            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
        # 检查附加嵌入类型是否为 "image"
        elif addition_embed_type == "image":
            # Kandinsky 2.2
            # 创建图像时间嵌入对象
            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 检查附加嵌入类型是否为 "image_hint"
        elif addition_embed_type == "image_hint":
            # Kandinsky 2.2 ControlNet
            # 创建图像提示时间嵌入对象
            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
        # 检查附加嵌入类型是否为 None 以外的值
        elif addition_embed_type is not None:
            # 抛出值错误,提示无效的附加嵌入类型
            raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")

    # 定义一个属性方法,用于设置位置网络
    def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
        # 检查注意力类型是否为 "gated" 或 "gated-text-image"
        if attention_type in ["gated", "gated-text-image"]:
            positive_len = 768  # 默认的正向长度
            # 如果交叉注意力维度是整数,则使用该值
            if isinstance(cross_attention_dim, int):
                positive_len = cross_attention_dim
            # 如果交叉注意力维度是列表或元组,则使用第一个值
            elif isinstance(cross_attention_dim, (list, tuple)):
                positive_len = cross_attention_dim[0]

            # 根据注意力类型确定特征类型
            feature_type = "text-only" if attention_type == "gated" else "text-image"
            # 创建 GLIGEN 文本边界框投影对象
            self.position_net = GLIGENTextBoundingboxProjection(
                positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
            )

    # 定义一个属性
    @property
    # 定义一个方法,返回一个字典,包含模型中所有的注意力处理器
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # 初始化一个空字典,用于存储注意力处理器
        processors = {}

        # 定义一个递归函数,用于添加处理器到字典
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 检查模块是否有获取处理器的方法
            if hasattr(module, "get_processor"):
                # 将处理器添加到字典中,键为名称,值为处理器
                processors[f"{name}.processor"] = module.get_processor()

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,处理子模块
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            # 返回更新后的处理器字典
            return processors

        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,添加处理器
            fn_recursive_add_processors(name, module, processors)

        # 返回包含所有处理器的字典
        return processors

    # 定义一个方法,设置用于计算注意力的处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        # 获取当前处理器的数量
        count = len(self.attn_processors.keys())

        # 如果传入的是字典,且字典长度与注意力层数量不匹配,抛出错误
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        # 定义一个递归函数,用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 检查模块是否有设置处理器的方法
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中弹出对应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))

            # 遍历模块的所有子模块
            for sub_name, child in module.named_children():
                # 递归调用,设置子模块的处理器
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        # 遍历当前模块的所有子模块
        for name, module in self.named_children():
            # 调用递归函数,设置处理器
            fn_recursive_attn_processor(name, module, processor)
    # 定义设置默认注意力处理器的方法
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器并设置默认的注意力实现。
        """
        # 检查所有注意力处理器是否属于添加的键值注意力处理器类
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建添加键值注意力处理器的实例
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器类
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建标准注意力处理器的实例
            processor = AttnProcessor()
        else:
            # 如果注意力处理器类型不匹配,则抛出错误
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        # 设置选定的注意力处理器
        self.set_attn_processor(processor)

    # 定义设置梯度检查点的方法
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块具有梯度检查点属性,则设置其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value

    # 定义启用 FreeU 机制的方法
    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""启用 FreeU 机制,详细信息请见 https://arxiv.org/abs/2309.11497。

        在缩放因子后面的后缀表示它们被应用的阶段块。

        请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中效果良好的值组合。

        参数:
            s1 (`float`):
                阶段 1 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            s2 (`float`):
                阶段 2 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
            b1 (`float`): 阶段 1 的缩放因子,用于增强骨干特征的贡献。
            b2 (`float`): 阶段 2 的缩放因子,用于增强骨干特征的贡献。
        """
        # 遍历上采样块并设置相应的缩放因子
        for i, upsample_block in enumerate(self.up_blocks):
            setattr(upsample_block, "s1", s1)  # 设置阶段 1 的缩放因子
            setattr(upsample_block, "s2", s2)  # 设置阶段 2 的缩放因子
            setattr(upsample_block, "b1", b1)  # 设置阶段 1 的骨干缩放因子
            setattr(upsample_block, "b2", b2)  # 设置阶段 2 的骨干缩放因子

    # 定义禁用 FreeU 机制的方法
    def disable_freeu(self):
        """禁用 FreeU 机制。"""
        freeu_keys = {"s1", "s2", "b1", "b2"}  # 定义 FreeU 相关的键
        # 遍历上采样块
        for i, upsample_block in enumerate(self.up_blocks):
            # 遍历每个 FreeU 键
            for k in freeu_keys:
                # 如果上采样块具有该键的属性或其值不为 None,则将其值设置为 None
                if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
                    setattr(upsample_block, k, None)
    # 定义一个方法,用于启用融合的 QKV 投影
    def fuse_qkv_projections(self):
        """
        启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。
        对于交叉注意力模块,键和值的投影矩阵被融合。

        <Tip warning={true}>

        此 API 是 🧪 实验性的。

        </Tip>
        """
        # 初始化原始注意力处理器为 None
        self.original_attn_processors = None

        # 遍历注意力处理器,检查是否包含“Added”字样
        for _, attn_processor in self.attn_processors.items():
            # 如果发现添加的 KV 投影,抛出错误
            if "Added" in str(attn_processor.__class__.__name__):
                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

        # 保存当前的注意力处理器
        self.original_attn_processors = self.attn_processors

        # 遍历模块,查找类型为 Attention 的模块
        for module in self.modules():
            if isinstance(module, Attention):
                # 启用投影融合
                module.fuse_projections(fuse=True)

        # 设置注意力处理器为融合的处理器
        self.set_attn_processor(FusedAttnProcessor2_0())

    # 定义一个方法,用于禁用已启用的融合 QKV 投影
    def unfuse_qkv_projections(self):
        """禁用已启用的融合 QKV 投影。

        <Tip warning={true}>

        此 API 是 🧪 实验性的。

        </Tip>

        """
        # 如果原始注意力处理器不为 None,则恢复到原始处理器
        if self.original_attn_processors is not None:
            self.set_attn_processor(self.original_attn_processors)

    # 定义一个方法,用于获取时间嵌入
    def get_time_embed(
        self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
    ) -> Optional[torch.Tensor]:
        # 将时间步长赋值给 timesteps
        timesteps = timestep
        # 如果 timesteps 不是张量
        if not torch.is_tensor(timesteps):
            # TODO: 这需要在 CPU 和 GPU 之间同步。因此,如果可以的话,尽量将 timesteps 作为张量传递
            # 这将是使用 `match` 语句的好例子(Python 3.10+)
            is_mps = sample.device.type == "mps"  # 检查设备类型是否为 MPS
            # 根据时间步长类型设置数据类型
            if isinstance(timestep, float):
                dtype = torch.float32 if is_mps else torch.float64  # 浮点数类型
            else:
                dtype = torch.int32 if is_mps else torch.int64  # 整数类型
            # 将 timesteps 转换为张量
            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
        # 如果 timesteps 是标量(零维张量),则扩展维度
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)  # 增加一个维度并转移到样本设备

        # 将 timesteps 广播到与样本批次维度兼容的方式
        timesteps = timesteps.expand(sample.shape[0])  # 扩展到批次大小

        # 通过时间投影获得时间嵌入
        t_emb = self.time_proj(timesteps)
        # `Timesteps` 不包含任何权重,总是返回 f32 张量
        # 但时间嵌入可能实际在 fp16 中运行,因此需要进行类型转换。
        # 可能有更好的方法来封装这一点。
        t_emb = t_emb.to(dtype=sample.dtype)  # 转换 t_emb 的数据类型
        # 返回时间嵌入
        return t_emb
    # 获取类嵌入的方法,接受样本张量和可选的类标签
        def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
            # 初始化类嵌入为 None
            class_emb = None
            # 检查类嵌入是否存在
            if self.class_embedding is not None:
                # 如果类标签为 None,抛出错误
                if class_labels is None:
                    raise ValueError("class_labels should be provided when num_class_embeds > 0")
    
                # 检查类嵌入类型是否为时间步
                if self.config.class_embed_type == "timestep":
                    # 将类标签通过时间投影处理
                    class_labels = self.time_proj(class_labels)
    
                    # `Timesteps` 不包含权重,总是返回 f32 张量
                    # 可能有更好的方式来封装这一点
                    class_labels = class_labels.to(dtype=sample.dtype)
    
                # 获取类嵌入并转换为与样本相同的数据类型
                class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
            # 返回类嵌入
            return class_emb
    
        # 获取增强嵌入的方法,接受嵌入张量、编码器隐藏状态和额外条件参数
        def get_aug_embed(
            self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
        # 处理编码器隐藏状态的方法,接受编码器隐藏状态和额外条件参数
        def process_encoder_hidden_states(
            self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
    # 定义返回类型为 torch.Tensor
        ) -> torch.Tensor:
            # 检查是否存在隐藏层投影,并且配置为 "text_proj"
            if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
                # 使用文本投影对编码隐藏状态进行转换
                encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
            # 检查是否存在隐藏层投影,并且配置为 "text_image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
                # 检查条件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 抛出错误提示缺少必要参数
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
    
                # 获取传入的图像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 对编码隐藏状态和图像嵌入进行投影转换
                encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
            # 检查是否存在隐藏层投影,并且配置为 "image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
                # 检查条件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 抛出错误提示缺少必要参数
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
                # 获取传入的图像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 使用图像嵌入对编码隐藏状态进行投影转换
                encoder_hidden_states = self.encoder_hid_proj(image_embeds)
            # 检查是否存在隐藏层投影,并且配置为 "ip_image_proj"
            elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
                # 检查条件中是否包含 "image_embeds"
                if "image_embeds" not in added_cond_kwargs:
                    # 抛出错误提示缺少必要参数
                    raise ValueError(
                        f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
                    )
    
                # 如果存在文本编码器的隐藏层投影,则对编码隐藏状态进行投影转换
                if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
                    encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
    
                # 获取传入的图像嵌入
                image_embeds = added_cond_kwargs.get("image_embeds")
                # 对图像嵌入进行投影转换
                image_embeds = self.encoder_hid_proj(image_embeds)
                # 将编码隐藏状态和图像嵌入打包成元组
                encoder_hidden_states = (encoder_hidden_states, image_embeds)
            # 返回最终的编码隐藏状态
            return encoder_hidden_states
    # 定义前向传播函数
    def forward(
            # 输入的样本数据,类型为 PyTorch 张量
            sample: torch.Tensor,
            # 当前时间步,类型可以是张量、浮点数或整数
            timestep: Union[torch.Tensor, float, int],
            # 编码器的隐藏状态,类型为 PyTorch 张量
            encoder_hidden_states: torch.Tensor,
            # 可选的类别标签,类型为 PyTorch 张量
            class_labels: Optional[torch.Tensor] = None,
            # 可选的时间步条件,类型为 PyTorch 张量
            timestep_cond: Optional[torch.Tensor] = None,
            # 可选的注意力掩码,类型为 PyTorch 张量
            attention_mask: Optional[torch.Tensor] = None,
            # 可选的交叉注意力参数,类型为字典,包含额外的关键字参数
            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
            # 可选的附加条件参数,类型为字典,键为字符串,值为 PyTorch 张量
            added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
            # 可选的下层块附加残差,类型为元组,包含 PyTorch 张量
            down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
            # 可选的中间块附加残差,类型为 PyTorch 张量
            mid_block_additional_residual: Optional[torch.Tensor] = None,
            # 可选的下层内部块附加残差,类型为元组,包含 PyTorch 张量
            down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
            # 可选的编码器注意力掩码,类型为 PyTorch 张量
            encoder_attention_mask: Optional[torch.Tensor] = None,
            # 返回结果的标志,布尔值,默认值为 True
            return_dict: bool = True,

.\diffusers\models\unets\unet_2d_condition_flax.py

# 版权声明,表明该文件的版权所有者及相关信息
# 
# 根据 Apache License 2.0 版本的许可协议
# 除非遵守该许可协议,否则不得使用本文件
# 可以在以下地址获取许可证副本
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,软件在“按现状”基础上分发,
# 不提供任何明示或暗示的保证或条件
# 请参阅许可证了解管理权限和限制的具体条款
from typing import Dict, Optional, Tuple, Union  # 从 typing 模块导入类型注释工具

import flax  # 导入 flax 库用于构建神经网络
import flax.linen as nn  # 从 flax 中导入 linen 模块,方便定义神经网络层
import jax  # 导入 jax 库用于高效数值计算
import jax.numpy as jnp  # 导入 jax 的 numpy 模块,提供张量操作功能
from flax.core.frozen_dict import FrozenDict  # 从 flax 导入 FrozenDict,用于不可变字典

from ...configuration_utils import ConfigMixin, flax_register_to_config  # 导入配置相关工具
from ...utils import BaseOutput  # 导入基础输出类
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps  # 导入时间步嵌入相关类
from ..modeling_flax_utils import FlaxModelMixin  # 导入模型混合类
from .unet_2d_blocks_flax import (  # 导入 UNet 的不同构建块
    FlaxCrossAttnDownBlock2D,  # 导入交叉注意力下采样块
    FlaxCrossAttnUpBlock2D,  # 导入交叉注意力上采样块
    FlaxDownBlock2D,  # 导入下采样块
    FlaxUNetMidBlock2DCrossAttn,  # 导入中间块,带有交叉注意力
    FlaxUpBlock2D,  # 导入上采样块
)


@flax.struct.dataclass  # 使用 flax 的数据类装饰器
class FlaxUNet2DConditionOutput(BaseOutput):  # 定义 UNet 条件输出类,继承自基础输出类
    """
    [`FlaxUNet2DConditionModel`] 的输出。

    参数:
        sample (`jnp.ndarray` 的形状为 `(batch_size, num_channels, height, width)`):
            基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
    """

    sample: jnp.ndarray  # 定义输出样本,数据类型为 jnp.ndarray


@flax_register_to_config  # 使用装饰器将模型注册到配置中
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):  # 定义条件 UNet 模型类,继承多个混合类
    r"""
    一个条件 2D UNet 模型,接收噪声样本、条件状态和时间步,并返回样本形状的输出。

    此模型继承自 [`FlaxModelMixin`]。请查看超类文档以了解其通用方法
    (例如下载或保存)。

    此模型也是 Flax Linen 的 [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    子类。将其作为常规 Flax Linen 模块使用,具体使用和行为请参阅 Flax 文档。

    支持以下 JAX 特性:
    - [即时编译 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [自动微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [并行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # 参数说明部分
    Parameters:
        # 输入样本的大小,类型为整型,选填参数
        sample_size (`int`, *optional*):
            The size of the input sample.
        # 输入样本的通道数,类型为整型,默认为4
        in_channels (`int`, *optional*, defaults to 4):
            The number of channels in the input sample.
        # 输出的通道数,类型为整型,默认为4
        out_channels (`int`, *optional*, defaults to 4):
            The number of channels in the output.
        # 使用的下采样块的元组,类型为字符串元组,默认为特定的下采样块
        down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
            The tuple of downsample blocks to use.
        # 使用的上采样块的元组,类型为字符串元组,默认为特定的上采样块
        up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
            The tuple of upsample blocks to use.
        # UNet中间块的类型,类型为字符串,默认为"UNetMidBlock2DCrossAttn"
        mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
            Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
            is skipped.
        # 每个块的输出通道的元组,类型为整型元组,默认为特定的输出通道
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 每个块的层数,类型为整型,默认为2
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        # 注意力头的维度,可以是整型或整型元组,默认为8
        attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
            The dimension of the attention heads.
        # 注意力头的数量,可以是整型或整型元组,选填参数
        num_attention_heads (`int` or `Tuple[int]`, *optional*):
            The number of attention heads.
        # 交叉注意力特征的维度,类型为整型,默认为768
        cross_attention_dim (`int`, *optional*, defaults to 768):
            The dimension of the cross attention features.
        # dropout的概率,类型为浮点数,默认为0
        dropout (`float`, *optional*, defaults to 0):
            Dropout probability for down, up and bottleneck blocks.
        # 是否在时间嵌入中将正弦转换为余弦,类型为布尔值,默认为True
        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
            Whether to flip the sin to cos in the time embedding.
        # 应用于时间嵌入的频率偏移,类型为整型,默认为0
        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
        # 是否启用内存高效的注意力机制,类型为布尔值,默认为False
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
        # 是否将头维度拆分为新的轴进行自注意力计算,类型为布尔值,默认为False
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
    """
    
    # 定义样本大小,默认为32
    sample_size: int = 32
    # 定义输入通道数,默认为4
    in_channels: int = 4
    # 定义输出通道数,默认为4
    out_channels: int = 4
    # 定义下采样块的类型元组
    down_block_types: Tuple[str, ...] = (
        "CrossAttnDownBlock2D",  # 第一个下采样块
        "CrossAttnDownBlock2D",  # 第二个下采样块
        "CrossAttnDownBlock2D",  # 第三个下采样块
        "DownBlock2D",           # 第四个下采样块
    )
    # 定义上采样块的类型元组
    up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
    # 定义中间块类型,默认为"UNetMidBlock2DCrossAttn"
    mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
    # 定义是否只使用交叉注意力,默认为False
    only_cross_attention: Union[bool, Tuple[bool]] = False
    # 定义每个块的输出通道元组
    block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
    # 每个块的层数设为 2
    layers_per_block: int = 2
    # 注意力头的维度设为 8
    attention_head_dim: Union[int, Tuple[int, ...]] = 8
    # 可选的注意力头数量,默认为 None
    num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
    # 跨注意力的维度设为 1280
    cross_attention_dim: int = 1280
    # dropout 比率设为 0.0
    dropout: float = 0.0
    # 是否使用线性投影,默认为 False
    use_linear_projection: bool = False
    # 数据类型设为 float32
    dtype: jnp.dtype = jnp.float32
    # flip_sin_to_cos 设为 True
    flip_sin_to_cos: bool = True
    # 频移设为 0
    freq_shift: int = 0
    # 是否使用内存高效的注意力,默认为 False
    use_memory_efficient_attention: bool = False
    # 是否拆分头维度,默认为 False
    split_head_dim: bool = False
    # 每个块的变换层数设为 1
    transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
    # 可选的附加嵌入类型,默认为 None
    addition_embed_type: Optional[str] = None
    # 可选的附加时间嵌入维度,默认为 None
    addition_time_embed_dim: Optional[int] = None
    # 附加嵌入类型的头数量设为 64
    addition_embed_type_num_heads: int = 64
    # 可选的投影类嵌入输入维度,默认为 None
    projection_class_embeddings_input_dim: Optional[int] = None

    # 初始化权重函数,接受随机数生成器作为参数
    def init_weights(self, rng: jax.Array) -> FrozenDict:
        # 初始化输入张量的形状
        sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
        # 创建全零的输入样本
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)
        # 创建全一的时间步张量
        timesteps = jnp.ones((1,), dtype=jnp.int32)
        # 初始化编码器的隐藏状态
        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)

        # 分割随机数生成器,用于参数和 dropout
        params_rng, dropout_rng = jax.random.split(rng)
        # 创建随机数字典
        rngs = {"params": params_rng, "dropout": dropout_rng}

        # 初始化附加条件关键字参数
        added_cond_kwargs = None
        # 判断嵌入类型是否为 "text_time"
        if self.addition_embed_type == "text_time":
            # 通过反向计算获取期望的文本嵌入维度
            is_refiner = (
                5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
                == self.config.projection_class_embeddings_input_dim
            )
            # 确定微条件的数量
            num_micro_conditions = 5 if is_refiner else 6

            # 计算文本嵌入维度
            text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
                num_micro_conditions * self.config.addition_time_embed_dim
            )

            # 计算时间 ID 的通道数和维度
            time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
            time_ids_dims = time_ids_channels // self.addition_time_embed_dim
            # 创建附加条件关键字参数字典
            added_cond_kwargs = {
                "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
                "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
            }
        # 返回初始化后的参数字典
        return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]

    # 定义调用函数,接收多个输入参数
    def __call__(
        self,
        sample: jnp.ndarray,
        timesteps: Union[jnp.ndarray, float, int],
        encoder_hidden_states: jnp.ndarray,
        # 可选的附加条件关键字参数
        added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
        # 可选的下块附加残差
        down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
        # 可选的中块附加残差
        mid_block_additional_residual: Optional[jnp.ndarray] = None,
        # 是否返回字典,默认为 True
        return_dict: bool = True,
        # 是否为训练模式,默认为 False
        train: bool = False,

.\diffusers\models\unets\unet_3d_blocks.py

# 版权所有 2024 HuggingFace 团队。所有权利保留。
#
# 根据 Apache 许可证 2.0 版本("许可证")授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下位置获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件是按“原样”基础分发,
# 不提供任何形式的保证或条件,无论是明示或暗示。
# 有关许可证的具体权限和限制,请参阅许可证。

# 导入类型提示中的任何类型
from typing import Any, Dict, Optional, Tuple, Union

# 导入 PyTorch 库
import torch
# 从 PyTorch 导入神经网络模块
from torch import nn

# 导入实用工具函数,包括弃用和日志记录
from ...utils import deprecate, is_torch_version, logging
# 导入 PyTorch 相关的工具函数
from ...utils.torch_utils import apply_freeu
# 导入注意力机制相关的类
from ..attention import Attention
# 导入 ResNet 相关的类
from ..resnet import (
    Downsample2D,  # 导入 2D 下采样模块
    ResnetBlock2D,  # 导入 2D ResNet 块
    SpatioTemporalResBlock,  # 导入时空 ResNet 块
    TemporalConvLayer,  # 导入时间卷积层
    Upsample2D,  # 导入 2D 上采样模块
)
# 导入 2D 变换器模型
from ..transformers.transformer_2d import Transformer2DModel
# 导入时间相关的变换器模型
from ..transformers.transformer_temporal import (
    TransformerSpatioTemporalModel,  # 导入时空变换器模型
    TransformerTemporalModel,  # 导入时间变换器模型
)
# 导入运动模型的 UNet 相关类
from .unet_motion_model import (
    CrossAttnDownBlockMotion,  # 导入交叉注意力下块运动类
    CrossAttnUpBlockMotion,  # 导入交叉注意力上块运动类
    DownBlockMotion,  # 导入下块运动类
    UNetMidBlockCrossAttnMotion,  # 导入中间块交叉注意力运动类
    UpBlockMotion,  # 导入上块运动类
)

# 创建一个日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义 DownBlockMotion 类,继承自 DownBlockMotion
class DownBlockMotion(DownBlockMotion):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 设置弃用消息,提醒用户变更
        deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
        # 调用弃用函数,记录弃用信息
        deprecate("DownBlockMotion", "1.0.0", deprecation_message)
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

# 定义 CrossAttnDownBlockMotion 类,继承自 CrossAttnDownBlockMotion
class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 设置弃用消息,提醒用户变更
        deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
        # 调用弃用函数,记录弃用信息
        deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

# 定义 UpBlockMotion 类,继承自 UpBlockMotion
class UpBlockMotion(UpBlockMotion):
    # 初始化方法,接受任意参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 设置弃用消息,提醒用户变更
        deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
        # 调用弃用函数,记录弃用信息
        deprecate("UpBlockMotion", "1.0.0", deprecation_message)
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)

# 定义 CrossAttnUpBlockMotion 类,继承自 CrossAttnUpBlockMotion
class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
    # 初始化方法,用于创建类的实例
        def __init__(self, *args, **kwargs):
            # 定义一个关于导入的弃用警告信息
            deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
            # 调用弃用警告函数,记录该功能的弃用信息及版本
            deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
            # 调用父类的初始化方法,传递参数以初始化父类部分
            super().__init__(*args, **kwargs)
# 定义一个名为 UNetMidBlockCrossAttnMotion 的类,继承自同名父类
class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
    # 初始化方法,接收可变参数和关键字参数
    def __init__(self, *args, **kwargs):
        # 定义弃用警告信息,提示用户更新导入路径
        deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
        # 触发弃用警告
        deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
        # 调用父类的初始化方法
        super().__init__(*args, **kwargs)


# 定义一个函数,返回不同类型的下采样块
def get_down_block(
    # 定义参数,类型和含义
    down_block_type: str,  # 下采样块的类型
    num_layers: int,  # 层数
    in_channels: int,  # 输入通道数
    out_channels: int,  # 输出通道数
    temb_channels: int,  # 时间嵌入通道数
    add_downsample: bool,  # 是否添加下采样
    resnet_eps: float,  # ResNet 的 epsilon 参数
    resnet_act_fn: str,  # ResNet 的激活函数
    num_attention_heads: int,  # 注意力头数
    resnet_groups: Optional[int] = None,  # ResNet 的分组数,可选
    cross_attention_dim: Optional[int] = None,  # 交叉注意力维度,可选
    downsample_padding: Optional[int] = None,  # 下采样填充,可选
    dual_cross_attention: bool = False,  # 是否使用双重交叉注意力
    use_linear_projection: bool = True,  # 是否使用线性投影
    only_cross_attention: bool = False,  # 是否仅使用交叉注意力
    upcast_attention: bool = False,  # 是否提升注意力精度
    resnet_time_scale_shift: str = "default",  # ResNet 时间尺度偏移
    temporal_num_attention_heads: int = 8,  # 时间注意力头数
    temporal_max_seq_length: int = 32,  # 时间序列最大长度
    transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每个块的变换器层数
    temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 时间变换器每块层数
    dropout: float = 0.0,  # dropout 概率
) -> Union[
    "DownBlock3D",  # 返回的可能类型之一:3D 下采样块
    "CrossAttnDownBlock3D",  # 返回的可能类型之二:交叉注意力下采样块
    "DownBlockSpatioTemporal",  # 返回的可能类型之三:时空下采样块
    "CrossAttnDownBlockSpatioTemporal",  # 返回的可能类型之四:时空交叉注意力下采样块
]:
    # 检查下采样块类型是否为 DownBlock3D
    if down_block_type == "DownBlock3D":
        # 创建并返回 DownBlock3D 实例
        return DownBlock3D(
            num_layers=num_layers,  # 传入层数
            in_channels=in_channels,  # 传入输入通道数
            out_channels=out_channels,  # 传入输出通道数
            temb_channels=temb_channels,  # 传入时间嵌入通道数
            add_downsample=add_downsample,  # 传入是否添加下采样
            resnet_eps=resnet_eps,  # 传入 ResNet 的 epsilon 参数
            resnet_act_fn=resnet_act_fn,  # 传入激活函数
            resnet_groups=resnet_groups,  # 传入分组数
            downsample_padding=downsample_padding,  # 传入下采样填充
            resnet_time_scale_shift=resnet_time_scale_shift,  # 传入时间尺度偏移
            dropout=dropout,  # 传入 dropout 概率
        )
    # 检查下采样块类型是否为 CrossAttnDownBlock3D
    elif down_block_type == "CrossAttnDownBlock3D":
        # 如果交叉注意力维度未指定,抛出错误
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
        # 创建并返回 CrossAttnDownBlock3D 实例
        return CrossAttnDownBlock3D(
            num_layers=num_layers,  # 传入层数
            in_channels=in_channels,  # 传入输入通道数
            out_channels=out_channels,  # 传入输出通道数
            temb_channels=temb_channels,  # 传入时间嵌入通道数
            add_downsample=add_downsample,  # 传入是否添加下采样
            resnet_eps=resnet_eps,  # 传入 ResNet 的 epsilon 参数
            resnet_act_fn=resnet_act_fn,  # 传入激活函数
            resnet_groups=resnet_groups,  # 传入分组数
            downsample_padding=downsample_padding,  # 传入下采样填充
            cross_attention_dim=cross_attention_dim,  # 传入交叉注意力维度
            num_attention_heads=num_attention_heads,  # 传入注意力头数
            dual_cross_attention=dual_cross_attention,  # 传入是否使用双重交叉注意力
            use_linear_projection=use_linear_projection,  # 传入是否使用线性投影
            only_cross_attention=only_cross_attention,  # 传入是否仅使用交叉注意力
            upcast_attention=upcast_attention,  # 传入是否提升注意力精度
            resnet_time_scale_shift=resnet_time_scale_shift,  # 传入时间尺度偏移
            dropout=dropout,  # 传入 dropout 概率
        )
    # 检查下一个块的类型是否为时空下采样块
    elif down_block_type == "DownBlockSpatioTemporal":
        # 为 SDV 进行了添加
        # 返回一个时空下采样块的实例
        return DownBlockSpatioTemporal(
            # 设置层数
            num_layers=num_layers,
            # 输入通道数
            in_channels=in_channels,
            # 输出通道数
            out_channels=out_channels,
            # 时间嵌入通道数
            temb_channels=temb_channels,
            # 是否添加下采样
            add_downsample=add_downsample,
        )
    # 检查下一个块的类型是否为交叉注意力时空下采样块
    elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
        # 为 SDV 进行了添加
        # 如果没有指定交叉注意力维度,抛出错误
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
        # 返回一个交叉注意力时空下采样块的实例
        return CrossAttnDownBlockSpatioTemporal(
            # 输入通道数
            in_channels=in_channels,
            # 输出通道数
            out_channels=out_channels,
            # 时间嵌入通道数
            temb_channels=temb_channels,
            # 设置层数
            num_layers=num_layers,
            # 每个块的变换层数
            transformer_layers_per_block=transformer_layers_per_block,
            # 是否添加下采样
            add_downsample=add_downsample,
            # 设置交叉注意力维度
            cross_attention_dim=cross_attention_dim,
            # 注意力头数
            num_attention_heads=num_attention_heads,
        )

    # 如果块类型不匹配,则抛出错误
    raise ValueError(f"{down_block_type} does not exist.")
# 定义函数 get_up_block,返回不同类型的上采样模块
def get_up_block(
    # 上采样块类型
    up_block_type: str,
    # 层数
    num_layers: int,
    # 输入通道数
    in_channels: int,
    # 输出通道数
    out_channels: int,
    # 上一层的输出通道数
    prev_output_channel: int,
    # 时间嵌入通道数
    temb_channels: int,
    # 是否添加上采样
    add_upsample: bool,
    # ResNet 的 epsilon 值
    resnet_eps: float,
    # ResNet 的激活函数类型
    resnet_act_fn: str,
    # 注意力头的数量
    num_attention_heads: int,
    # 分辨率索引(可选)
    resolution_idx: Optional[int] = None,
    # ResNet 组数(可选)
    resnet_groups: Optional[int] = None,
    # 跨注意力维度(可选)
    cross_attention_dim: Optional[int] = None,
    # 是否使用双重跨注意力
    dual_cross_attention: bool = False,
    # 是否使用线性投影
    use_linear_projection: bool = True,
    # 是否仅使用跨注意力
    only_cross_attention: bool = False,
    # 是否提升注意力计算
    upcast_attention: bool = False,
    # ResNet 时间尺度移位的设置
    resnet_time_scale_shift: str = "default",
    # 时间上的注意力头数量
    temporal_num_attention_heads: int = 8,
    # 时间上的跨注意力维度(可选)
    temporal_cross_attention_dim: Optional[int] = None,
    # 时间序列最大长度
    temporal_max_seq_length: int = 32,
    # 每个块的变换器层数
    transformer_layers_per_block: Union[int, Tuple[int]] = 1,
    # 每个块的时间变换器层数
    temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
    # dropout 概率
    dropout: float = 0.0,
) -> Union[
    # 返回类型为不同的上采样块
    "UpBlock3D",
    "CrossAttnUpBlock3D",
    "UpBlockSpatioTemporal",
    "CrossAttnUpBlockSpatioTemporal",
]:
    # 判断上采样块类型是否为 UpBlock3D
    if up_block_type == "UpBlock3D":
        # 创建并返回 UpBlock3D 实例
        return UpBlock3D(
            # 传入参数设置
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            resnet_time_scale_shift=resnet_time_scale_shift,
            resolution_idx=resolution_idx,
            dropout=dropout,
        )
    # 判断上采样块类型是否为 CrossAttnUpBlock3D
    elif up_block_type == "CrossAttnUpBlock3D":
        # 检查是否提供跨注意力维度
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
        # 创建并返回 CrossAttnUpBlock3D 实例
        return CrossAttnUpBlock3D(
            # 传入参数设置
            num_layers=num_layers,
            in_channels=in_channels,
            out_channels=out_channels,
            prev_output_channel=prev_output_channel,
            temb_channels=temb_channels,
            add_upsample=add_upsample,
            resnet_eps=resnet_eps,
            resnet_act_fn=resnet_act_fn,
            resnet_groups=resnet_groups,
            cross_attention_dim=cross_attention_dim,
            num_attention_heads=num_attention_heads,
            dual_cross_attention=dual_cross_attention,
            use_linear_projection=use_linear_projection,
            only_cross_attention=only_cross_attention,
            upcast_attention=upcast_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            resolution_idx=resolution_idx,
            dropout=dropout,
        )
    # 检查上升块类型是否为 "UpBlockSpatioTemporal"
    elif up_block_type == "UpBlockSpatioTemporal":
        # 为 SDV 添加的内容
        # 返回 UpBlockSpatioTemporal 实例,使用指定的参数
        return UpBlockSpatioTemporal(
            # 层数参数
            num_layers=num_layers,
            # 输入通道数
            in_channels=in_channels,
            # 输出通道数
            out_channels=out_channels,
            # 前一个输出通道数
            prev_output_channel=prev_output_channel,
            # 时间嵌入通道数
            temb_channels=temb_channels,
            # 分辨率索引
            resolution_idx=resolution_idx,
            # 是否添加上采样
            add_upsample=add_upsample,
        )
    # 检查上升块类型是否为 "CrossAttnUpBlockSpatioTemporal"
    elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
        # 为 SDV 添加的内容
        # 如果没有指定交叉注意力维度,抛出错误
        if cross_attention_dim is None:
            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
        # 返回 CrossAttnUpBlockSpatioTemporal 实例,使用指定的参数
        return CrossAttnUpBlockSpatioTemporal(
            # 输入通道数
            in_channels=in_channels,
            # 输出通道数
            out_channels=out_channels,
            # 前一个输出通道数
            prev_output_channel=prev_output_channel,
            # 时间嵌入通道数
            temb_channels=temb_channels,
            # 层数参数
            num_layers=num_layers,
            # 每个块的变换层数
            transformer_layers_per_block=transformer_layers_per_block,
            # 是否添加上采样
            add_upsample=add_upsample,
            # 交叉注意力维度
            cross_attention_dim=cross_attention_dim,
            # 注意力头数
            num_attention_heads=num_attention_heads,
            # 分辨率索引
            resolution_idx=resolution_idx,
        )

    # 如果上升块类型不符合任何已知类型,抛出错误
    raise ValueError(f"{up_block_type} does not exist.")
# 定义一个名为 UNetMidBlock3DCrossAttn 的类,继承自 nn.Module
class UNetMidBlock3DCrossAttn(nn.Module):
    # 初始化方法,设置类的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        temb_channels: int,  # 时间嵌入通道数
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        resnet_eps: float = 1e-6,  # ResNet 中的小常数,避免除零
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数类型
        resnet_groups: int = 32,  # ResNet 分组数
        resnet_pre_norm: bool = True,  # 是否在 ResNet 前进行归一化
        num_attention_heads: int = 1,  # 注意力头数
        output_scale_factor: float = 1.0,  # 输出缩放因子
        cross_attention_dim: int = 1280,  # 交叉注意力维度
        dual_cross_attention: bool = False,  # 是否使用双交叉注意力
        use_linear_projection: bool = True,  # 是否使用线性投影
        upcast_attention: bool = False,  # 是否使用上采样注意力
    ):
        # 省略具体初始化代码,通常会在这里初始化各个层和参数

    # 定义前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态
        attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
        num_frames: int = 1,  # 帧数
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,  # 交叉注意力的额外参数
    ) -> torch.Tensor:  # 返回一个张量
        # 通过第一个 ResNet 层处理隐藏状态
        hidden_states = self.resnets[0](hidden_states, temb)
        # 通过第一个时间卷积层处理隐藏状态
        hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
        # 遍历所有的注意力层、时间注意力层、ResNet 层和时间卷积层
        for attn, temp_attn, resnet, temp_conv in zip(
            self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
        ):
            # 通过当前注意力层处理隐藏状态
            hidden_states = attn(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]  # 只取返回的第一个元素
            # 通过当前时间注意力层处理隐藏状态
            hidden_states = temp_attn(
                hidden_states,
                num_frames=num_frames,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]  # 只取返回的第一个元素
            # 通过当前 ResNet 层处理隐藏状态
            hidden_states = resnet(hidden_states, temb)
            # 通过当前时间卷积层处理隐藏状态
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

        # 返回最终的隐藏状态
        return hidden_states


# 定义一个名为 CrossAttnDownBlock3D 的类,继承自 nn.Module
class CrossAttnDownBlock3D(nn.Module):
    # 初始化方法,设置类的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        out_channels: int,  # 输出通道数
        temb_channels: int,  # 时间嵌入通道数
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        resnet_eps: float = 1e-6,  # ResNet 中的小常数,避免除零
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移
        resnet_act_fn: str = "swish",  # ResNet 激活函数类型
        resnet_groups: int = 32,  # ResNet 分组数
        resnet_pre_norm: bool = True,  # 是否在 ResNet 前进行归一化
        num_attention_heads: int = 1,  # 注意力头数
        cross_attention_dim: int = 1280,  # 交叉注意力维度
        output_scale_factor: float = 1.0,  # 输出缩放因子
        downsample_padding: int = 1,  # 下采样的填充大小
        add_downsample: bool = True,  # 是否添加下采样层
        dual_cross_attention: bool = False,  # 是否使用双交叉注意力
        use_linear_projection: bool = False,  # 是否使用线性投影
        only_cross_attention: bool = False,  # 是否只使用交叉注意力
        upcast_attention: bool = False,  # 是否使用上采样注意力
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化残差块列表
        resnets = []
        # 初始化注意力层列表
        attentions = []
        # 初始化临时注意力层列表
        temp_attentions = []
        # 初始化临时卷积层列表
        temp_convs = []

        # 设置是否使用交叉注意力
        self.has_cross_attention = True
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads

        # 根据层数创建各层模块
        for i in range(num_layers):
            # 设置输入通道数,第一层使用传入的输入通道数,后续层使用输出通道数
            in_channels = in_channels if i == 0 else out_channels
            # 创建残差块并添加到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 输入通道数
                    out_channels=out_channels,  # 输出通道数
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    eps=resnet_eps,  # 残差块的 epsilon 值
                    groups=resnet_groups,  # 组卷积的组数
                    dropout=dropout,  # Dropout 比例
                    time_embedding_norm=resnet_time_scale_shift,  # 时间嵌入的归一化方式
                    non_linearity=resnet_act_fn,  # 非线性激活函数
                    output_scale_factor=output_scale_factor,  # 输出缩放因子
                    pre_norm=resnet_pre_norm,  # 是否进行预归一化
                )
            )
            # 创建时间卷积层并添加到列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 输入通道数
                    out_channels,  # 输出通道数
                    dropout=0.1,  # Dropout 比例
                    norm_num_groups=resnet_groups,  # 归一化的组数
                )
            )
            # 创建二维变换器模型并添加到列表中
            attentions.append(
                Transformer2DModel(
                    out_channels // num_attention_heads,  # 每个注意力头的通道数
                    num_attention_heads,  # 注意力头的数量
                    in_channels=out_channels,  # 输入通道数
                    num_layers=1,  # 变换器层数
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                    norm_num_groups=resnet_groups,  # 归一化的组数
                    use_linear_projection=use_linear_projection,  # 是否使用线性映射
                    only_cross_attention=only_cross_attention,  # 是否只使用交叉注意力
                    upcast_attention=upcast_attention,  # 是否上溢注意力
                )
            )
            # 创建时间变换器模型并添加到列表中
            temp_attentions.append(
                TransformerTemporalModel(
                    out_channels // num_attention_heads,  # 每个注意力头的通道数
                    num_attention_heads,  # 注意力头的数量
                    in_channels=out_channels,  # 输入通道数
                    num_layers=1,  # 变换器层数
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                    norm_num_groups=resnet_groups,  # 归一化的组数
                )
            )
        # 将残差块列表转换为模块列表
        self.resnets = nn.ModuleList(resnets)
        # 将临时卷积层列表转换为模块列表
        self.temp_convs = nn.ModuleList(temp_convs)
        # 将注意力层列表转换为模块列表
        self.attentions = nn.ModuleList(attentions)
        # 将临时注意力层列表转换为模块列表
        self.temp_attentions = nn.ModuleList(temp_attentions)

        # 如果需要添加下采样层
        if add_downsample:
            # 创建下采样模块列表
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 输出通道数
                        use_conv=True,  # 是否使用卷积进行下采样
                        out_channels=out_channels,  # 输出通道数
                        padding=downsample_padding,  # 下采样时的填充
                        name="op",  # 模块名称
                    )
                ]
            )
        else:
            # 如果不需要下采样,设置为 None
            self.downsamplers = None

        # 初始化梯度检查点设置为 False
        self.gradient_checkpointing = False
    # 定义前向传播方法,接受多个输入参数并返回张量或元组
    def forward(
            self,
            hidden_states: torch.Tensor,
            temb: Optional[torch.Tensor] = None,
            encoder_hidden_states: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            num_frames: int = 1,
            cross_attention_kwargs: Dict[str, Any] = None,
        ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
            # TODO(Patrick, William) - 注意力掩码未使用
            output_states = ()  # 初始化输出状态为元组
    
            # 遍历所有的残差网络、临时卷积、注意力和临时注意力层
            for resnet, temp_conv, attn, temp_attn in zip(
                self.resnets, self.temp_convs, self.attentions, self.temp_attentions
            ):
                # 使用残差网络处理隐状态和时间嵌入
                hidden_states = resnet(hidden_states, temb)
                # 使用临时卷积处理隐状态,考虑帧数
                hidden_states = temp_conv(hidden_states, num_frames=num_frames)
                # 使用注意力层处理隐状态,返回字典设为 False
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]  # 取返回的第一个元素
                # 使用临时注意力层处理隐状态,返回字典设为 False
                hidden_states = temp_attn(
                    hidden_states,
                    num_frames=num_frames,
                    cross_attention_kwargs=cross_attention_kwargs,
                    return_dict=False,
                )[0]  # 取返回的第一个元素
    
                # 将当前隐状态添加到输出状态中
                output_states += (hidden_states,)
    
            # 如果存在下采样器,则逐个应用
            if self.downsamplers is not None:
                for downsampler in self.downsamplers:
                    hidden_states = downsampler(hidden_states)  # 应用下采样器
    
                # 将下采样后的隐状态添加到输出状态中
                output_states += (hidden_states,)
    
            # 返回最终的隐状态和所有输出状态
            return hidden_states, output_states
# 定义一个 3D 下采样模块,继承自 nn.Module
class DownBlock3D(nn.Module):
    # 初始化方法,设置各个参数及模块
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        out_channels: int,  # 输出通道数
        temb_channels: int,  # 时间嵌入通道数
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        resnet_eps: float = 1e-6,  # ResNet 中的 epsilon 值
        resnet_time_scale_shift: str = "default",  # 时间尺度偏移方式
        resnet_act_fn: str = "swish",  # ResNet 激活函数类型
        resnet_groups: int = 32,  # ResNet 中的组数
        resnet_pre_norm: bool = True,  # 是否使用预归一化
        output_scale_factor: float = 1.0,  # 输出缩放因子
        add_downsample: bool = True,  # 是否添加下采样层
        downsample_padding: int = 1,  # 下采样的填充大小
    ):
        # 调用父类构造函数
        super().__init__()
        # 初始化 ResNet 模块列表
        resnets = []
        # 初始化时间卷积层列表
        temp_convs = []

        # 遍历层数,构建 ResNet 模块和时间卷积层
        for i in range(num_layers):
            # 第一层使用输入通道数,后续层使用输出通道数
            in_channels = in_channels if i == 0 else out_channels
            # 添加 ResNet 模块到列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=in_channels,  # 输入通道数
                    out_channels=out_channels,  # 输出通道数
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    eps=resnet_eps,  # ResNet 中的 epsilon 值
                    groups=resnet_groups,  # 组数
                    dropout=dropout,  # dropout 概率
                    time_embedding_norm=resnet_time_scale_shift,  # 时间嵌入归一化方式
                    non_linearity=resnet_act_fn,  # 激活函数
                    output_scale_factor=output_scale_factor,  # 输出缩放因子
                    pre_norm=resnet_pre_norm,  # 是否预归一化
                )
            )
            # 添加时间卷积层到列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 输入通道数
                    out_channels,  # 输出通道数
                    dropout=0.1,  # dropout 概率
                    norm_num_groups=resnet_groups,  # 组数
                )
            )

        # 将 ResNet 模块列表转换为 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 将时间卷积层列表转换为 nn.ModuleList
        self.temp_convs = nn.ModuleList(temp_convs)

        # 如果需要添加下采样层
        if add_downsample:
            # 初始化下采样模块列表
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(
                        out_channels,  # 输出通道数
                        use_conv=True,  # 使用卷积
                        out_channels=out_channels,  # 输出通道数
                        padding=downsample_padding,  # 填充大小
                        name="op",  # 模块名称
                    )
                ]
            )
        else:
            # 不添加下采样层
            self.downsamplers = None

        # 初始化梯度检查点开关为 False
        self.gradient_checkpointing = False

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        num_frames: int = 1,  # 帧数
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        # 初始化输出状态元组
        output_states = ()

        # 遍历 ResNet 模块和时间卷积层
        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
            # 通过 ResNet 模块处理隐藏状态
            hidden_states = resnet(hidden_states, temb)
            # 通过时间卷积层处理隐藏状态
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

            # 将当前的隐藏状态添加到输出状态元组中
            output_states += (hidden_states,)

        # 如果存在下采样层
        if self.downsamplers is not None:
            # 遍历下采样层
            for downsampler in self.downsamplers:
                # 通过下采样层处理隐藏状态
                hidden_states = downsampler(hidden_states)

            # 将当前的隐藏状态添加到输出状态元组中
            output_states += (hidden_states,)

        # 返回最终的隐藏状态和输出状态元组
        return hidden_states, output_states


# 定义一个 3D 交叉注意力上采样模块,继承自 nn.Module
class CrossAttnUpBlock3D(nn.Module):
    # 初始化方法,用于设置类的属性
        def __init__(
            # 输入通道数
            self,
            in_channels: int,
            # 输出通道数
            out_channels: int,
            # 前一个输出通道数
            prev_output_channel: int,
            # 时间嵌入通道数
            temb_channels: int,
            # dropout比率,默认为0.0
            dropout: float = 0.0,
            # 网络层数,默认为1
            num_layers: int = 1,
            # ResNet的epsilon值,默认为1e-6
            resnet_eps: float = 1e-6,
            # ResNet的时间尺度偏移,默认为"default"
            resnet_time_scale_shift: str = "default",
            # ResNet激活函数,默认为"swish"
            resnet_act_fn: str = "swish",
            # ResNet的组数,默认为32
            resnet_groups: int = 32,
            # ResNet是否使用预归一化,默认为True
            resnet_pre_norm: bool = True,
            # 注意力头的数量,默认为1
            num_attention_heads: int = 1,
            # 交叉注意力的维度,默认为1280
            cross_attention_dim: int = 1280,
            # 输出缩放因子,默认为1.0
            output_scale_factor: float = 1.0,
            # 是否添加上采样,默认为True
            add_upsample: bool = True,
            # 双重交叉注意力的开关,默认为False
            dual_cross_attention: bool = False,
            # 是否使用线性投影,默认为False
            use_linear_projection: bool = False,
            # 仅使用交叉注意力的开关,默认为False
            only_cross_attention: bool = False,
            # 是否上采样注意力,默认为False
            upcast_attention: bool = False,
            # 分辨率索引,默认为None
            resolution_idx: Optional[int] = None,
    ):
        # 调用父类的构造函数进行初始化
        super().__init__()
        # 初始化用于存储 ResNet 块的列表
        resnets = []
        # 初始化用于存储时间卷积层的列表
        temp_convs = []
        # 初始化用于存储注意力模型的列表
        attentions = []
        # 初始化用于存储时间注意力模型的列表
        temp_attentions = []

        # 设置是否使用交叉注意力
        self.has_cross_attention = True
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads

        # 遍历每一层以构建网络结构
        for i in range(num_layers):
            # 确定残差跳过通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 确定 ResNet 的输入通道数
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 添加一个 ResNet 块到列表中
            resnets.append(
                ResnetBlock2D(
                    # 设置输入通道数,包括残差跳过通道
                    in_channels=resnet_in_channels + res_skip_channels,
                    # 设置输出通道数
                    out_channels=out_channels,
                    # 设置时间嵌入通道数
                    temb_channels=temb_channels,
                    # 设置 epsilon 值
                    eps=resnet_eps,
                    # 设置组数
                    groups=resnet_groups,
                    # 设置 dropout 概率
                    dropout=dropout,
                    # 设置时间嵌入归一化方式
                    time_embedding_norm=resnet_time_scale_shift,
                    # 设置非线性激活函数
                    non_linearity=resnet_act_fn,
                    # 设置输出缩放因子
                    output_scale_factor=output_scale_factor,
                    # 设置是否在预归一化
                    pre_norm=resnet_pre_norm,
                )
            )
            # 添加一个时间卷积层到列表中
            temp_convs.append(
                TemporalConvLayer(
                    # 设置输出通道数
                    out_channels,
                    # 设置输入通道数
                    out_channels,
                    # 设置 dropout 概率
                    dropout=0.1,
                    # 设置组数
                    norm_num_groups=resnet_groups,
                )
            )
            # 添加一个 2D 转换器模型到列表中
            attentions.append(
                Transformer2DModel(
                    # 设置每个注意力头的通道数
                    out_channels // num_attention_heads,
                    # 设置注意力头的数量
                    num_attention_heads,
                    # 设置输入通道数
                    in_channels=out_channels,
                    # 设置层数
                    num_layers=1,
                    # 设置交叉注意力维度
                    cross_attention_dim=cross_attention_dim,
                    # 设置组数
                    norm_num_groups=resnet_groups,
                    # 是否使用线性投影
                    use_linear_projection=use_linear_projection,
                    # 是否仅使用交叉注意力
                    only_cross_attention=only_cross_attention,
                    # 是否提升注意力精度
                    upcast_attention=upcast_attention,
                )
            )
            # 添加一个时间转换器模型到列表中
            temp_attentions.append(
                TransformerTemporalModel(
                    # 设置每个注意力头的通道数
                    out_channels // num_attention_heads,
                    # 设置注意力头的数量
                    num_attention_heads,
                    # 设置输入通道数
                    in_channels=out_channels,
                    # 设置层数
                    num_layers=1,
                    # 设置交叉注意力维度
                    cross_attention_dim=cross_attention_dim,
                    # 设置组数
                    norm_num_groups=resnet_groups,
                )
            )
        # 将 ResNet 块列表转换为 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)
        # 将时间卷积层列表转换为 nn.ModuleList
        self.temp_convs = nn.ModuleList(temp_convs)
        # 将注意力模型列表转换为 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 将时间注意力模型列表转换为 nn.ModuleList
        self.temp_attentions = nn.ModuleList(temp_attentions)

        # 如果需要上采样,则初始化上采样层
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否则将上采样层设置为 None
            self.upsamplers = None

        # 设置梯度检查点标志
        self.gradient_checkpointing = False
        # 设置分辨率索引
        self.resolution_idx = resolution_idx
    # 定义前向传播函数,接受多个参数并返回张量
        def forward(
            self,
            hidden_states: torch.Tensor,  # 当前层的隐藏状态
            res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 以前层的隐藏状态元组
            temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
            encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态
            upsample_size: Optional[int] = None,  # 可选的上采样大小
            attention_mask: Optional[torch.Tensor] = None,  # 可选的注意力掩码
            num_frames: int = 1,  # 帧数,默认值为1
            cross_attention_kwargs: Dict[str, Any] = None,  # 可选的交叉注意力参数
        ) -> torch.Tensor:  # 返回一个张量
            # 检查 FreeU 是否启用,基于多个属性的存在性
            is_freeu_enabled = (
                getattr(self, "s1", None)  # 获取属性 s1
                and getattr(self, "s2", None)  # 获取属性 s2
                and getattr(self, "b1", None)  # 获取属性 b1
                and getattr(self, "b2", None)  # 获取属性 b2
            )
    
            # TODO(Patrick, William) - 注意力掩码尚未使用
            for resnet, temp_conv, attn, temp_attn in zip(  # 遍历网络模块
                self.resnets, self.temp_convs, self.attentions, self.temp_attentions  # 从各模块提取
            ):
                # 从元组中弹出最后一个残差隐藏状态
                res_hidden_states = res_hidden_states_tuple[-1]
                # 更新元组,去掉最后一个隐藏状态
                res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
                # FreeU:仅对前两个阶段操作
                if is_freeu_enabled:
                    hidden_states, res_hidden_states = apply_freeu(  # 应用 FreeU 操作
                        self.resolution_idx,  # 当前分辨率索引
                        hidden_states,  # 当前隐藏状态
                        res_hidden_states,  # 残差隐藏状态
                        s1=self.s1,  # 属性 s1
                        s2=self.s2,  # 属性 s2
                        b1=self.b1,  # 属性 b1
                        b2=self.b2,  # 属性 b2
                    )
    
                # 将当前隐藏状态与残差隐藏状态在维度1上拼接
                hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
                # 通过 ResNet 模块处理隐藏状态
                hidden_states = resnet(hidden_states, temb)
                # 通过临时卷积模块处理隐藏状态
                hidden_states = temp_conv(hidden_states, num_frames=num_frames)
                # 通过注意力模块处理隐藏状态,并提取第一个返回值
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,  # 传递编码器隐藏状态
                    cross_attention_kwargs=cross_attention_kwargs,  # 传递交叉注意力参数
                    return_dict=False,  # 不返回字典形式的结果
                )[0]  # 提取第一个返回值
                # 通过临时注意力模块处理隐藏状态,并提取第一个返回值
                hidden_states = temp_attn(
                    hidden_states,
                    num_frames=num_frames,  # 传递帧数
                    cross_attention_kwargs=cross_attention_kwargs,  # 传递交叉注意力参数
                    return_dict=False,  # 不返回字典形式的结果
                )[0]  # 提取第一个返回值
    
            # 如果存在上采样模块
            if self.upsamplers is not None:
                for upsampler in self.upsamplers:  # 遍历上采样模块
                    hidden_states = upsampler(hidden_states, upsample_size)  # 应用上采样模块
    
            # 返回最终的隐藏状态
            return hidden_states
# 定义一个名为 UpBlock3D 的类,继承自 nn.Module
class UpBlock3D(nn.Module):
    # 初始化函数,接受多个参数以配置网络层
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        prev_output_channel: int,  # 前一层的输出通道数
        out_channels: int,  # 当前层的输出通道数
        temb_channels: int,  # 时间嵌入通道数
        dropout: float = 0.0,  # dropout 概率
        num_layers: int = 1,  # 层数
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 参数,用于数值稳定
        resnet_time_scale_shift: str = "default",  # ResNet 时间缩放偏移设置
        resnet_act_fn: str = "swish",  # ResNet 激活函数
        resnet_groups: int = 32,  # ResNet 的组数
        resnet_pre_norm: bool = True,  # 是否使用前置归一化
        output_scale_factor: float = 1.0,  # 输出缩放因子
        add_upsample: bool = True,  # 是否添加上采样层
        resolution_idx: Optional[int] = None,  # 分辨率索引,默认为 None
    ):
        # 调用父类构造函数
        super().__init__()
        # 创建一个空列表,用于存储 ResNet 层
        resnets = []
        # 创建一个空列表,用于存储时间卷积层
        temp_convs = []

        # 根据层数创建 ResNet 和时间卷积层
        for i in range(num_layers):
            # 确定跳跃连接的通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 确定当前 ResNet 层的输入通道数
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 创建 ResNet 层,并添加到 resnets 列表中
            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,  # 输入通道数
                    out_channels=out_channels,  # 输出通道数
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    eps=resnet_eps,  # epsilon 参数
                    groups=resnet_groups,  # 组数
                    dropout=dropout,  # dropout 概率
                    time_embedding_norm=resnet_time_scale_shift,  # 时间嵌入归一化
                    non_linearity=resnet_act_fn,  # 激活函数
                    output_scale_factor=output_scale_factor,  # 输出缩放因子
                    pre_norm=resnet_pre_norm,  # 前置归一化
                )
            )
            # 创建时间卷积层,并添加到 temp_convs 列表中
            temp_convs.append(
                TemporalConvLayer(
                    out_channels,  # 输入通道数
                    out_channels,  # 输出通道数
                    dropout=0.1,  # dropout 概率
                    norm_num_groups=resnet_groups,  # 归一化组数
                )
            )

        # 将 ResNet 层的列表转为 nn.ModuleList,以便于管理
        self.resnets = nn.ModuleList(resnets)
        # 将时间卷积层的列表转为 nn.ModuleList,以便于管理
        self.temp_convs = nn.ModuleList(temp_convs)

        # 如果需要添加上采样层,则创建并添加
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 如果不需要上采样层,则设置为 None
            self.upsamplers = None

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False
        # 设置分辨率索引
        self.resolution_idx = resolution_idx

    # 定义前向传播函数
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 跳跃连接的隐藏状态元组
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        upsample_size: Optional[int] = None,  # 可选的上采样尺寸
        num_frames: int = 1,  # 帧数
    ) -> torch.Tensor:
        # 判断是否启用 FreeU,检查相关属性是否存在且不为 None
        is_freeu_enabled = (
            getattr(self, "s1", None)
            and getattr(self, "s2", None)
            and getattr(self, "b1", None)
            and getattr(self, "b2", None)
        )
        # 遍历自定义的 resnets 和 temp_convs,进行逐对处理
        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
            # 从 res_hidden_states_tuple 中弹出最后一个隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新 res_hidden_states_tuple,去掉最后一个隐藏状态
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]

            # 如果启用了 FreeU,则仅对前两个阶段进行操作
            if is_freeu_enabled:
                # 调用 apply_freeu 函数,处理隐藏状态
                hidden_states, res_hidden_states = apply_freeu(
                    self.resolution_idx,
                    hidden_states,
                    res_hidden_states,
                    s1=self.s1,
                    s2=self.s2,
                    b1=self.b1,
                    b2=self.b2,
                )

            # 将当前的隐藏状态与残差隐藏状态在维度 1 上连接
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

            # 通过当前的 resnet 处理隐藏状态和 temb
            hidden_states = resnet(hidden_states, temb)
            # 通过当前的 temp_conv 处理隐藏状态,传入 num_frames 参数
            hidden_states = temp_conv(hidden_states, num_frames=num_frames)

        # 如果存在上采样器,则对每个上采样器进行处理
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                # 通过当前的 upsampler 处理隐藏状态,传入 upsample_size 参数
                hidden_states = upsampler(hidden_states, upsample_size)

        # 返回最终的隐藏状态
        return hidden_states
# 定义一个中间块时间解码器类,继承自 nn.Module
class MidBlockTemporalDecoder(nn.Module):
    # 初始化函数,定义输入输出通道、注意力头维度、层数等参数
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        attention_head_dim: int = 512,
        num_layers: int = 1,
        upcast_attention: bool = False,
    ):
        # 调用父类的初始化函数
        super().__init__()

        # 初始化 ResNet 和 Attention 列表
        resnets = []
        attentions = []
        # 根据层数创建相应数量的 ResNet
        for i in range(num_layers):
            input_channels = in_channels if i == 0 else out_channels
            # 将 SpatioTemporalResBlock 实例添加到 ResNet 列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=input_channels,
                    out_channels=out_channels,
                    temb_channels=None,
                    eps=1e-6,
                    temporal_eps=1e-5,
                    merge_factor=0.0,
                    merge_strategy="learned",
                    switch_spatial_to_temporal_mix=True,
                )
            )

        # 添加 Attention 实例到 Attention 列表中
        attentions.append(
            Attention(
                query_dim=in_channels,
                heads=in_channels // attention_head_dim,
                dim_head=attention_head_dim,
                eps=1e-6,
                upcast_attention=upcast_attention,
                norm_num_groups=32,
                bias=True,
                residual_connection=True,
            )
        )

        # 将 Attention 和 ResNet 列表转换为 ModuleList
        self.attentions = nn.ModuleList(attentions)
        self.resnets = nn.ModuleList(resnets)

    # 前向传播函数,定义输入的隐藏状态和图像指示器的处理
    def forward(
        self,
        hidden_states: torch.Tensor,
        image_only_indicator: torch.Tensor,
    ):
        # 处理第一个 ResNet 的输出
        hidden_states = self.resnets[0](
            hidden_states,
            image_only_indicator=image_only_indicator,
        )
        # 遍历剩余的 ResNet 和 Attention,交替处理
        for resnet, attn in zip(self.resnets[1:], self.attentions):
            hidden_states = attn(hidden_states)  # 应用注意力机制
            # 处理 ResNet 的输出
            hidden_states = resnet(
                hidden_states,
                image_only_indicator=image_only_indicator,
            )

        # 返回最终的隐藏状态
        return hidden_states


# 定义一个上采样块时间解码器类,继承自 nn.Module
class UpBlockTemporalDecoder(nn.Module):
    # 初始化函数,定义输入输出通道、层数等参数
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_layers: int = 1,
        add_upsample: bool = True,
    ):
        # 调用父类的初始化函数
        super().__init__()
        # 初始化 ResNet 列表
        resnets = []
        # 根据层数创建相应数量的 ResNet
        for i in range(num_layers):
            input_channels = in_channels if i == 0 else out_channels
            # 将 SpatioTemporalResBlock 实例添加到 ResNet 列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=input_channels,
                    out_channels=out_channels,
                    temb_channels=None,
                    eps=1e-6,
                    temporal_eps=1e-5,
                    merge_factor=0.0,
                    merge_strategy="learned",
                    switch_spatial_to_temporal_mix=True,
                )
            )
        # 将 ResNet 列表转换为 ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 如果需要,初始化上采样层
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None
    # 定义前向传播函数,接收隐藏状态和图像指示器作为输入,返回处理后的张量
        def forward(
            self,
            hidden_states: torch.Tensor,
            image_only_indicator: torch.Tensor,
        ) -> torch.Tensor:
            # 遍历每个 ResNet 模块,更新隐藏状态
            for resnet in self.resnets:
                hidden_states = resnet(
                    hidden_states,
                    image_only_indicator=image_only_indicator,
                )
    
            # 如果存在上采样模块,则对隐藏状态进行上采样处理
            if self.upsamplers is not None:
                for upsampler in self.upsamplers:
                    hidden_states = upsampler(hidden_states)
    
            # 返回最终的隐藏状态
            return hidden_states
# 定义一个名为 UNetMidBlockSpatioTemporal 的类,继承自 nn.Module
class UNetMidBlockSpatioTemporal(nn.Module):
    # 初始化方法,接收多个参数以配置该模块
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        temb_channels: int,  # 时间嵌入通道数
        num_layers: int = 1,  # 层数,默认为 1
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,  # 每个块的变换层数,默认为 1
        num_attention_heads: int = 1,  # 注意力头数,默认为 1
        cross_attention_dim: int = 1280,  # 交叉注意力维度,默认为 1280
    ):
        # 调用父类的初始化方法
        super().__init__()

        # 设置是否使用交叉注意力标志
        self.has_cross_attention = True
        # 存储注意力头的数量
        self.num_attention_heads = num_attention_heads

        # 支持每个块的变换层数为可变的
        if isinstance(transformer_layers_per_block, int):
            # 如果是整数,则将其转换为包含 num_layers 个相同元素的列表
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 至少有一个 ResNet 块
        resnets = [
            # 创建第一个时空残差块
            SpatioTemporalResBlock(
                in_channels=in_channels,  # 输入通道数
                out_channels=in_channels,  # 输出通道数与输入相同
                temb_channels=temb_channels,  # 时间嵌入通道数
                eps=1e-5,  # 小常数用于数值稳定性
            )
        ]
        # 初始化注意力模块列表
        attentions = []

        # 遍历层数以添加注意力和残差块
        for i in range(num_layers):
            # 添加时空变换模型到注意力列表
            attentions.append(
                TransformerSpatioTemporalModel(
                    num_attention_heads,  # 注意力头数
                    in_channels // num_attention_heads,  # 每个头的通道数
                    in_channels=in_channels,  # 输入通道数
                    num_layers=transformer_layers_per_block[i],  # 当前层的变换层数
                    cross_attention_dim=cross_attention_dim,  # 交叉注意力维度
                )
            )

            # 添加另一个时空残差块到残差列表
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=in_channels,  # 输入通道数
                    out_channels=in_channels,  # 输出通道数与输入相同
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    eps=1e-5,  # 小常数用于数值稳定性
                )
            )

        # 将注意力模块列表转换为 nn.ModuleList
        self.attentions = nn.ModuleList(attentions)
        # 将残差模块列表转换为 nn.ModuleList
        self.resnets = nn.ModuleList(resnets)

        # 初始化梯度检查点标志为 False
        self.gradient_checkpointing = False

    # 前向传播方法,接收多个输入参数
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隐藏状态张量
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        encoder_hidden_states: Optional[torch.Tensor] = None,  # 可选的编码器隐藏状态
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示张量
    # 返回类型为 torch.Tensor 的函数
    ) -> torch.Tensor:
        # 使用第一个残差网络处理隐藏状态,传入时间嵌入和图像指示器
        hidden_states = self.resnets[0](
            hidden_states,
            temb,
            image_only_indicator=image_only_indicator,
        )
    
        # 遍历注意力层和后续残差网络的组合
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            # 检查是否在训练中且开启了梯度检查点
            if self.training and self.gradient_checkpointing:  # TODO
    
                # 创建自定义前向传播函数
                def create_custom_forward(module, return_dict=None):
                    # 定义自定义前向传播内部函数
                    def custom_forward(*inputs):
                        # 根据返回字典参数决定是否使用返回字典
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 设置检查点参数,取决于 PyTorch 版本
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用注意力层处理隐藏状态,传入编码器的隐藏状态和图像指示器
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
                # 使用检查点进行残差网络的前向传播
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    image_only_indicator,
                    **ckpt_kwargs,
                )
            else:
                # 使用注意力层处理隐藏状态,传入编码器的隐藏状态和图像指示器
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
                # 使用残差网络处理隐藏状态
                hidden_states = resnet(
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )
    
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个下采样的时空块,继承自 nn.Module
class DownBlockSpatioTemporal(nn.Module):
    # 初始化方法,设置输入输出通道和层数等参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        out_channels: int,  # 输出通道数
        temb_channels: int,  # 时间嵌入通道数
        num_layers: int = 1,  # 层数,默认为1
        add_downsample: bool = True,  # 是否添加下采样
    ):
        super().__init__()  # 调用父类的初始化方法
        resnets = []  # 初始化一个空列表以存储残差块

        # 根据层数创建相应数量的 SpatioTemporalResBlock
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels  # 确定当前层的输入通道数
            resnets.append(
                SpatioTemporalResBlock(  # 添加一个新的时空残差块
                    in_channels=in_channels,  # 设置输入通道
                    out_channels=out_channels,  # 设置输出通道
                    temb_channels=temb_channels,  # 设置时间嵌入通道
                    eps=1e-5,  # 设置 epsilon 值
                )
            )

        self.resnets = nn.ModuleList(resnets)  # 将残差块列表转化为 ModuleList

        # 如果需要下采样,创建下采样模块
        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [
                    Downsample2D(  # 添加一个下采样层
                        out_channels,  # 设置输入通道
                        use_conv=True,  # 是否使用卷积进行下采样
                        out_channels=out_channels,  # 设置输出通道
                        name="op",  # 下采样层名称
                    )
                ]
            )
        else:
            self.downsamplers = None  # 不添加下采样模块

        self.gradient_checkpointing = False  # 初始化梯度检查点为 False

    # 前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 隐藏状态输入
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示器
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
        output_states = ()  # 初始化输出状态元组
        for resnet in self.resnets:  # 遍历每个残差块
            if self.training and self.gradient_checkpointing:  # 如果在训练且启用了梯度检查点

                # 定义一个创建自定义前向传播的方法
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)  # 返回模块的前向输出

                    return custom_forward

                # 根据 PyTorch 版本进行检查点操作
                if is_torch_version(">=", "1.11.0"):
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定义前向传播
                        hidden_states,  # 输入隐藏状态
                        temb,  # 输入时间嵌入
                        image_only_indicator,  # 输入图像指示器
                        use_reentrant=False,  # 不使用重入
                    )
                else:
                    hidden_states = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(resnet),  # 使用自定义前向传播
                        hidden_states,  # 输入隐藏状态
                        temb,  # 输入时间嵌入
                        image_only_indicator,  # 输入图像指示器
                    )
            else:
                hidden_states = resnet(  # 直接通过残差块进行前向传播
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )

            output_states = output_states + (hidden_states,)  # 将当前隐藏状态添加到输出状态中

        if self.downsamplers is not None:  # 如果存在下采样模块
            for downsampler in self.downsamplers:  # 遍历下采样模块
                hidden_states = downsampler(hidden_states)  # 对隐藏状态进行下采样

            output_states = output_states + (hidden_states,)  # 将下采样后的状态添加到输出状态中

        return hidden_states, output_states  # 返回最终的隐藏状态和输出状态元组


# 定义一个交叉注意力下采样时空块,继承自 nn.Module
class CrossAttnDownBlockSpatioTemporal(nn.Module):
    # 初始化方法,设置模型的基本参数
        def __init__(
            # 输入通道数
            self,
            in_channels: int,
            # 输出通道数
            out_channels: int,
            # 时间嵌入通道数
            temb_channels: int,
            # 层数,默认为1
            num_layers: int = 1,
            # 每个块的变换层数,可以是整数或元组,默认为1
            transformer_layers_per_block: Union[int, Tuple[int]] = 1,
            # 注意力头数,默认为1
            num_attention_heads: int = 1,
            # 交叉注意力的维度,默认为1280
            cross_attention_dim: int = 1280,
            # 是否添加下采样,默认为True
            add_downsample: bool = True,
        ):
            # 调用父类初始化方法
            super().__init__()
            # 初始化残差网络列表
            resnets = []
            # 初始化注意力层列表
            attentions = []
    
            # 设置是否使用交叉注意力为True
            self.has_cross_attention = True
            # 设置注意力头数
            self.num_attention_heads = num_attention_heads
            # 如果变换层数是整数,则扩展为列表
            if isinstance(transformer_layers_per_block, int):
                transformer_layers_per_block = [transformer_layers_per_block] * num_layers
    
            # 遍历层数以创建残差块和注意力层
            for i in range(num_layers):
                # 如果是第一层,使用输入通道数,否则使用输出通道数
                in_channels = in_channels if i == 0 else out_channels
                # 添加残差块到列表
                resnets.append(
                    SpatioTemporalResBlock(
                        # 输入通道数
                        in_channels=in_channels,
                        # 输出通道数
                        out_channels=out_channels,
                        # 时间嵌入通道数
                        temb_channels=temb_channels,
                        # 防止除零的微小值
                        eps=1e-6,
                    )
                )
                # 添加注意力模型到列表
                attentions.append(
                    TransformerSpatioTemporalModel(
                        # 注意力头数
                        num_attention_heads,
                        # 每个头的输出通道数
                        out_channels // num_attention_heads,
                        # 输入通道数
                        in_channels=out_channels,
                        # 该层的变换层数
                        num_layers=transformer_layers_per_block[i],
                        # 交叉注意力的维度
                        cross_attention_dim=cross_attention_dim,
                    )
                )
    
            # 将注意力层转换为nn.ModuleList以支持PyTorch模型
            self.attentions = nn.ModuleList(attentions)
            # 将残差层转换为nn.ModuleList以支持PyTorch模型
            self.resnets = nn.ModuleList(resnets)
    
            # 如果需要添加下采样层
            if add_downsample:
                # 添加下采样层到nn.ModuleList
                self.downsamplers = nn.ModuleList(
                    [
                        Downsample2D(
                            # 输出通道数
                            out_channels,
                            # 是否使用卷积
                            use_conv=True,
                            # 输出通道数
                            out_channels=out_channels,
                            # 填充大小
                            padding=1,
                            # 操作名称
                            name="op",
                        )
                    ]
                )
            else:
                # 如果不需要下采样层,则设置为None
                self.downsamplers = None
    
            # 初始化梯度检查点为False
            self.gradient_checkpointing = False
    
        # 前向传播方法
        def forward(
            # 隐藏状态的张量
            hidden_states: torch.Tensor,
            # 时间嵌入的可选张量
            temb: Optional[torch.Tensor] = None,
            # 编码器隐藏状态的可选张量
            encoder_hidden_states: Optional[torch.Tensor] = None,
            # 仅图像指示的可选张量
            image_only_indicator: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:  # 定义返回类型为一个元组,包含一个张量和多个张量的元组
        output_states = ()  # 初始化一个空元组,用于存储输出状态

        blocks = list(zip(self.resnets, self.attentions))  # 将自定义的残差网络和注意力模块打包成一个列表
        for resnet, attn in blocks:  # 遍历每个残差网络和对应的注意力模块
            if self.training and self.gradient_checkpointing:  # 如果处于训练模式并且启用了梯度检查点

                def create_custom_forward(module, return_dict=None):  # 定义一个创建自定义前向传播函数的辅助函数
                    def custom_forward(*inputs):  # 定义自定义前向传播函数,接受任意数量的输入
                        if return_dict is not None:  # 如果提供了返回字典参数
                            return module(*inputs, return_dict=return_dict)  # 使用返回字典参数调用模块
                        else:  # 如果没有提供返回字典
                            return module(*inputs)  # 直接调用模块并返回结果

                    return custom_forward  # 返回自定义前向传播函数

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}  # 根据 PyTorch 版本设置检查点参数
                hidden_states = torch.utils.checkpoint.checkpoint(  # 使用检查点功能计算隐藏状态以节省内存
                    create_custom_forward(resnet),  # 将残差网络传入自定义前向函数
                    hidden_states,  # 将当前隐藏状态作为输入
                    temb,  # 传递时间嵌入
                    image_only_indicator,  # 传递图像指示器
                    **ckpt_kwargs,  # 解包检查点参数
                )

                hidden_states = attn(  # 使用注意力模块处理隐藏状态
                    hidden_states,  # 输入隐藏状态
                    encoder_hidden_states=encoder_hidden_states,  # 传递编码器隐藏状态
                    image_only_indicator=image_only_indicator,  # 传递图像指示器
                    return_dict=False,  # 不返回字典形式的结果
                )[0]  # 获取输出的第一个元素
            else:  # 如果不处于训练模式或未启用梯度检查点
                hidden_states = resnet(  # 直接使用残差网络处理隐藏状态
                    hidden_states,  # 输入当前隐藏状态
                    temb,  # 传递时间嵌入
                    image_only_indicator=image_only_indicator,  # 传递图像指示器
                )
                hidden_states = attn(  # 使用注意力模块处理隐藏状态
                    hidden_states,  # 输入隐藏状态
                    encoder_hidden_states=encoder_hidden_states,  # 传递编码器隐藏状态
                    image_only_indicator=image_only_indicator,  # 传递图像指示器
                    return_dict=False,  # 不返回字典形式的结果
                )[0]  # 获取输出的第一个元素

            output_states = output_states + (hidden_states,)  # 将当前的隐藏状态添加到输出状态元组中

        if self.downsamplers is not None:  # 如果存在下采样模块
            for downsampler in self.downsamplers:  # 遍历每个下采样模块
                hidden_states = downsampler(hidden_states)  # 使用下采样模块处理隐藏状态

            output_states = output_states + (hidden_states,)  # 将处理后的隐藏状态添加到输出状态元组中

        return hidden_states, output_states  # 返回当前的隐藏状态和所有输出状态
# 定义一个上采样的时空块类,继承自 nn.Module
class UpBlockSpatioTemporal(nn.Module):
    # 初始化方法,定义类的参数
    def __init__(
        self,
        in_channels: int,  # 输入通道数
        prev_output_channel: int,  # 前一层输出通道数
        out_channels: int,  # 当前层输出通道数
        temb_channels: int,  # 时间嵌入通道数
        resolution_idx: Optional[int] = None,  # 可选的分辨率索引
        num_layers: int = 1,  # 层数,默认为1
        resnet_eps: float = 1e-6,  # ResNet 的 epsilon 值
        add_upsample: bool = True,  # 是否添加上采样层
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 初始化一个空的列表,用于存储 ResNet 模块
        resnets = []

        # 根据层数创建对应的 ResNet 块
        for i in range(num_layers):
            # 如果是最后一层,使用输入通道数,否则使用输出通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 第一层的输入通道为前一层输出通道,后续层为输出通道
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 创建时空 ResNet 块并添加到列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=resnet_in_channels + res_skip_channels,  # 输入通道数
                    out_channels=out_channels,  # 输出通道数
                    temb_channels=temb_channels,  # 时间嵌入通道数
                    eps=resnet_eps,  # ResNet 的 epsilon 值
                )
            )

        # 将 ResNet 模块列表转换为 nn.ModuleList,以便在模型中管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要添加上采样层,则创建对应的 nn.ModuleList
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 如果不添加上采样层,设置为 None
            self.upsamplers = None

        # 初始化梯度检查点标志
        self.gradient_checkpointing = False
        # 设置分辨率索引
        self.resolution_idx = resolution_idx

    # 定义前向传播方法
    def forward(
        self,
        hidden_states: torch.Tensor,  # 输入的隐藏状态张量
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],  # 之前层的隐藏状态元组
        temb: Optional[torch.Tensor] = None,  # 可选的时间嵌入张量
        image_only_indicator: Optional[torch.Tensor] = None,  # 可选的图像指示器张量
    ) -> torch.Tensor:  # 定义函数返回类型为 PyTorch 的张量
        for resnet in self.resnets:  # 遍历当前对象中的所有 ResNet 模型
            # pop res hidden states  # 从隐藏状态元组中提取最后一个隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]  # 获取最后的 ResNet 隐藏状态
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]  # 更新元组,移除最后一个隐藏状态

            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)  # 将当前隐藏状态与 ResNet 隐藏状态在维度 1 上拼接

            if self.training and self.gradient_checkpointing:  # 如果处于训练状态并启用了梯度检查点
                def create_custom_forward(module):  # 定义用于创建自定义前向传播函数的内部函数
                    def custom_forward(*inputs):  # 自定义前向传播,接收任意数量的输入
                        return module(*inputs)  # 调用原始模块的前向传播

                    return custom_forward  # 返回自定义前向传播函数

                if is_torch_version(">=", "1.11.0"):  # 检查当前 PyTorch 版本是否大于等于 1.11.0
                    hidden_states = torch.utils.checkpoint.checkpoint(  # 使用检查点机制保存内存
                        create_custom_forward(resnet),  # 传入自定义前向传播函数
                        hidden_states,  # 传入当前的隐藏状态
                        temb,  # 传入时间嵌入
                        image_only_indicator,  # 传入图像指示器
                        use_reentrant=False,  # 禁用重入
                    )
                else:  # 如果 PyTorch 版本小于 1.11.0
                    hidden_states = torch.utils.checkpoint.checkpoint(  # 使用检查点机制保存内存
                        create_custom_forward(resnet),  # 传入自定义前向传播函数
                        hidden_states,  # 传入当前的隐藏状态
                        temb,  # 传入时间嵌入
                        image_only_indicator,  # 传入图像指示器
                    )
            else:  # 如果不是训练状态或没有启用梯度检查点
                hidden_states = resnet(  # 直接调用 ResNet 模型处理隐藏状态
                    hidden_states,  # 传入当前的隐藏状态
                    temb,  # 传入时间嵌入
                    image_only_indicator=image_only_indicator,  # 传入图像指示器
                )

        if self.upsamplers is not None:  # 如果存在上采样模块
            for upsampler in self.upsamplers:  # 遍历所有上采样模块
                hidden_states = upsampler(hidden_states)  # 调用上采样模块处理隐藏状态

        return hidden_states  # 返回处理后的隐藏状态
# 定义一个时空交叉注意力上采样块类,继承自 nn.Module
class CrossAttnUpBlockSpatioTemporal(nn.Module):
    # 初始化方法,设置网络的各个参数
    def __init__(
        # 输入通道数
        in_channels: int,
        # 输出通道数
        out_channels: int,
        # 前一层输出通道数
        prev_output_channel: int,
        # 时间嵌入通道数
        temb_channels: int,
        # 分辨率索引,可选
        resolution_idx: Optional[int] = None,
        # 层数
        num_layers: int = 1,
        # 每个块的变换器层数,支持单个整数或元组
        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
        # ResNet 的 epsilon 值,防止除零错误
        resnet_eps: float = 1e-6,
        # 注意力头的数量
        num_attention_heads: int = 1,
        # 交叉注意力维度
        cross_attention_dim: int = 1280,
        # 是否添加上采样层
        add_upsample: bool = True,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 存储 ResNet 层的列表
        resnets = []
        # 存储注意力层的列表
        attentions = []

        # 指示是否使用交叉注意力
        self.has_cross_attention = True
        # 设置注意力头的数量
        self.num_attention_heads = num_attention_heads

        # 如果是整数,将其转换为列表,包含 num_layers 个相同的元素
        if isinstance(transformer_layers_per_block, int):
            transformer_layers_per_block = [transformer_layers_per_block] * num_layers

        # 遍历每一层,构建 ResNet 和注意力层
        for i in range(num_layers):
            # 根据当前层是否是最后一层来决定跳过连接的通道数
            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
            # 根据当前层决定 ResNet 的输入通道数
            resnet_in_channels = prev_output_channel if i == 0 else out_channels

            # 添加时空 ResNet 块到列表中
            resnets.append(
                SpatioTemporalResBlock(
                    in_channels=resnet_in_channels + res_skip_channels,
                    out_channels=out_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                )
            )
            # 添加时空变换器模型到列表中
            attentions.append(
                TransformerSpatioTemporalModel(
                    num_attention_heads,
                    out_channels // num_attention_heads,
                    in_channels=out_channels,
                    num_layers=transformer_layers_per_block[i],
                    cross_attention_dim=cross_attention_dim,
                )
            )

        # 将注意力层列表转换为 nn.ModuleList,以便于管理
        self.attentions = nn.ModuleList(attentions)
        # 将 ResNet 层列表转换为 nn.ModuleList,以便于管理
        self.resnets = nn.ModuleList(resnets)

        # 如果需要,添加上采样层
        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
        else:
            # 否则上采样层设置为 None
            self.upsamplers = None

        # 设置梯度检查点标志为 False
        self.gradient_checkpointing = False
        # 存储分辨率索引
        self.resolution_idx = resolution_idx

    # 前向传播方法,定义输入和输出
    def forward(
        # 隐藏状态张量
        hidden_states: torch.Tensor,
        # 上一层隐藏状态的元组
        res_hidden_states_tuple: Tuple[torch.Tensor, ...],
        # 可选的时间嵌入张量
        temb: Optional[torch.Tensor] = None,
        # 可选的编码器隐藏状态张量
        encoder_hidden_states: Optional[torch.Tensor] = None,
        # 可选的图像指示器张量
        image_only_indicator: Optional[torch.Tensor] = None,
    # 返回一个 torch.Tensor 类型的结果
    ) -> torch.Tensor:
        # 遍历每个 resnet 和 attention 模块的组合
        for resnet, attn in zip(self.resnets, self.attentions):
            # 从隐藏状态元组中弹出最后一个 res 隐藏状态
            res_hidden_states = res_hidden_states_tuple[-1]
            # 更新隐藏状态元组,去掉最后一个元素
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
    
            # 在指定维度连接当前的 hidden_states 和 res_hidden_states
            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
    
            # 如果处于训练模式且开启了梯度检查点
            if self.training and self.gradient_checkpointing:  # TODO
                # 定义一个用于创建自定义前向传播函数的函数
                def create_custom_forward(module, return_dict=None):
                    # 定义自定义前向传播逻辑
                    def custom_forward(*inputs):
                        # 根据是否返回字典,调用模块的前向传播
                        if return_dict is not None:
                            return module(*inputs, return_dict=return_dict)
                        else:
                            return module(*inputs)
    
                    return custom_forward
    
                # 根据 PyTorch 版本选择 checkpoint 的参数
                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                # 使用检查点保存内存,调用自定义前向传播
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(resnet),
                    hidden_states,
                    temb,
                    image_only_indicator,
                    **ckpt_kwargs,
                )
                # 通过 attention 模块处理 hidden_states
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
            else:
                # 如果不使用检查点,直接通过 resnet 模块处理 hidden_states
                hidden_states = resnet(
                    hidden_states,
                    temb,
                    image_only_indicator=image_only_indicator,
                )
                # 通过 attention 模块处理 hidden_states
                hidden_states = attn(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    image_only_indicator=image_only_indicator,
                    return_dict=False,
                )[0]
    
        # 如果存在上采样模块,逐个应用于 hidden_states
        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)
    
        # 返回处理后的 hidden_states
        return hidden_states
posted @ 2024-10-22 12:39  绝不原创的飞龙  阅读(34)  评论(0编辑  收藏  举报