diffusers-源码解析-十六-

diffusers 源码解析(十六)

.\diffusers\models\unets\unet_spatio_temporal_condition.py

# 从数据类模块导入数据类装饰器
from dataclasses import dataclass
# 导入字典、可选、元组和联合类型的类型注解
from typing import Dict, Optional, Tuple, Union

# 导入 PyTorch 库
import torch
import torch.nn as nn

# 导入配置相关的工具
from ...configuration_utils import ConfigMixin, register_to_config
# 导入用于加载 UNet2D 条件模型的混合类
from ...loaders import UNet2DConditionLoadersMixin
# 导入基本输出类和日志工具
from ...utils import BaseOutput, logging
# 导入注意力处理器相关内容
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
# 导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 导入模型混合类
from ..modeling_utils import ModelMixin
# 导入 UNet 3D 块的相关功能
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block

# 获取当前模块的日志记录器
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

# 定义一个数据类,用于存储 UNet 空间时间条件模型的输出
@dataclass
class UNetSpatioTemporalConditionOutput(BaseOutput):
    """
    [`UNetSpatioTemporalConditionModel`] 的输出。

    参数:
        sample (`torch.Tensor` 形状为 `(batch_size, num_frames, num_channels, height, width)`):
            根据 `encoder_hidden_states` 输入条件的隐藏状态输出。模型最后一层的输出。
    """

    # 定义一个可选的张量,默认为 None
    sample: torch.Tensor = None

# 定义一个条件空间时间 UNet 模型类
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
    r"""
    一个条件空间时间 UNet 模型,接受噪声视频帧、条件状态和时间步,返回指定形状的样本输出。

    此模型继承自 [`ModelMixin`]。请查看超类文档以了解为所有模型实现的通用方法
    (例如下载或保存)。
    # 函数参数说明
    Parameters:
        # 输入/输出样本的高度和宽度,类型为整型或整型元组,默认为 None
        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
            Height and width of input/output sample.
        # 输入样本的通道数,默认为 8
        in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
        # 输出样本的通道数,默认为 4
        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
        # 用于下采样的块的元组,默认为指定的四个下采样块
        down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
            The tuple of downsample blocks to use.
        # 用于上采样的块的元组,默认为指定的四个上采样块
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
            The tuple of upsample blocks to use.
        # 每个块的输出通道数的元组,默认为 (320, 640, 1280, 1280)
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
            The tuple of output channels for each block.
        # 用于编码附加时间 ID 的维度,默认为 256
        addition_time_embed_dim: (`int`, defaults to 256):
            Dimension to encode the additional time ids.
        # 编码 `added_time_ids` 的投影维度,默认为 768
        projection_class_embeddings_input_dim (`int`, defaults to 768):
            The dimension of the projection of encoded `added_time_ids`.
        # 每个块的层数,默认为 2
        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
        # 交叉注意力特征的维度,类型为整型或整型元组,默认为 1280
        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
            The dimension of the cross attention features.
        # 变换器块的数量,相关于特定类型的下/上块,默认为 1
        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
            [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
            [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
            [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
        # 注意力头的数量,默认为 (5, 10, 10, 20)
        num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
            The number of attention heads.
        # 使用的 dropout 概率,默认为 0.0
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
    """

    # 启用梯度检查点
    _supports_gradient_checkpointing = True

    # 注册到配置中
    @register_to_config
    # 初始化方法,用于创建类的实例并设置其属性
        def __init__(
            self,
            # 可选的样本大小参数,默认为 None
            sample_size: Optional[int] = None,
            # 输入通道数量,默认为 8
            in_channels: int = 8,
            # 输出通道数量,默认为 4
            out_channels: int = 4,
            # 各下采样块的类型,默认为指定的类型元组
            down_block_types: Tuple[str] = (
                "CrossAttnDownBlockSpatioTemporal",
                "CrossAttnDownBlockSpatioTemporal",
                "CrossAttnDownBlockSpatioTemporal",
                "DownBlockSpatioTemporal",
            ),
            # 各上采样块的类型,默认为指定的类型元组
            up_block_types: Tuple[str] = (
                "UpBlockSpatioTemporal",
                "CrossAttnUpBlockSpatioTemporal",
                "CrossAttnUpBlockSpatioTemporal",
                "CrossAttnUpBlockSpatioTemporal",
            ),
            # 各块输出通道数量,默认为指定的整数元组
            block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
            # 附加时间嵌入维度,默认为 256
            addition_time_embed_dim: int = 256,
            # 投影类嵌入输入维度,默认为 768
            projection_class_embeddings_input_dim: int = 768,
            # 每个块的层数,可以是整数或整数元组,默认为 2
            layers_per_block: Union[int, Tuple[int]] = 2,
            # 交叉注意力维度,可以是整数或整数元组,默认为 1024
            cross_attention_dim: Union[int, Tuple[int]] = 1024,
            # 每个块的变换器层数,可以是整数或元组,默认为 1
            transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
            # 注意力头的数量,可以是整数或整数元组,默认为 (5, 10, 20, 20)
            num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
            # 帧的数量,默认为 25
            num_frames: int = 25,
        # 定义属性装饰器,返回注意力处理器字典
        @property
        def attn_processors(self) -> Dict[str, AttentionProcessor]:
            r"""
            Returns:
                `dict` of attention processors: A dictionary containing all attention processors used in the model with
                indexed by its weight name.
            """
            # 创建一个空字典用于存储注意力处理器
            processors = {}
    
            # 定义递归函数以添加处理器
            def fn_recursive_add_processors(
                # 当前模块名称
                name: str,
                # 当前模块对象
                module: torch.nn.Module,
                # 存储处理器的字典
                processors: Dict[str, AttentionProcessor],
            ):
                # 检查当前模块是否有获取处理器的方法
                if hasattr(module, "get_processor"):
                    # 将处理器添加到字典中,键为模块名称
                    processors[f"{name}.processor"] = module.get_processor()
    
                # 遍历当前模块的所有子模块
                for sub_name, child in module.named_children():
                    # 递归调用,将子模块的处理器添加到字典中
                    fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
                # 返回处理器字典
                return processors
    
            # 遍历当前实例的所有子模块
            for name, module in self.named_children():
                # 调用递归函数以添加所有处理器
                fn_recursive_add_processors(name, module, processors)
    
            # 返回所有注意力处理器的字典
            return processors
    # 设置用于计算注意力的处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r""" 
        设置要用于计算注意力的处理器。
    
        参数:
            processor(`dict` 或 `AttentionProcessor`):
                实例化的处理器类或将被设置为 **所有** `Attention` 层的处理器类字典。
    
                如果 `processor` 是字典,则键需要定义相应的交叉注意力处理器的路径。 
                强烈建议在设置可训练的注意力处理器时使用此方法。
    
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 检查传入的处理器字典长度是否与注意力层数量匹配
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量 {count} 不匹配。"
                f" 请确保传入 {count} 个处理器类。"
            )
    
        # 定义递归设置注意力处理器的函数
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块具有设置处理器的属性
            if hasattr(module, "set_processor"):
                # 如果处理器不是字典,直接设置
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    # 从字典中取出相应的处理器并设置
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历模块的子模块
            for sub_name, child in module.named_children():
                # 递归调用设置处理器的函数
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的子模块
        for name, module in self.named_children():
            # 递归设置注意力处理器
            fn_recursive_attn_processor(name, module, processor)
    
    # 设置默认的注意力处理器
    def set_default_attn_processor(self):
        """
        禁用自定义注意力处理器,并设置默认的注意力实现。
        """
        # 检查当前的处理器是否都在交叉注意力处理器列表中
        if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 创建默认的注意力处理器实例
            processor = AttnProcessor()
        else:
            # 抛出错误,不能设置默认处理器
            raise ValueError(
                f"当注意力处理器类型为 {next(iter(self.attn_processors.values()))} 时,无法调用 `set_default_attn_processor`"
            )
    
        # 设置默认的注意力处理器
        self.set_attn_processor(processor)
    
    # 设置模块的梯度检查点
    def _set_gradient_checkpointing(self, module, value=False):
        # 如果模块具有梯度检查点属性,则设置其值
        if hasattr(module, "gradient_checkpointing"):
            module.gradient_checkpointing = value
    
    # 从 diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 复制的内容
    # 定义一个方法以启用前馈层的分块处理,参数为分块大小和维度
    def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
        """
        设置注意力处理器以使用 [前馈分块处理](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers)。
    
        参数:
            chunk_size (`int`, *可选*):
                前馈层的分块大小。如果未指定,将单独在每个维度为 `dim` 的张量上运行前馈层。
            dim (`int`, *可选*, 默认值为 `0`):
                前馈计算应该在哪个维度上进行分块。选择 dim=0(批量)或 dim=1(序列长度)。
        """
        # 如果 dim 不是 0 或 1,则引发值错误
        if dim not in [0, 1]:
            raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
    
        # 默认的分块大小为 1
        chunk_size = chunk_size or 1
    
        # 定义一个递归函数以设置每个子模块的前馈分块处理
        def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
            # 如果模块有设置前馈分块的方法,则调用该方法
            if hasattr(module, "set_chunk_feed_forward"):
                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
    
            # 遍历模块的所有子模块并递归调用
            for child in module.children():
                fn_recursive_feed_forward(child, chunk_size, dim)
    
        # 遍历当前对象的所有子模块并调用递归函数
        for module in self.children():
            fn_recursive_feed_forward(module, chunk_size, dim)
    
    # 定义前馈方法,接收样本、时间步、编码器隐藏状态、额外时间ID和返回字典参数
    def forward(
        self,
        sample: torch.Tensor,
        timestep: Union[torch.Tensor, float, int],
        encoder_hidden_states: torch.Tensor,
        added_time_ids: torch.Tensor,
        return_dict: bool = True,

.\diffusers\models\unets\unet_stable_cascade.py

# 版权声明,指明该文件的版权归 HuggingFace 团队所有
# 
# 根据 Apache 2.0 许可协议进行许可;
# 除非符合许可,否则您不能使用此文件。
# 您可以在以下网址获得许可的副本:
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# 除非适用法律或书面协议另有约定,
# 根据该许可分发的软件是按“原样”基础提供的,
# 不提供任何明示或暗示的保证或条件。
# 请参阅许可以获取特定语言的权限和
# 限制条款。

# 导入数学模块,用于数学计算
import math
# 从数据类模块导入数据类装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入可选类型、元组和联合类型的类型注解
from typing import Optional, Tuple, Union

# 导入 NumPy 库,用于数组和矩阵操作
import numpy as np
# 导入 PyTorch 库及其子模块,用于构建和训练神经网络
import torch
import torch.nn as nn

# 从配置工具导入配置混合类和注册配置的函数
from ...configuration_utils import ConfigMixin, register_to_config
# 从加载器模块导入原始模型混合类
from ...loaders import FromOriginalModelMixin
# 从实用工具导入基础输出类
from ...utils import BaseOutput
# 从注意力处理器模块导入注意力类
from ..attention_processor import Attention
# 从建模工具模块导入模型混合类
from ..modeling_utils import ModelMixin

# 定义一个层归一化类,继承自 nn.LayerNorm
# 从 diffusers.pipelines.wuerstchen.modeling_wuerstchen_common 中复制,并重命名为 SDCascadeLayerNorm
class SDCascadeLayerNorm(nn.LayerNorm):
    # 初始化方法,接受可变参数并调用父类构造函数
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # 前向传播方法,接受输入 x
    def forward(self, x):
        # 重新排列 x 的维度,将其形状变为 (batch_size, height, width, channels)
        x = x.permute(0, 2, 3, 1)
        # 调用父类的前向传播方法进行层归一化
        x = super().forward(x)
        # 再次排列 x 的维度,返回到 (batch_size, channels, height, width) 形状
        return x.permute(0, 3, 1, 2)

# 定义时间步块类,继承自 nn.Module
class SDCascadeTimestepBlock(nn.Module):
    # 初始化方法,接受参数 c, c_timestep 和条件列表 conds
    def __init__(self, c, c_timestep, conds=[]):
        super().__init__()

        # 创建一个线性映射层,将时间步的输入转换为两倍的通道数
        self.mapper = nn.Linear(c_timestep, c * 2)
        # 保存条件列表
        self.conds = conds
        # 为每个条件创建一个线性映射层
        for cname in conds:
            setattr(self, f"mapper_{cname}", nn.Linear(c_timestep, c * 2))

    # 前向传播方法,接受输入 x 和时间步 t
    def forward(self, x, t):
        # 将时间步 t 拆分为多个部分
        t = t.chunk(len(self.conds) + 1, dim=1)
        # 使用 mapper 对第一个时间步进行线性映射,并拆分为 a 和 b
        a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
        # 遍历条件列表
        for i, c in enumerate(self.conds):
            # 获取条件的映射结果,并拆分为 ac 和 bc
            ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
            # 将映射结果加到 a 和 b 上
            a, b = a + ac, b + bc
        # 返回经过变换后的 x
        return x * (1 + a) + b

# 定义残差块类,继承自 nn.Module
class SDCascadeResBlock(nn.Module):
    # 初始化方法,接受多个参数定义残差块的结构
    def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
        super().__init__()
        # 创建深度可分离卷积层
        self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
        # 创建自定义的层归一化
        self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
        # 创建一个包含多个层的顺序模块
        self.channelwise = nn.Sequential(
            nn.Linear(c + c_skip, c * 4),  # 线性变换
            nn.GELU(),                      # 激活函数
            GlobalResponseNorm(c * 4),      # 全局响应归一化
            nn.Dropout(dropout),            # Dropout 层
            nn.Linear(c * 4, c),            # 输出层
        )

    # 前向传播方法,接受输入 x 和可选的跳跃连接 x_skip
    def forward(self, x, x_skip=None):
        # 保存输入 x 的副本,用于残差连接
        x_res = x
        # 经过深度卷积和归一化
        x = self.norm(self.depthwise(x))
        # 如果提供了跳跃连接,则将其与 x 拼接
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        # 对 x 进行通道变换并返回到原始形状
        x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # 返回残差连接的结果
        return x + x_res

# 定义全局响应归一化类,继承自 nn.Module
# 从 Facebook Research 的 ConvNeXt-V2 项目中获取代码
class GlobalResponseNorm(nn.Module):
    # 初始化方法,接收一个维度参数
        def __init__(self, dim):
            # 调用父类的初始化方法
            super().__init__()
            # 创建一个可学习的参数 gamma,形状为 (1, 1, 1, dim),初始化为零
            self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
            # 创建一个可学习的参数 beta,形状为 (1, 1, 1, dim),初始化为零
            self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
    
    # 前向传播方法,定义如何计算输出
        def forward(self, x):
            # 计算输入 x 的 L2 范数,维度为 (1, 2),保持维度不变
            agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
            # 将范数归一化,通过均值进行标准化,防止除以零
            stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
            # 返回标准化后的 x 乘以 gamma,加上 beta 和原始 x,形成最终输出
            return self.gamma * (x * stand_div_norm) + self.beta + x
# 定义一个名为 SDCascadeAttnBlock 的类,继承自 nn.Module
class SDCascadeAttnBlock(nn.Module):
    # 初始化函数,接收多个参数设置
    def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
        # 调用父类的初始化函数
        super().__init__()

        # 设置自注意力标志
        self.self_attn = self_attn
        # 创建归一化层,使用 SDCascadeLayerNorm
        self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6)
        # 创建注意力机制实例
        self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
        # 创建键值映射层,由 SiLU 激活和线性层组成
        self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))

    # 前向传播函数,接收输入和键值对
    def forward(self, x, kv):
        # 使用键值映射层处理 kv
        kv = self.kv_mapper(kv)
        # 对输入 x 进行归一化处理
        norm_x = self.norm(x)
        # 如果启用自注意力机制
        if self.self_attn:
            # 获取输入的批大小和通道数
            batch_size, channel, _, _ = x.shape
            # 将归一化后的输入和 kv 连接
            kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
        # 将注意力输出与原输入相加
        x = x + self.attention(norm_x, encoder_hidden_states=kv)
        # 返回处理后的输入
        return x


# 定义一个名为 UpDownBlock2d 的类,继承自 nn.Module
class UpDownBlock2d(nn.Module):
    # 初始化函数,接收输入和输出通道数、模式和启用标志
    def __init__(self, in_channels, out_channels, mode, enabled=True):
        # 调用父类的初始化函数
        super().__init__()
        # 如果模式不支持,抛出异常
        if mode not in ["up", "down"]:
            raise ValueError(f"{mode} not supported")
        # 根据模式创建上采样或下采样的插值层
        interpolation = (
            nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True)
            if enabled
            else nn.Identity()
        )
        # 创建卷积映射层
        mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        # 根据模式将插值层和卷积层组合成模块列表
        self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])

    # 前向传播函数,接收输入 x
    def forward(self, x):
        # 遍历块并依次处理输入
        for block in self.blocks:
            x = block(x)
        # 返回处理后的输入
        return x


# 定义一个数据类 StableCascadeUNetOutput,继承自 BaseOutput
@dataclass
class StableCascadeUNetOutput(BaseOutput):
    # 初始化输出样本,默认值为 None
    sample: torch.Tensor = None


# 定义一个名为 StableCascadeUNet 的类,继承多个混入类
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    # 设置支持梯度检查点标志
    _supports_gradient_checkpointing = True

    # 注册配置装饰器
    @register_to_config
    # 初始化方法,用于设置模型的参数
        def __init__(
            # 输入通道数,默认为16
            in_channels: int = 16,
            # 输出通道数,默认为16
            out_channels: int = 16,
            # 时间步比率嵌入维度,默认为64
            timestep_ratio_embedding_dim: int = 64,
            # 每个补丁的大小,默认为1
            patch_size: int = 1,
            # 条件维度,默认为2048
            conditioning_dim: int = 2048,
            # 每个块的输出通道数,默认为(2048, 2048)
            block_out_channels: Tuple[int] = (2048, 2048),
            # 每层的注意力头数,默认为(32, 32)
            num_attention_heads: Tuple[int] = (32, 32),
            # 每个块的下采样层数,默认为(8, 24)
            down_num_layers_per_block: Tuple[int] = (8, 24),
            # 每个块的上采样层数,默认为(24, 8)
            up_num_layers_per_block: Tuple[int] = (24, 8),
            # 下采样块的重复映射器,默认为(1, 1)
            down_blocks_repeat_mappers: Optional[Tuple[int]] = (
                1,
                1,
            ),
            # 上采样块的重复映射器,默认为(1, 1)
            up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1),
            # 每层的块类型,默认为两个层的不同块类型
            block_types_per_layer: Tuple[Tuple[str]] = (
                ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
                ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
            ),
            # 文本输入通道数,可选
            clip_text_in_channels: Optional[int] = None,
            # 文本池化输入通道数,默认为1280
            clip_text_pooled_in_channels=1280,
            # 图像输入通道数,可选
            clip_image_in_channels: Optional[int] = None,
            # 序列长度,默认为4
            clip_seq=4,
            # EfficientNet输入通道数,可选
            effnet_in_channels: Optional[int] = None,
            # 像素映射器输入通道数,可选
            pixel_mapper_in_channels: Optional[int] = None,
            # 卷积核大小,默认为3
            kernel_size=3,
            # dropout率,默认为(0.1, 0.1)
            dropout: Union[float, Tuple[float]] = (0.1, 0.1),
            # 自注意力标志,默认为True
            self_attn: Union[bool, Tuple[bool]] = True,
            # 时间步条件类型,默认为("sca", "crp")
            timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
            # 切换级别,可选
            switch_level: Optional[Tuple[bool]] = None,
        # 设置梯度检查点的方法,默认为False
        def _set_gradient_checkpointing(self, value=False):
            # 存储梯度检查点的布尔值
            self.gradient_checkpointing = value
    
        # 初始化权重的方法
        def _init_weights(self, m):
            # 如果是卷积层或线性层
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                # 使用Xavier均匀分布初始化权重
                torch.nn.init.xavier_uniform_(m.weight)
                # 如果有偏置,则将偏置初始化为0
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
            # 对文本池化映射器的权重进行正态分布初始化,标准差为0.02
            nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02)
            # 如果有文本映射器,则对其权重进行正态分布初始化
            nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) if hasattr(self, "clip_txt_mapper") else None
            # 如果有图像映射器,则对其权重进行正态分布初始化
            nn.init.normal_(self.clip_img_mapper.weight, std=0.02) if hasattr(self, "clip_img_mapper") else None
    
            # 如果有EfficientNet映射器,则对其权重进行初始化
            if hasattr(self, "effnet_mapper"):
                nn.init.normal_(self.effnet_mapper[0].weight, std=0.02)  # 条件层
                nn.init.normal_(self.effnet_mapper[2].weight, std=0.02)  # 条件层
    
            # 如果有像素映射器,则对其权重进行初始化
            if hasattr(self, "pixels_mapper"):
                nn.init.normal_(self.pixels_mapper[0].weight, std=0.02)  # 条件层
                nn.init.normal_(self.pixels_mapper[2].weight, std=0.02)  # 条件层
    
            # 对嵌入层的权重进行Xavier均匀分布初始化
            torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02)  # 输入层
            # 将分类器的权重初始化为0
            nn.init.constant_(self.clf[1].weight, 0)  # 输出层
    
            # 初始化块的权重
            for level_block in self.down_blocks + self.up_blocks:
                # 遍历每个块
                for block in level_block:
                    # 如果是SDCascadeResBlock类型
                    if isinstance(block, SDCascadeResBlock):
                        # 对最后一层的权重进行调整
                        block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks[0]))
                    # 如果是SDCascadeTimestepBlock类型
                    elif isinstance(block, SDCascadeTimestepBlock):
                        # 将映射器的权重初始化为0
                        nn.init.constant_(block.mapper.weight, 0)
    # 定义获取时间步比率嵌入的方法,输入时间步比率和最大位置数
        def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000):
            # 计算时间步比率与最大位置数的乘积
            r = timestep_ratio * max_positions
            # 计算嵌入维度的一半
            half_dim = self.config.timestep_ratio_embedding_dim // 2
    
            # 根据最大位置数和一半维度计算嵌入的基础值
            emb = math.log(max_positions) / (half_dim - 1)
            # 生成从0到half_dim的张量,乘以负的基础值并取指数
            emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
            # 将时间步比率和嵌入结合,扩展维度
            emb = r[:, None] * emb[None, :]
            # 将正弦和余弦值拼接在一起
            emb = torch.cat([emb.sin(), emb.cos()], dim=1)
    
            # 如果嵌入维度为奇数,则进行零填充
            if self.config.timestep_ratio_embedding_dim % 2 == 1:  # zero pad
                emb = nn.functional.pad(emb, (0, 1), mode="constant")
    
            # 将嵌入转换为与r相同的数据类型并返回
            return emb.to(dtype=r.dtype)
    
        # 定义获取CLIP嵌入的方法,输入文本和图像的池化结果
        def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None):
            # 如果文本池的形状为二维,增加一个维度
            if len(clip_txt_pooled.shape) == 2:
                clip_txt_pool = clip_txt_pooled.unsqueeze(1)
            # 将文本池通过映射器转换并调整维度
            clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
                clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.config.clip_seq, -1
            )
            # 如果提供了文本和图像,进行相应的映射和拼接
            if clip_txt is not None and clip_img is not None:
                clip_txt = self.clip_txt_mapper(clip_txt)
                # 如果图像的形状为二维,增加一个维度
                if len(clip_img.shape) == 2:
                    clip_img = clip_img.unsqueeze(1)
                # 将图像通过映射器转换并调整维度
                clip_img = self.clip_img_mapper(clip_img).view(
                    clip_img.size(0), clip_img.size(1) * self.config.clip_seq, -1
                )
                # 将文本、文本池和图像拼接在一起
                clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
            else:
                # 如果没有图像,只返回文本池
                clip = clip_txt_pool
            # 对最终的CLIP嵌入进行归一化并返回
            return self.clip_norm(clip)
    # 定义一个私有方法 _down_encode,接受输入 x,r_embed 和 clip
        def _down_encode(self, x, r_embed, clip):
            # 初始化一个空列表,用于存储每个层的输出
            level_outputs = []
            # 将 down_blocks、down_downscalers 和 down_repeat_mappers 组合成一个可迭代的元组
            block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
    
            # 如果处于训练模式并启用梯度检查点
            if self.training and self.gradient_checkpointing:
    
                # 定义一个用于创建自定义前向传播的方法
                def create_custom_forward(module):
                    # 定义一个自定义前向传播,接受任意输入
                    def custom_forward(*inputs):
                        return module(*inputs)
    
                    return custom_forward
    
                # 遍历 block_group 中的每一组 down_block、downscaler 和 repmap
                for down_block, downscaler, repmap in block_group:
                    # 使用 downscaler 对输入 x 进行下采样
                    x = downscaler(x)
                    # 遍历 repmap 的长度加一
                    for i in range(len(repmap) + 1):
                        # 遍历 down_block 中的每个块
                        for block in down_block:
                            # 如果块是 SDCascadeResBlock 类型
                            if isinstance(block, SDCascadeResBlock):
                                # 使用梯度检查点进行前向传播
                                x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
                            # 如果块是 SDCascadeAttnBlock 类型
                            elif isinstance(block, SDCascadeAttnBlock):
                                # 使用梯度检查点进行前向传播,传入 clip
                                x = torch.utils.checkpoint.checkpoint(
                                    create_custom_forward(block), x, clip, use_reentrant=False
                                )
                            # 如果块是 SDCascadeTimestepBlock 类型
                            elif isinstance(block, SDCascadeTimestepBlock):
                                # 使用梯度检查点进行前向传播,传入 r_embed
                                x = torch.utils.checkpoint.checkpoint(
                                    create_custom_forward(block), x, r_embed, use_reentrant=False
                                )
                            # 其他块类型
                            else:
                                # 使用梯度检查点进行前向传播
                                x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), use_reentrant=False)
                        # 如果 i 小于 repmap 的长度
                        if i < len(repmap):
                            # 使用当前的 repmap 对 x 进行处理
                            x = repmap[i](x)
                    # 将当前层的输出插入到 level_outputs 的开头
                    level_outputs.insert(0, x)
            # 如果不是训练模式或未启用梯度检查点
            else:
                # 遍历 block_group 中的每一组 down_block、downscaler 和 repmap
                for down_block, downscaler, repmap in block_group:
                    # 使用 downscaler 对输入 x 进行下采样
                    x = downscaler(x)
                    # 遍历 repmap 的长度加一
                    for i in range(len(repmap) + 1):
                        # 遍历 down_block 中的每个块
                        for block in down_block:
                            # 如果块是 SDCascadeResBlock 类型
                            if isinstance(block, SDCascadeResBlock):
                                # 直接对 x 进行前向传播
                                x = block(x)
                            # 如果块是 SDCascadeAttnBlock 类型
                            elif isinstance(block, SDCascadeAttnBlock):
                                # 直接对 x 进行前向传播,传入 clip
                                x = block(x, clip)
                            # 如果块是 SDCascadeTimestepBlock 类型
                            elif isinstance(block, SDCascadeTimestepBlock):
                                # 直接对 x 进行前向传播,传入 r_embed
                                x = block(x, r_embed)
                            # 其他块类型
                            else:
                                # 直接对 x 进行前向传播
                                x = block(x)
                        # 如果 i 小于 repmap 的长度
                        if i < len(repmap):
                            # 使用当前的 repmap 对 x 进行处理
                            x = repmap[i](x)
                    # 将当前层的输出插入到 level_outputs 的开头
                    level_outputs.insert(0, x)
            # 返回所有层的输出
            return level_outputs
    
        # 定义前向传播方法,接受多个参数
        def forward(
            self,
            sample,
            timestep_ratio,
            clip_text_pooled,
            clip_text=None,
            clip_img=None,
            effnet=None,
            pixels=None,
            sca=None,
            crp=None,
            return_dict=True,
    ):
        # 如果 pixels 参数为 None,则初始化为一个全零的张量,尺寸为 (3, 8, 8)
        if pixels is None:
            pixels = sample.new_zeros(sample.size(0), 3, 8, 8)

        # 处理时间步比率嵌入
        timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio)
        # 遍历配置中的时间步条件类型
        for c in self.config.timestep_conditioning_type:
            # 如果条件类型是 "sca",则使用 sca 作为条件
            if c == "sca":
                cond = sca
            # 如果条件类型是 "crp",则使用 crp 作为条件
            elif c == "crp":
                cond = crp
            # 否则条件为 None
            else:
                cond = None
            # 如果 cond 为 None,则使用与 timestep_ratio 同形状的零张量
            t_cond = cond or torch.zeros_like(timestep_ratio)
            # 将时间步比率嵌入与条件嵌入进行拼接
            timestep_ratio_embed = torch.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], dim=1)
        # 获取 CLIP 嵌入
        clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img)

        # 模型块
        # 对样本进行嵌入
        x = self.embedding(sample)
        # 如果存在 effnet_mapper 且 effnet 不为 None,则进行映射
        if hasattr(self, "effnet_mapper") and effnet is not None:
            x = x + self.effnet_mapper(
                # 对 effnet 进行上采样,调整到与 x 相同的空间尺寸
                nn.functional.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
            )
        # 如果存在 pixels_mapper,则进行映射
        if hasattr(self, "pixels_mapper"):
            x = x + nn.functional.interpolate(
                # 对 pixels 进行映射并上采样,调整到与 x 相同的空间尺寸
                self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True
            )
        # 通过下采样编码器处理 x 和其他嵌入
        level_outputs = self._down_encode(x, timestep_ratio_embed, clip)
        # 通过上采样解码器处理 level_outputs
        x = self._up_decode(level_outputs, timestep_ratio_embed, clip)
        # 使用分类器生成最终样本
        sample = self.clf(x)

        # 如果不需要返回字典格式的结果,则返回单个样本元组
        if not return_dict:
            return (sample,)
        # 返回 StableCascadeUNetOutput 对象,包含样本
        return StableCascadeUNetOutput(sample=sample)

.\diffusers\models\unets\uvit_2d.py

# coding=utf-8  # 指定文件编码为 UTF-8
# Copyright 2024 The HuggingFace Inc. team.  # 版权声明,标识文件归 HuggingFace Inc. 团队所有
#
# Licensed under the Apache License, Version 2.0 (the "License");  # 表明文件受 Apache 2.0 许可证保护
# you may not use this file except in compliance with the License.  # 指明用户必须遵循许可证的使用条款
# You may obtain a copy of the License at  # 提供获取许可证的方式
#
#     http://www.apache.org/licenses/LICENSE-2.0  # 许可证的具体链接
#
# Unless required by applicable law or agreed to in writing, software  # 表明软件是按 "AS IS" 基础分发
# distributed under the License is distributed on an "AS IS" BASIS,  # 不提供任何明示或暗示的担保
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  # 用户须自行承担使用风险
# See the License for the specific language governing permissions and  # 引导用户查看许可证了解权限
# limitations under the License.  # 许可证的限制条款

from typing import Dict, Union  # 从 typing 模块导入字典和联合类型
import torch  # 导入 PyTorch 库
import torch.nn.functional as F  # 导入 PyTorch 的功能性神经网络模块
from torch import nn  # 从 PyTorch 导入神经网络模块
from torch.utils.checkpoint import checkpoint  # 从 PyTorch 导入检查点功能,用于节省内存

from ...configuration_utils import ConfigMixin, register_to_config  # 从配置工具导入混合类和注册配置函数
from ...loaders import PeftAdapterMixin  # 从加载器导入适配器混合类
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock  # 从注意力模块导入基本变换块和跳过前馈变换块
from ..attention_processor import (  # 从注意力处理器导入相关组件
    ADDED_KV_ATTENTION_PROCESSORS,  # 导入增加的键值注意力处理器
    CROSS_ATTENTION_PROCESSORS,  # 导入交叉注意力处理器
    AttentionProcessor,  # 导入注意力处理器类
    AttnAddedKVProcessor,  # 导入增加键值注意力处理器类
    AttnProcessor,  # 导入注意力处理器基类
)
from ..embeddings import TimestepEmbedding, get_timestep_embedding  # 从嵌入模块导入时间步嵌入及其获取函数
from ..modeling_utils import ModelMixin  # 从建模工具导入模型混合类
from ..normalization import GlobalResponseNorm, RMSNorm  # 从归一化模块导入全局响应归一化和 RMS 归一化
from ..resnet import Downsample2D, Upsample2D  # 从 ResNet 模块导入二维下采样和上采样

class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):  # 定义 UVit2DModel 类,继承多个混合类
    _supports_gradient_checkpointing = True  # 声明支持梯度检查点,节省内存

    @register_to_config  # 注册到配置的装饰器
    def __init__(  # 初始化方法
        self,  # 实例自身
        # global config  # 全局配置说明
        hidden_size: int = 1024,  # 隐藏层大小,默认为 1024
        use_bias: bool = False,  # 是否使用偏置,默认为 False
        hidden_dropout: float = 0.0,  # 隐藏层 dropout 概率,默认为 0
        # conditioning dimensions  # 条件维度说明
        cond_embed_dim: int = 768,  # 条件嵌入维度,默认为 768
        micro_cond_encode_dim: int = 256,  # 微条件编码维度,默认为 256
        micro_cond_embed_dim: int = 1280,  # 微条件嵌入维度,默认为 1280
        encoder_hidden_size: int = 768,  # 编码器隐藏层大小,默认为 768
        # num tokens  # 令牌数量说明
        vocab_size: int = 8256,  # 词汇表大小,默认为 8256(包括掩码令牌)
        codebook_size: int = 8192,  # 代码本大小,默认为 8192
        # `UVit2DConvEmbed`  # UVit2D 卷积嵌入说明
        in_channels: int = 768,  # 输入通道数,默认为 768
        block_out_channels: int = 768,  # 块输出通道数,默认为 768
        num_res_blocks: int = 3,  # 残差块数量,默认为 3
        downsample: bool = False,  # 是否进行下采样,默认为 False
        upsample: bool = False,  # 是否进行上采样,默认为 False
        block_num_heads: int = 12,  # 块头数,默认为 12
        # `TransformerLayer`  # 变换层说明
        num_hidden_layers: int = 22,  # 隐藏层数量,默认为 22
        num_attention_heads: int = 16,  # 注意力头数量,默认为 16
        # `Attention`  # 注意力说明
        attention_dropout: float = 0.0,  # 注意力层 dropout 概率,默认为 0
        # `FeedForward`  # 前馈层说明
        intermediate_size: int = 2816,  # 前馈层中间大小,默认为 2816
        # `Norm`  # 归一化说明
        layer_norm_eps: float = 1e-6,  # 层归一化的 epsilon 值,默认为 1e-6
        ln_elementwise_affine: bool = True,  # 是否使用元素级仿射,默认为 True
        sample_size: int = 64,  # 采样大小,默认为 64
    # 初始化父类
        ):
            super().__init__()
    
            # 创建一个线性层,用于编码器的输出投影
            self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
            # 创建 RMSNorm 层,对编码器输出进行层归一化
            self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
    
            # 初始化 UVit2DConvEmbed,进行输入通道到嵌入的转换
            self.embed = UVit2DConvEmbed(
                in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
            )
    
            # 创建时间步嵌入层,用于条件输入的嵌入
            self.cond_embed = TimestepEmbedding(
                micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias
            )
    
            # 创建下采样块,包含多个残差块
            self.down_block = UVitBlock(
                block_out_channels,
                num_res_blocks,
                hidden_size,
                hidden_dropout,
                ln_elementwise_affine,
                layer_norm_eps,
                use_bias,
                block_num_heads,
                attention_dropout,
                downsample,
                False,
            )
    
            # 创建 RMSNorm 层,用于隐藏状态的归一化
            self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine)
            # 创建线性层,将投影结果转换为隐藏层大小
            self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias)
    
            # 创建一个模块列表,包含多个基本的 Transformer 块
            self.transformer_layers = nn.ModuleList(
                [
                    BasicTransformerBlock(
                        dim=hidden_size,
                        num_attention_heads=num_attention_heads,
                        attention_head_dim=hidden_size // num_attention_heads,
                        dropout=hidden_dropout,
                        cross_attention_dim=hidden_size,
                        attention_bias=use_bias,
                        norm_type="ada_norm_continuous",
                        ada_norm_continous_conditioning_embedding_dim=hidden_size,
                        norm_elementwise_affine=ln_elementwise_affine,
                        norm_eps=layer_norm_eps,
                        ada_norm_bias=use_bias,
                        ff_inner_dim=intermediate_size,
                        ff_bias=use_bias,
                        attention_out_bias=use_bias,
                    )
                    for _ in range(num_hidden_layers)  # 遍历生成指定数量的 Transformer 块
                ]
            )
    
            # 创建 RMSNorm 层,用于隐藏状态的归一化
            self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine)
            # 创建线性层,将隐藏层转换为块输出通道
            self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias)
    
            # 创建上采样块,包含多个残差块
            self.up_block = UVitBlock(
                block_out_channels,
                num_res_blocks,
                hidden_size,
                hidden_dropout,
                ln_elementwise_affine,
                layer_norm_eps,
                use_bias,
                block_num_heads,
                attention_dropout,
                downsample=False,
                upsample=upsample,
            )
    
            # 创建卷积 MLM 层,用于生成模型的最终输出
            self.mlm_layer = ConvMlmLayer(
                block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size
            )
    
            # 初始化梯度检查点标志为 False
            self.gradient_checkpointing = False
    
        # 定义梯度检查点设置函数,默认不启用
        def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
            pass
    # 定义前向传播方法,接受输入 IDs、编码器隐藏状态、池化文本嵌入、微条件及交叉注意力参数
    def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
        # 对编码器隐藏状态进行线性变换
        encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
        # 对编码器隐藏状态进行层归一化
        encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
    
        # 获取微条件的时间步嵌入,应用特定的配置参数
        micro_cond_embeds = get_timestep_embedding(
            micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
        )
    
        # 调整微条件嵌入的形状,匹配输入 ID 的批大小
        micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1))
    
        # 将池化文本嵌入和微条件嵌入在维度1上连接
        pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1)
        # 将池化文本嵌入转换为指定的数据类型
        pooled_text_emb = pooled_text_emb.to(dtype=self.dtype)
        # 对池化文本嵌入进行条件嵌入并转换为编码器隐藏状态的数据类型
        pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype)
    
        # 获取输入 ID 的嵌入表示
        hidden_states = self.embed(input_ids)
    
        # 将隐藏状态通过下一个模块处理
        hidden_states = self.down_block(
            hidden_states,
            pooled_text_emb=pooled_text_emb,
            encoder_hidden_states=encoder_hidden_states,
            cross_attention_kwargs=cross_attention_kwargs,
        )
    
        # 获取隐藏状态的批大小、通道、高度和宽度
        batch_size, channels, height, width = hidden_states.shape
        # 调整隐藏状态的维度顺序并重塑形状
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
    
        # 对隐藏状态进行规范化投影
        hidden_states = self.project_to_hidden_norm(hidden_states)
        # 对隐藏状态进行线性投影
        hidden_states = self.project_to_hidden(hidden_states)
    
        # 遍历每个变换层
        for layer in self.transformer_layers:
            # 如果在训练模式下并启用梯度检查点
            if self.training and self.gradient_checkpointing:
                # 定义一个带检查点的层
                def layer_(*args):
                    return checkpoint(layer, *args)
            else:
                # 否则直接使用层
                layer_ = layer
    
            # 通过当前层处理隐藏状态
            hidden_states = layer_(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                cross_attention_kwargs=cross_attention_kwargs,
                added_cond_kwargs={"pooled_text_emb": pooled_text_emb},
            )
    
        # 对隐藏状态进行规范化投影
        hidden_states = self.project_from_hidden_norm(hidden_states)
        # 对隐藏状态进行线性投影
        hidden_states = self.project_from_hidden(hidden_states)
    
        # 重塑隐藏状态以匹配图像维度并调整维度顺序
        hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
    
        # 将隐藏状态通过上一个模块处理
        hidden_states = self.up_block(
            hidden_states,
            pooled_text_emb=pooled_text_emb,
            encoder_hidden_states=encoder_hidden_states,
            cross_attention_kwargs=cross_attention_kwargs,
        )
    
        # 通过 MLM 层获取最终的 logits
        logits = self.mlm_layer(hidden_states)
    
        # 返回最终的 logits
        return logits
    
    # 定义一个只读属性,可能用于后续处理
    @property
    # 定义返回注意力处理器的字典的方法
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        返回值:
            `dict` 的注意力处理器: 一个包含模型中所有注意力处理器的字典,以权重名称为索引。
        """
        # 创建一个空字典,用于存储注意力处理器
        processors = {}
    
        # 定义递归函数,用于添加处理器
        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
            # 如果模块有获取处理器的方法,则将其添加到字典中
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()
    
            # 遍历模块的子模块,递归调用
            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
    
            return processors
    
        # 遍历当前对象的子模块,调用递归函数
        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)
    
        # 返回所有收集到的处理器
        return processors
    
    # 从 UNet2DConditionModel 复制的方法,用于设置注意力处理器
    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
        r"""
        设置用于计算注意力的处理器。
    
        参数:
            processor (`dict` of `AttentionProcessor` 或 `AttentionProcessor`):
                实例化的处理器类或将被设置为**所有** `Attention` 层的处理器类的字典。
    
                如果 `processor` 是字典,则键需要定义相应的交叉注意力处理器的路径。强烈建议在设置可训练的注意力处理器时使用。
        """
        # 获取当前注意力处理器的数量
        count = len(self.attn_processors.keys())
    
        # 如果传入的是字典且数量不匹配,抛出异常
        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量: {count} 不匹配。请确保传入 {count} 个处理器类。"
            )
    
        # 定义递归函数,用于设置处理器
        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            # 如果模块有设置处理器的方法,则设置处理器
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))
    
            # 遍历子模块,递归调用设置处理器
            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
    
        # 遍历当前对象的子模块,调用递归设置函数
        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)
    
    # 从 UNet2DConditionModel 复制的方法,用于设置默认注意力处理器
    # 定义一个设置默认注意力处理器的方法
    def set_default_attn_processor(self):
        # 文档字符串,说明该方法的功能
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        # 检查所有注意力处理器是否属于新增的 KV 注意力处理器类型
        if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,则创建一个新增 KV 注意力处理器实例
            processor = AttnAddedKVProcessor()
        # 检查所有注意力处理器是否属于交叉注意力处理器类型
        elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
            # 如果是,则创建一个标准注意力处理器实例
            processor = AttnProcessor()
        else:
            # 如果既不是新增 KV 注意力处理器也不是交叉注意力处理器,则抛出错误
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )
    
        # 设置当前实例的注意力处理器为刚创建的处理器
        self.set_attn_processor(processor)
# 定义一个二维卷积嵌入类,继承自 nn.Module
class UVit2DConvEmbed(nn.Module):
    # 初始化方法,接收输入通道数、块输出通道数、词汇表大小等参数
    def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
        # 调用父类构造函数
        super().__init__()
        # 创建嵌入层,将词汇表大小映射到输入通道数
        self.embeddings = nn.Embedding(vocab_size, in_channels)
        # 创建 RMSNorm 层,用于归一化嵌入,支持可选的元素级仿射变换
        self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
        # 创建 2D 卷积层,使用指定的输出通道数和偏置选项
        self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)

    # 前向传播方法,定义输入如何经过该层处理
    def forward(self, input_ids):
        # 根据输入 ID 获取对应的嵌入
        embeddings = self.embeddings(input_ids)
        # 对嵌入进行层归一化处理
        embeddings = self.layer_norm(embeddings)
        # 调整嵌入的维度顺序,以适应卷积层的输入格式
        embeddings = embeddings.permute(0, 3, 1, 2)
        # 通过卷积层处理嵌入
        embeddings = self.conv(embeddings)
        # 返回处理后的嵌入
        return embeddings


# 定义一个 UVit 块类,继承自 nn.Module
class UVitBlock(nn.Module):
    # 初始化方法,接收多个配置参数
    def __init__(
        self,
        channels,
        num_res_blocks: int,
        hidden_size,
        hidden_dropout,
        ln_elementwise_affine,
        layer_norm_eps,
        use_bias,
        block_num_heads,
        attention_dropout,
        downsample: bool,
        upsample: bool,
    ):
        # 调用父类构造函数
        super().__init__()

        # 如果需要下采样,初始化下采样层
        if downsample:
            self.downsample = Downsample2D(
                channels,
                use_conv=True,
                padding=0,
                name="Conv2d_0",
                kernel_size=2,
                norm_type="rms_norm",
                eps=layer_norm_eps,
                elementwise_affine=ln_elementwise_affine,
                bias=use_bias,
            )
        else:
            # 否则将下采样层设为 None
            self.downsample = None

        # 创建残差块列表,包含指定数量的卷积块
        self.res_blocks = nn.ModuleList(
            [
                ConvNextBlock(
                    channels,
                    layer_norm_eps,
                    ln_elementwise_affine,
                    use_bias,
                    hidden_dropout,
                    hidden_size,
                )
                for i in range(num_res_blocks)
            ]
        )

        # 创建注意力块列表,包含指定数量的跳跃前馈变换块
        self.attention_blocks = nn.ModuleList(
            [
                SkipFFTransformerBlock(
                    channels,
                    block_num_heads,
                    channels // block_num_heads,
                    hidden_size,
                    use_bias,
                    attention_dropout,
                    channels,
                    attention_bias=use_bias,
                    attention_out_bias=use_bias,
                )
                for _ in range(num_res_blocks)
            ]
        )

        # 如果需要上采样,初始化上采样层
        if upsample:
            self.upsample = Upsample2D(
                channels,
                use_conv_transpose=True,
                kernel_size=2,
                padding=0,
                name="conv",
                norm_type="rms_norm",
                eps=layer_norm_eps,
                elementwise_affine=ln_elementwise_affine,
                bias=use_bias,
                interpolate=False,
            )
        else:
            # 否则将上采样层设为 None
            self.upsample = None
    # 定义前向传播函数,接收输入 x、池化文本嵌入、编码器隐藏状态和交叉注意力参数
    def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs):
        # 如果存在下采样层,则对输入进行下采样
        if self.downsample is not None:
            x = self.downsample(x)
    
        # 遍历残差块和注意力块的组合
        for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
            # 将输入通过残差块进行处理
            x = res_block(x, pooled_text_emb)
    
            # 获取当前输出的批量大小、通道数、高度和宽度
            batch_size, channels, height, width = x.shape
            # 将输出形状调整为 (批量大小, 通道数, 高度 * 宽度),然后转置
            x = x.view(batch_size, channels, height * width).permute(0, 2, 1)
            # 将处理后的输入通过注意力块,并传递编码器隐藏状态和交叉注意力参数
            x = attention_block(
                x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs
            )
            # 将输出转置并恢复为 (批量大小, 通道数, 高度, 宽度) 的形状
            x = x.permute(0, 2, 1).view(batch_size, channels, height, width)
    
        # 如果存在上采样层,则对输出进行上采样
        if self.upsample is not None:
            x = self.upsample(x)
    
        # 返回最终的输出
        return x
# 定义一个卷积块的类,继承自 nn.Module
class ConvNextBlock(nn.Module):
    # 初始化方法,接受多个参数以配置卷积块
    def __init__(
        self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
    ):
        # 调用父类初始化方法
        super().__init__()
        # 定义深度可分离卷积层,通道数为 channels,卷积核大小为 3,使用 padding 保持输入输出相同大小
        self.depthwise = nn.Conv2d(
            channels,
            channels,
            kernel_size=3,
            padding=1,
            groups=channels,  # 进行深度卷积
            bias=use_bias,  # 是否使用偏置
        )
        # 定义 RMSNorm 层,用于规范化,接受通道数和层归一化的 epsilon
        self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
        # 定义第一个线性层,将通道数映射到一个更大的维度
        self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
        # 定义激活函数,使用 GELU
        self.channelwise_act = nn.GELU()
        # 定义全局响应规范化层,输入维度为扩展后的通道数
        self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
        # 定义第二个线性层,将大维度映射回原始通道数
        self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
        # 定义 dropout 层,应用于隐藏层,使用给定的丢弃率
        self.channelwise_dropout = nn.Dropout(hidden_dropout)
        # 定义条件嵌入映射层,将隐层大小映射到两倍的通道数
        self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)

    # 前向传播方法,定义数据如何流经网络
    def forward(self, x, cond_embeds):
        # 保存输入以用于残差连接
        x_res = x

        # 通过深度卷积层处理输入
        x = self.depthwise(x)

        # 调整张量维度,将通道维移至最后
        x = x.permute(0, 2, 3, 1)
        # 对张量进行归一化处理
        x = self.norm(x)

        # 通过第一个线性层
        x = self.channelwise_linear_1(x)
        # 应用激活函数
        x = self.channelwise_act(x)
        # 进行规范化处理
        x = self.channelwise_norm(x)
        # 通过第二个线性层映射回通道数
        x = self.channelwise_linear_2(x)
        # 应用 dropout
        x = self.channelwise_dropout(x)

        # 再次调整张量维度,恢复通道维的位置
        x = x.permute(0, 3, 1, 2)

        # 添加残差连接,将输入与输出相加
        x = x + x_res

        # 通过条件嵌入映射生成缩放和偏移值,使用 SiLU 激活
        scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
        # 应用缩放和偏移调整输出
        x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]

        # 返回处理后的输出
        return x


# 定义一个卷积 MLM 层的类,继承自 nn.Module
class ConvMlmLayer(nn.Module):
    # 初始化方法,接受多个参数以配置卷积 MLM 层
    def __init__(
        self,
        block_out_channels: int,
        in_channels: int,
        use_bias: bool,
        ln_elementwise_affine: bool,
        layer_norm_eps: float,
        codebook_size: int,
    ):
        # 调用父类初始化方法
        super().__init__()
        # 定义第一个卷积层,将 block_out_channels 映射到 in_channels
        self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
        # 定义 RMSNorm 层,用于规范化,接受输入通道数和层归一化的 epsilon
        self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
        # 定义第二个卷积层,将 in_channels 映射到 codebook_size
        self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)

    # 前向传播方法,定义数据如何流经网络
    def forward(self, hidden_states):
        # 通过第一个卷积层处理隐藏状态
        hidden_states = self.conv1(hidden_states)
        # 对输出进行规范化处理,调整维度顺序
        hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        # 通过第二个卷积层生成 logits
        logits = self.conv2(hidden_states)
        # 返回 logits
        return logits

.\diffusers\models\unets\__init__.py

# 从工具模块中导入检测 PyTorch 和 Flax 是否可用的函数
from ...utils import is_flax_available, is_torch_available

# 如果 PyTorch 可用,则导入相应的 UNet 模型
if is_torch_available():
    # 导入一维 UNet 模型
    from .unet_1d import UNet1DModel
    # 导入二维 UNet 模型
    from .unet_2d import UNet2DModel
    # 导入条件二维 UNet 模型
    from .unet_2d_condition import UNet2DConditionModel
    # 导入条件三维 UNet 模型
    from .unet_3d_condition import UNet3DConditionModel
    # 导入 I2VGenXL UNet 模型
    from .unet_i2vgen_xl import I2VGenXLUNet
    # 导入 Kandinsky3 UNet 模型
    from .unet_kandinsky3 import Kandinsky3UNet
    # 导入运动模型适配器和 UNet 运动模型
    from .unet_motion_model import MotionAdapter, UNetMotionModel
    # 导入时空条件的 UNet 模型
    from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
    # 导入稳定级联 UNet 模型
    from .unet_stable_cascade import StableCascadeUNet
    # 导入二维 UVit 模型
    from .uvit_2d import UVit2DModel

# 如果 Flax 可用,则导入相应的条件二维 UNet 模型
if is_flax_available():
    # 导入条件二维 Flax UNet 模型
    from .unet_2d_condition_flax import FlaxUNet2DConditionModel

.\diffusers\models\upsampling.py

# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)授权;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下网址获取许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,否则根据许可证分发的软件是以“原样”基础提供,
# 不提供任何形式的担保或条件,无论是明示或暗示的。
# 有关许可证的特定语言、治理权限和
# 许可证限制,请参阅许可证。

# 从 typing 模块导入 Optional 和 Tuple,用于类型提示
from typing import Optional, Tuple

# 导入 PyTorch 库及其神经网络模块
import torch
import torch.nn as nn
import torch.nn.functional as F

# 从 utils 模块导入 deprecate 装饰器
from ..utils import deprecate
# 从 normalization 模块导入 RMSNorm 类
from .normalization import RMSNorm

# 定义一维上采样层类,继承自 nn.Module
class Upsample1D(nn.Module):
    """一维上采样层,带可选的卷积。

    参数:
        channels (`int`):
            输入和输出的通道数。
        use_conv (`bool`, default `False`):
            是否使用卷积的选项。
        use_conv_transpose (`bool`, default `False`):
            是否使用转置卷积的选项。
        out_channels (`int`, optional):
            输出通道的数量。默认为 `channels`。
        name (`str`, default `conv`):
            一维上采样层的名称。
    """

    # 初始化函数,定义层的参数和卷积层
    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 设置输入通道数
        self.channels = channels
        # 设置输出通道数,默认为输入通道数
        self.out_channels = out_channels or channels
        # 设置是否使用卷积
        self.use_conv = use_conv
        # 设置是否使用转置卷积
        self.use_conv_transpose = use_conv_transpose
        # 设置层的名称
        self.name = name

        # 初始化卷积层为 None
        self.conv = None
        # 如果选择使用转置卷积,则初始化相应的卷积层
        if use_conv_transpose:
            self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
        # 否则,如果选择使用卷积,则初始化卷积层
        elif use_conv:
            self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)

    # 定义前向传播函数,接收输入并返回输出
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 断言输入的通道数与设置的通道数一致
        assert inputs.shape[1] == self.channels
        # 如果选择使用转置卷积,直接返回卷积结果
        if self.use_conv_transpose:
            return self.conv(inputs)

        # 否则,使用最近邻插值法对输入进行上采样
        outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")

        # 如果选择使用卷积,则对上采样结果进行卷积
        if self.use_conv:
            outputs = self.conv(outputs)

        # 返回最终的输出结果
        return outputs

# 定义二维上采样层类,继承自 nn.Module
class Upsample2D(nn.Module):
    """二维上采样层,带可选的卷积。

    参数:
        channels (`int`):
            输入和输出的通道数。
        use_conv (`bool`, default `False`):
            是否使用卷积的选项。
        use_conv_transpose (`bool`, default `False`):
            是否使用转置卷积的选项。
        out_channels (`int`, optional):
            输出通道的数量。默认为 `channels`。
        name (`str`, default `conv`):
            二维上采样层的名称。
    """
    # 初始化方法,用于创建该类的实例
    def __init__(
        # 输入通道数
        self,
        channels: int,
        # 是否使用卷积
        use_conv: bool = False,
        # 是否使用转置卷积
        use_conv_transpose: bool = False,
        # 输出通道数,可选,默认为输入通道数
        out_channels: Optional[int] = None,
        # 模块名称,默认为 "conv"
        name: str = "conv",
        # 卷积核大小,可选
        kernel_size: Optional[int] = None,
        # 填充大小,默认为 1
        padding=1,
        # 归一化类型
        norm_type=None,
        # 归一化时的 epsilon
        eps=None,
        # 是否使用逐元素仿射
        elementwise_affine=None,
        # 是否使用偏置,默认为 True
        bias=True,
        # 是否进行插值,默认为 True
        interpolate=True,
    ):
        # 调用父类的初始化方法
        super().__init__()
        # 保存输入通道数
        self.channels = channels
        # 保存输出通道数,默认为输入通道数
        self.out_channels = out_channels or channels
        # 保存是否使用卷积的标志
        self.use_conv = use_conv
        # 保存是否使用转置卷积的标志
        self.use_conv_transpose = use_conv_transpose
        # 保存模块名称
        self.name = name
        # 保存是否进行插值的标志
        self.interpolate = interpolate

        # 根据归一化类型初始化归一化层
        if norm_type == "ln_norm":
            # 使用层归一化
            self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
        elif norm_type == "rms_norm":
            # 使用 RMS 归一化
            self.norm = RMSNorm(channels, eps, elementwise_affine)
        elif norm_type is None:
            # 不使用归一化
            self.norm = None
        else:
            # 抛出未知归一化类型的错误
            raise ValueError(f"unknown norm_type: {norm_type}")

        # 初始化卷积层
        conv = None
        if use_conv_transpose:
            # 如果使用转置卷积且未指定卷积核大小,则默认为 4
            if kernel_size is None:
                kernel_size = 4
            # 创建转置卷积层
            conv = nn.ConvTranspose2d(
                channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
            )
        elif use_conv:
            # 如果使用卷积且未指定卷积核大小,则默认为 3
            if kernel_size is None:
                kernel_size = 3
            # 创建卷积层
            conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)

        # TODO(Suraj, Patrick) - 在权重字典正确重命名后进行清理
        if name == "conv":
            # 如果名称为 "conv",则保存卷积层
            self.conv = conv
        else:
            # 否则将卷积层保存为另一个属性
            self.Conv2d_0 = conv
    # 前向传播函数,接受隐藏状态和可选输出大小,返回张量
    def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, *args, **kwargs) -> torch.Tensor:
        # 检查是否传入多余参数或废弃的 scale 参数
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            # 设置废弃提示信息,告知用户 scale 参数将来会引发错误
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            # 调用 deprecate 函数,记录废弃信息
            deprecate("scale", "1.0.0", deprecation_message)
    
        # 确保隐藏状态的通道数与当前对象的通道数匹配
        assert hidden_states.shape[1] == self.channels
    
        # 如果存在归一化层,则对隐藏状态进行归一化
        if self.norm is not None:
            hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
    
        # 如果使用转置卷积,则调用卷积层处理隐藏状态
        if self.use_conv_transpose:
            return self.conv(hidden_states)
    
        # 将数据类型转换为 float32,解决 bfloat16 在特定操作中不支持的问题
        # TODO(Suraj): 一旦问题修复,移除此转换
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        # 检查数据类型是否为 bfloat16
        if dtype == torch.bfloat16:
            # 转换为 float32
            hidden_states = hidden_states.to(torch.float32)
    
        # 对于大批量大小的情况,确保数据是连续的
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()
    
        # 如果传入了 output_size,则强制进行插值输出
        # 并不使用 scale_factor=2
        if self.interpolate:
            # 如果没有传入 output_size,则使用默认插值因子
            if output_size is None:
                hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
            else:
                # 使用传入的 output_size 进行插值
                hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
    
        # 如果输入为 bfloat16,转换回 bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)
    
        # TODO(Suraj, Patrick) - 在权重字典正确重命名后进行清理
        if self.use_conv:
            # 如果使用卷积,判断卷积层名称
            if self.name == "conv":
                hidden_states = self.conv(hidden_states)  # 调用常规卷积
            else:
                hidden_states = self.Conv2d_0(hidden_states)  # 调用特定卷积层
    
        # 返回处理后的隐藏状态
        return hidden_states
# 定义一个二维 FIR 上采样层,包含可选的卷积操作
class FirUpsample2D(nn.Module):
    """A 2D FIR upsampling layer with an optional convolution.

    Parameters:
        channels (`int`, optional):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
            kernel for the FIR filter.
    """

    # 初始化 FIR 上采样层
    def __init__(
        self,
        channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        use_conv: bool = False,
        fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
    ):
        # 调用父类构造函数
        super().__init__()
        # 确定输出通道数,如果未提供则使用输入通道数
        out_channels = out_channels if out_channels else channels
        # 如果选择使用卷积,初始化卷积层
        if use_conv:
            self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
        # 保存卷积使用状态、FIR 核心和输出通道数
        self.use_conv = use_conv
        self.fir_kernel = fir_kernel
        self.out_channels = out_channels

    # 定义一个用于 2D 上采样的私有方法
    def _upsample_2d(
        self,
        hidden_states: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        kernel: Optional[torch.Tensor] = None,
        factor: int = 2,
        gain: float = 1,
    ):
        # 定义前向传播方法
        def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
            # 如果使用卷积,执行卷积操作并加上偏置
            if self.use_conv:
                height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
                height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
            # 否则仅执行 FIR 上采样
            else:
                height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)

            # 返回上采样后的高度
            return height


# 定义一个二维 K 上采样层
class KUpsample2D(nn.Module):
    r"""A 2D K-upsampling layer.

    Parameters:
        pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
    """

    # 初始化 K 上采样层,设置填充模式
    def __init__(self, pad_mode: str = "reflect"):
        # 调用父类构造函数
        super().__init__()
        # 保存填充模式
        self.pad_mode = pad_mode
        # 创建一维卷积核,进行标准化
        kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
        # 计算填充大小
        self.pad = kernel_1d.shape[1] // 2 - 1
        # 注册卷积核缓冲区
        self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)

    # 定义前向传播方法
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 根据填充模式填充输入
        inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
        # 初始化权重张量
        weight = inputs.new_zeros(
            [
                inputs.shape[1],
                inputs.shape[1],
                self.kernel.shape[0],
                self.kernel.shape[1],
            ]
        )
        # 创建索引张量
        indices = torch.arange(inputs.shape[1], device=inputs.device)
        # 扩展卷积核
        kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
        # 设置权重的对应索引
        weight[indices, indices] = kernel
        # 返回经过反卷积后的结果
        return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)


# 定义一个三维上采样层
class CogVideoXUpsample3D(nn.Module):
    r"""
    A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
    # 参数说明
    Args:
        in_channels (`int`):
            # 输入图像的通道数
            Number of channels in the input image.
        out_channels (`int`):
            # 卷积操作产生的通道数
            Number of channels produced by the convolution.
        kernel_size (`int`, defaults to `3`):
            # 卷积核的大小,默认值为3
            Size of the convolving kernel.
        stride (`int`, defaults to `1`):
            # 卷积的步幅,默认值为1
            Stride of the convolution.
        padding (`int`, defaults to `1`):
            # 输入数据四周填充的大小,默认值为1
            Padding added to all four sides of the input.
        compress_time (`bool`, defaults to `False`):
            # 是否压缩时间维度,默认值为False
            Whether or not to compress the time dimension.
    """

    # 初始化方法
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        compress_time: bool = False,
    ) -> None:
        # 调用父类的初始化方法
        super().__init__()

        # 定义卷积层
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        # 保存压缩时间的标志
        self.compress_time = compress_time

    # 前向传播方法
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        # 如果需要压缩时间维度
        if self.compress_time:
            # 检查时间维度是否大于1且为奇数
            if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
                # 分离第一个帧
                x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]

                # 对第一个帧进行插值放大
                x_first = F.interpolate(x_first, scale_factor=2.0)
                # 对其余帧进行插值放大
                x_rest = F.interpolate(x_rest, scale_factor=2.0)
                # 增加一个维度
                x_first = x_first[:, :, None, :, :]
                # 合并第一个帧和其余帧
                inputs = torch.cat([x_first, x_rest], dim=2)
            # 如果时间维度大于1
            elif inputs.shape[2] > 1:
                # 对输入进行插值放大
                inputs = F.interpolate(inputs, scale_factor=2.0)
            else:
                # 如果时间维度等于1,进行处理
                inputs = inputs.squeeze(2)
                # 对输入进行插值放大
                inputs = F.interpolate(inputs, scale_factor=2.0)
                # 增加一个维度
                inputs = inputs[:, :, None, :, :]
        else:
            # 仅对2D进行插值处理
            b, c, t, h, w = inputs.shape
            # 重新排列维度,准备卷积操作
            inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
            # 对输入进行插值放大
            inputs = F.interpolate(inputs, scale_factor=2.0)
            # 还原维度顺序
            inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)

        # 再次获取当前形状
        b, c, t, h, w = inputs.shape
        # 重新排列维度,为卷积操作做准备
        inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
        # 通过卷积层处理输入
        inputs = self.conv(inputs)
        # 还原维度顺序
        inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)

        # 返回处理后的输入
        return inputs
# 定义一个用于二维上采样的函数,支持可选的上采样核和因子
def upfirdn2d_native(
    # 输入的张量,通常是图像数据
    tensor: torch.Tensor,
    # 卷积核,决定上采样的滤波效果
    kernel: torch.Tensor,
    # 上采样因子,默认值为1
    up: int = 1,
    # 下采样因子,默认值为1
    down: int = 1,
    # 填充大小,默认不填充
    pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
    # 将上采样因子赋值给x和y方向
    up_x = up_y = up
    # 将下采样因子赋值给x和y方向
    down_x = down_y = down
    # 获取y方向的填充大小
    pad_x0 = pad_y0 = pad[0]
    # 获取y方向的填充大小
    pad_x1 = pad_y1 = pad[1]

    # 获取输入张量的形状信息,包括通道数和高宽
    _, channel, in_h, in_w = tensor.shape
    # 将张量重塑为适合卷积操作的形状
    tensor = tensor.reshape(-1, in_h, in_w, 1)

    # 获取重塑后的张量的形状信息
    _, in_h, in_w, minor = tensor.shape
    # 获取卷积核的高和宽
    kernel_h, kernel_w = kernel.shape

    # 将张量视图转换为适合处理的格式
    out = tensor.view(-1, in_h, 1, in_w, 1, minor)
    # 对张量进行填充以便进行上采样
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    # 重塑填充后的张量为新的形状
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    # 应用额外的填充,以便在后续处理中保持一致
    out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
    # 将张量移动回原始设备(如果需要)
    out = out.to(tensor.device)  # Move back to mps if necessary
    # 应用负填充以调整输出的边界
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    # 重新排列张量的维度,以便于卷积操作
    out = out.permute(0, 3, 1, 2)
    # 重塑输出张量以匹配卷积的输入格式
    out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
    # 翻转卷积核以应用于卷积操作
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    # 执行卷积操作
    out = F.conv2d(out, w)
    # 重塑输出张量的形状以匹配所需的输出格式
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    # 重新排列输出张量的维度
    out = out.permute(0, 2, 3, 1)
    # 根据下采样因子对输出张量进行下采样
    out = out[:, ::down_y, ::down_x, :]

    # 计算输出张量的高和宽
    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    # 返回最终输出张量,调整为指定形状
    return out.view(-1, channel, out_h, out_w)


# 定义用于二维上采样的辅助函数,支持自定义卷积核和因子
def upsample_2d(
    # 输入的张量,通常是图像数据
    hidden_states: torch.Tensor,
    # 可选的卷积核,用于滤波
    kernel: Optional[torch.Tensor] = None,
    # 上采样因子,默认值为2
    factor: int = 2,
    # 信号幅度的缩放因子,默认值为1
    gain: float = 1,
) -> torch.Tensor:
    r"""Upsample2D a batch of 2D images with the given filter.
    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
    filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
    `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
    a: multiple of the upsampling factor.

    Args:
        hidden_states (`torch.Tensor`):
            Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
        kernel (`torch.Tensor`, *optional*):
            FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
            corresponds to nearest-neighbor upsampling.
        factor (`int`, *optional*, default to `2`):
            Integer upsampling factor.
        gain (`float`, *optional*, default to `1.0`):
            Scaling factor for signal magnitude (default: 1.0).

    Returns:
        output (`torch.Tensor`):
            Tensor of the shape `[N, C, H * factor, W * factor]`
    """
    # 确保因子是一个正整数
    assert isinstance(factor, int) and factor >= 1
    # 如果没有提供卷积核,则使用默认的最近邻上采样核
    if kernel is None:
        kernel = [1] * factor
    # 将输入的 kernel 转换为张量,并指定数据类型为 float32
        kernel = torch.tensor(kernel, dtype=torch.float32)
        # 如果 kernel 是一维的,则计算其外积,生成二维卷积核
        if kernel.ndim == 1:
            kernel = torch.outer(kernel, kernel)
        # 将 kernel 归一化,使其所有元素之和为 1
        kernel /= torch.sum(kernel)
    
        # 根据增益和因子调整 kernel 的值
        kernel = kernel * (gain * (factor**2))
        # 计算 padding 的值,用于图像处理
        pad_value = kernel.shape[0] - factor
        # 使用 upfirdn2d_native 函数进行上采样和滤波处理
        output = upfirdn2d_native(
            hidden_states,  # 输入的隐藏状态
            kernel.to(device=hidden_states.device),  # 将 kernel 移动到隐藏状态的设备上
            up=factor,  # 上采样因子
            # 设置 padding,确保输出的尺寸正确
            pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
        )
        # 返回处理后的输出
        return output

.\diffusers\models\vae_flax.py

# 版权所有 2024 The HuggingFace Team. 保留所有权利。
#
# 根据 Apache License, Version 2.0 (以下称为“许可证”) 进行授权;
# 除非遵守该许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证副本:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有规定,软件
# 根据该许可证分发是按“原样”基础进行的,
# 不提供任何形式的明示或暗示的担保或条件。
# 有关许可证规定的权限和限制,请参见许可证。

# JAX 实现 VQGAN,来源于 taming-transformers https://github.com/CompVis/taming-transformers

import math  # 导入数学库,提供数学函数
from functools import partial  # 从 functools 模块导入 partial 函数,用于部分应用函数
from typing import Tuple  # 从 typing 模块导入 Tuple,用于类型注解

import flax  # 导入 flax 库,支持神经网络构建
import flax.linen as nn  # 从 flax 导入 linen 模块,简化神经网络层的创建
import jax  # 导入 jax 库,支持高效数值计算
import jax.numpy as jnp  # 从 jax 导入 numpy 模块,提供类似于 NumPy 的数组操作
from flax.core.frozen_dict import FrozenDict  # 导入 FrozenDict,提供不可变字典的实现

from ..configuration_utils import ConfigMixin, flax_register_to_config  # 导入配置相关的混合类和注册函数
from ..utils import BaseOutput  # 从 utils 模块导入 BaseOutput 基类
from .modeling_flax_utils import FlaxModelMixin  # 从 modeling_flax_utils 导入 FlaxModelMixin


@flax.struct.dataclass  # 使用 flax 的数据类装饰器,自动生成类的初始化和其他方法
class FlaxDecoderOutput(BaseOutput):  # 定义解码器输出类,继承自 BaseOutput
    """
    解码方法的输出。

    参数:
        sample (`jnp.ndarray` 的形状为 `(batch_size, num_channels, height, width)`):
            模型最后一层的解码输出样本。
        dtype (`jnp.dtype`, *可选*, 默认为 `jnp.float32`):
            参数的 `dtype`。
    """

    sample: jnp.ndarray  # 定义解码样本为 jnp.ndarray 类型


@flax.struct.dataclass  # 使用 flax 的数据类装饰器
class FlaxAutoencoderKLOutput(BaseOutput):  # 定义自动编码器 KL 输出类,继承自 BaseOutput
    """
    自动编码器 KL 编码方法的输出。

    参数:
        latent_dist (`FlaxDiagonalGaussianDistribution`):
            编码器的输出表示为 FlaxDiagonalGaussianDistribution 的均值和对数方差。
            `FlaxDiagonalGaussianDistribution` 允许从分布中采样潜在变量。
    """

    latent_dist: "FlaxDiagonalGaussianDistribution"  # 定义潜在分布类型为 FlaxDiagonalGaussianDistribution


class FlaxUpsample2D(nn.Module):  # 定义 2D 上采样层的 Flax 实现,继承自 nn.Module
    """
    Flax 实现的 2D 上采样层

    参数:
        in_channels (`int`):
            输入通道数
        dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):
            参数的 `dtype`
    """

    in_channels: int  # 定义输入通道数为整型
    dtype: jnp.dtype = jnp.float32  # 定义参数类型,默认为 jnp.float32

    def setup(self):  # 定义设置方法,在模块初始化时调用
        self.conv = nn.Conv(  # 创建卷积层
            self.in_channels,  # 设置输入通道数
            kernel_size=(3, 3),  # 设置卷积核大小为 3x3
            strides=(1, 1),  # 设置卷积步幅为 1
            padding=((1, 1), (1, 1)),  # 设置填充方式
            dtype=self.dtype,  # 设置卷积层参数的类型
        )

    def __call__(self, hidden_states):  # 定义模块的前向传播方法
        batch, height, width, channels = hidden_states.shape  # 解包输入形状为批量大小、高度、宽度和通道数
        hidden_states = jax.image.resize(  # 对输入进行上采样
            hidden_states,  # 输入的隐藏状态
            shape=(batch, height * 2, width * 2, channels),  # 设置输出形状为原来的两倍
            method="nearest",  # 使用最近邻插值法进行上采样
        )
        hidden_states = self.conv(hidden_states)  # 通过卷积层处理上采样后的状态
        return hidden_states  # 返回处理后的隐藏状态


class FlaxDownsample2D(nn.Module):  # 定义 2D 下采样层的 Flax 实现,继承自 nn.Module
    """
    Flax 实现的 2D 下采样层
    # 参数说明文档
        Args:
            in_channels (`int`):
                输入通道数
            dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
                参数数据类型
    
        # 声明输入通道数为整数类型
        in_channels: int
        # 声明数据类型,默认值为 jnp.float32
        dtype: jnp.dtype = jnp.float32
    
        # 设置方法,用于初始化卷积层
        def setup(self):
            # 创建卷积层,指定输入通道、卷积核大小、步幅和填充方式
            self.conv = nn.Conv(
                self.in_channels,
                kernel_size=(3, 3),
                strides=(2, 2),
                padding="VALID",
                dtype=self.dtype,
            )
    
        # 调用方法,接收隐藏状态作为输入
        def __call__(self, hidden_states):
            # 定义填充的尺寸,增加高度和宽度维度的边界
            pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # pad height and width dim
            # 对输入的隐藏状态进行填充
            hidden_states = jnp.pad(hidden_states, pad_width=pad)
            # 将填充后的隐藏状态输入卷积层进行处理
            hidden_states = self.conv(hidden_states)
            # 返回处理后的隐藏状态
            return hidden_states
# 定义 Flax 实现的 2D Resnet Block 类,继承自 nn.Module
class FlaxResnetBlock2D(nn.Module):
    """
    Flax 实现的 2D Resnet Block。

    参数:
        in_channels (`int`):
            输入通道数
        out_channels (`int`):
            输出通道数
        dropout (:obj:`float`, *可选*, 默认为 0.0):
            Dropout 率
        groups (:obj:`int`, *可选*, 默认为 `32`):
            用于分组归一化的组数。
        use_nin_shortcut (:obj:`bool`, *可选*, 默认为 `None`):
            是否使用 `nin_shortcut`。这会在 ResNet 块内部激活一个新层
        dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):
            参数数据类型
    """

    # 定义输入通道数,输出通道数,dropout 率,分组数,是否使用 nin_shortcut 和数据类型
    in_channels: int
    out_channels: int = None
    dropout: float = 0.0
    groups: int = 32
    use_nin_shortcut: bool = None
    dtype: jnp.dtype = jnp.float32

    # 设置方法,用于初始化各层
    def setup(self):
        # 如果未指定输出通道数,则设置为输入通道数
        out_channels = self.in_channels if self.out_channels is None else self.out_channels

        # 初始化第一层分组归一化
        self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
        # 初始化第一层卷积
        self.conv1 = nn.Conv(
            out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )

        # 初始化第二层分组归一化
        self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
        # 初始化 dropout 层
        self.dropout_layer = nn.Dropout(self.dropout)
        # 初始化第二层卷积
        self.conv2 = nn.Conv(
            out_channels,
            kernel_size=(3, 3),
            strides=(1, 1),
            padding=((1, 1), (1, 1)),
            dtype=self.dtype,
        )

        # 根据输入和输出通道数判断是否使用 nin_shortcut
        use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut

        # 初始化快捷连接卷积层
        self.conv_shortcut = None
        if use_nin_shortcut:
            # 如果需要使用 nin_shortcut,则初始化其卷积层
            self.conv_shortcut = nn.Conv(
                out_channels,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding="VALID",
                dtype=self.dtype,
            )

    # 前向传播方法,接受隐状态和确定性标志
    def __call__(self, hidden_states, deterministic=True):
        # 保存输入作为残差
        residual = hidden_states
        # 通过第一层归一化
        hidden_states = self.norm1(hidden_states)
        # 应用 Swish 激活函数
        hidden_states = nn.swish(hidden_states)
        # 通过第一层卷积
        hidden_states = self.conv1(hidden_states)

        # 通过第二层归一化
        hidden_states = self.norm2(hidden_states)
        # 再次应用 Swish 激活函数
        hidden_states = nn.swish(hidden_states)
        # 应用 dropout
        hidden_states = self.dropout_layer(hidden_states, deterministic)
        # 通过第二层卷积
        hidden_states = self.conv2(hidden_states)

        # 如果使用快捷连接,则通过卷积层处理残差
        if self.conv_shortcut is not None:
            residual = self.conv_shortcut(residual)

        # 返回隐状态与残差的和
        return hidden_states + residual


# 定义 Flax 实现的基于卷积的多头注意力块类,继承自 nn.Module
class FlaxAttentionBlock(nn.Module):
    r"""
    Flax 基于卷积的多头注意力块,用于扩散模型的 VAE。
    # 定义参数文档
    Parameters:
        channels (:obj:`int`):  # 输入通道数
            Input channels
        num_head_channels (:obj:`int`, *optional*, defaults to `None`):  # 注意力头的数量(可选,默认值为None)
            Number of attention heads
        num_groups (:obj:`int`, *optional*, defaults to `32`):  # 用于组归一化的组数(可选,默认值为32)
            The number of groups to use for group norm
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):  # 参数的数据类型(可选,默认值为jnp.float32)
            Parameters `dtype`

    """

    channels: int  # 定义输入通道数类型
    num_head_channels: int = None  # 定义注意力头的数量,默认值为None
    num_groups: int = 32  # 定义组归一化的组数,默认值为32
    dtype: jnp.dtype = jnp.float32  # 定义参数的数据类型,默认值为jnp.float32

    def setup(self):  # 设置方法
        # 计算注意力头的数量,如果未定义则默认为1
        self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1

        # 定义稠密层的部分,使用指定的通道和数据类型
        dense = partial(nn.Dense, self.channels, dtype=self.dtype)

        # 创建组归一化层,使用指定的组数和小常数
        self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
        # 创建查询、键和值的稠密层
        self.query, self.key, self.value = dense(), dense(), dense()
        # 创建投影注意力的稠密层
        self.proj_attn = dense()

    def transpose_for_scores(self, projection):  # 转置以适应注意力头
        # 定义新的投影形状,插入头的维度
        new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
        # 将头的维度移动到第二个位置(B, T, H * D)->(B, T, H, D)
        new_projection = projection.reshape(new_projection_shape)
        # (B, T, H, D)->(B, H, T, D)
        new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
        return new_projection  # 返回转置后的投影

    def __call__(self, hidden_states):  # 定义调用方法
        residual = hidden_states  # 保存输入的残差
        batch, height, width, channels = hidden_states.shape  # 获取输入的形状

        hidden_states = self.group_norm(hidden_states)  # 对隐藏状态进行组归一化

        # 重新调整隐藏状态的形状以适应注意力机制
        hidden_states = hidden_states.reshape((batch, height * width, channels))

        query = self.query(hidden_states)  # 计算查询
        key = self.key(hidden_states)  # 计算键
        value = self.value(hidden_states)  # 计算值

        # 转置查询、键和值以适应注意力计算
        query = self.transpose_for_scores(query)
        key = self.transpose_for_scores(key)
        value = self.transpose_for_scores(value)

        # 计算注意力权重
        scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))  # 计算缩放因子
        attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)  # 计算注意力权重
        attn_weights = nn.softmax(attn_weights, axis=-1)  # 对注意力权重进行归一化

        # 根据注意力权重聚合值
        hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)

        hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))  # 转置隐藏状态
        new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)  # 定义新的隐藏状态形状
        hidden_states = hidden_states.reshape(new_hidden_states_shape)  # 重新调整形状

        hidden_states = self.proj_attn(hidden_states)  # 通过投影注意力层处理隐藏状态
        hidden_states = hidden_states.reshape((batch, height, width, channels))  # 还原到原始形状
        hidden_states = hidden_states + residual  # 加上残差
        return hidden_states  # 返回处理后的隐藏状态
# 定义一个基于 Flax 和 Resnet 的二维编码器块,用于扩散式变分自编码器
class FlaxDownEncoderBlock2D(nn.Module):
    r"""
    Flax Resnet blocks-based Encoder block for diffusion-based VAE.

    Parameters:
        in_channels (:obj:`int`):
            输入通道数
        out_channels (:obj:`int`):
            输出通道数
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Resnet 层块的数量
        resnet_groups (:obj:`int`, *optional*, defaults to `32`):
            Resnet 块组归一化使用的组数
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            是否添加下采样层
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的数据类型
    """

    # 初始化输入和输出通道、dropout 率、层数、组数、是否添加下采样和数据类型
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    resnet_groups: int = 32
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32

    # 设置函数,用于构建模块的内部结构
    def setup(self):
        # 创建一个空列表用于存放 Resnet 块
        resnets = []
        # 遍历设置的层数,构建 Resnet 块
        for i in range(self.num_layers):
            # 如果是第一层,使用输入通道,否则使用输出通道
            in_channels = self.in_channels if i == 0 else self.out_channels

            # 创建一个 Resnet 块实例
            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout=self.dropout,
                groups=self.resnet_groups,
                dtype=self.dtype,
            )
            # 将创建的 Resnet 块添加到列表中
            resnets.append(res_block)
        # 将所有 Resnet 块存储到实例变量中
        self.resnets = resnets

        # 如果需要添加下采样层,则创建下采样模块
        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)

    # 前向传播函数,用于处理输入的隐藏状态
    def __call__(self, hidden_states, deterministic=True):
        # 依次通过每个 Resnet 块处理隐藏状态
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, deterministic=deterministic)

        # 如果需要下采样,则调用下采样层处理隐藏状态
        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)

        # 返回处理后的隐藏状态
        return hidden_states


# 定义一个基于 Flax 和 Resnet 的二维解码器块,用于扩散式变分自编码器
class FlaxUpDecoderBlock2D(nn.Module):
    r"""
    Flax Resnet blocks-based Decoder block for diffusion-based VAE.

    Parameters:
        in_channels (:obj:`int`):
            输入通道数
        out_channels (:obj:`int`):
            输出通道数
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Resnet 层块的数量
        resnet_groups (:obj:`int`, *optional*, defaults to `32`):
            Resnet 块组归一化使用的组数
        add_upsample (:obj:`bool`, *optional*, defaults to `True`):
            是否添加上采样层
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的数据类型
    """

    # 初始化输入和输出通道、dropout 率、层数、组数、是否添加上采样和数据类型
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    resnet_groups: int = 32
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32
    # 设置方法,初始化 ResNet 模块
        def setup(self):
            # 创建一个空列表,用于存储 ResNet 块
            resnets = []
            # 遍历指定数量的层
            for i in range(self.num_layers):
                # 根据层索引确定输入通道数,第一层使用 in_channels,其余层使用 out_channels
                in_channels = self.in_channels if i == 0 else self.out_channels
                # 创建一个 ResNet 块并初始化其参数
                res_block = FlaxResnetBlock2D(
                    in_channels=in_channels,  # 输入通道数
                    out_channels=self.out_channels,  # 输出通道数
                    dropout=self.dropout,  # dropout 概率
                    groups=self.resnet_groups,  # 组数
                    dtype=self.dtype,  # 数据类型
                )
                # 将创建的 ResNet 块添加到列表中
                resnets.append(res_block)
    
            # 将创建的 ResNet 块列表赋值给实例变量
            self.resnets = resnets
    
            # 如果需要添加上采样层,则初始化上采样层
            if self.add_upsample:
                self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
    
        # 前向传播方法,处理隐藏状态
        def __call__(self, hidden_states, deterministic=True):
            # 逐个通过 ResNet 块处理隐藏状态
            for resnet in self.resnets:
                hidden_states = resnet(hidden_states, deterministic=deterministic)
    
            # 如果需要上采样,则应用上采样层
            if self.add_upsample:
                hidden_states = self.upsamplers_0(hidden_states)
    
            # 返回处理后的隐藏状态
            return hidden_states
# 定义 FlaxUNetMidBlock2D 类,继承自 nn.Module
class FlaxUNetMidBlock2D(nn.Module):
    r"""
    Flax Unet 中间块模块。

    参数:
        in_channels (:obj:`int`):
            输入通道数
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout 率
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Resnet 层块的数量
        resnet_groups (:obj:`int`, *optional*, defaults to `32`):
            Resnet 和注意力块的组归一化使用的组数
        num_attention_heads (:obj:`int`, *optional*, defaults to `1`):
            每个注意力块的注意力头数量
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            参数的数据类型
    """

    # 定义类属性,包含输入通道数、dropout 率等
    in_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    resnet_groups: int = 32
    num_attention_heads: int = 1
    dtype: jnp.dtype = jnp.float32

    # 设置模块的初始化方法
    def setup(self):
        # 计算 Resnet 组数,若未指定则取输入通道数的四分之一与 32 的最小值
        resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)

        # 至少有一个 Resnet 层块
        resnets = [
            FlaxResnetBlock2D(
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout=self.dropout,  # dropout 率
                groups=resnet_groups,  # 组数
                dtype=self.dtype,  # 数据类型
            )
        ]

        # 初始化注意力块列表
        attentions = []

        # 创建多个层块
        for _ in range(self.num_layers):
            # 创建一个注意力块并添加到列表中
            attn_block = FlaxAttentionBlock(
                channels=self.in_channels,  # 通道数
                num_head_channels=self.num_attention_heads,  # 注意力头数量
                num_groups=resnet_groups,  # 组数
                dtype=self.dtype,  # 数据类型
            )
            attentions.append(attn_block)  # 将注意力块添加到列表

            # 创建一个 Resnet 层块并添加到列表中
            res_block = FlaxResnetBlock2D(
                in_channels=self.in_channels,  # 输入通道数
                out_channels=self.in_channels,  # 输出通道数
                dropout=self.dropout,  # dropout 率
                groups=resnet_groups,  # 组数
                dtype=self.dtype,  # 数据类型
            )
            resnets.append(res_block)  # 将 Resnet 层块添加到列表

        # 将生成的 Resnet 层块和注意力块存储为类属性
        self.resnets = resnets
        self.attentions = attentions

    # 定义模块的前向调用方法
    def __call__(self, hidden_states, deterministic=True):
        # 使用第一个 Resnet 层块处理隐藏状态
        hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
        # 遍历注意力块和 Resnet 层块进行处理
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            hidden_states = attn(hidden_states)  # 应用注意力块
            hidden_states = resnet(hidden_states, deterministic=deterministic)  # 应用 Resnet 层块

        # 返回处理后的隐藏状态
        return hidden_states


# 定义 FlaxEncoder 类,继承自 nn.Module
class FlaxEncoder(nn.Module):
    r"""
    Flax 实现的 VAE 编码器。

    该模型是 Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    子类。可将其用作常规 Flax linen 模块,并参考 Flax 文档以获取与
    一般用法和行为相关的所有事项。

    最后,该模型支持固有的 JAX 特性,例如:
    - [即时编译 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    # 自动微分相关链接
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    # 向量化相关链接
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    # 并行化相关链接
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    # 参数说明
    Parameters:
        # 输入通道数,默认为 3
        in_channels (:obj:`int`, *optional*, defaults to 3):
            Input channels
        # 输出通道数,默认为 3
        out_channels (:obj:`int`, *optional*, defaults to 3):
            Output channels
        # 下采样块类型,默认为 `(DownEncoderBlock2D)`
        down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
            DownEncoder block type
        # 每个块的输出通道数元组,默认为 `(64,)`
        block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
            Tuple containing the number of output channels for each block
        # 每个块的 ResNet 层数,默认为 2
        layers_per_block (:obj:`int`, *optional*, defaults to `2`):
            Number of Resnet layer for each block
        # 归一化分组数,默认为 32
        norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
            norm num group
        # 激活函数类型,默认为 `silu`
        act_fn (:obj:`str`, *optional*, defaults to `silu`):
            Activation function
        # 是否将最后的输出通道数加倍,默认为 False
        double_z (:obj:`bool`, *optional*, defaults to `False`):
            Whether to double the last output channels
        # 参数数据类型,默认为 jnp.float32
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """

    # 设置默认输入通道数为 3
    in_channels: int = 3
    # 设置默认输出通道数为 3
    out_channels: int = 3
    # 设置默认下采样块类型
    down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
    # 设置每个块的默认输出通道数
    block_out_channels: Tuple[int] = (64,)
    # 设置每个块的默认层数为 2
    layers_per_block: int = 2
    # 设置默认归一化分组数为 32
    norm_num_groups: int = 32
    # 设置默认激活函数为 "silu"
    act_fn: str = "silu"
    # 设置默认是否加倍输出通道数为 False
    double_z: bool = False
    # 设置默认数据类型为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 设置模型的各个层
    def setup(self):
        # 获取输出通道的数量
        block_out_channels = self.block_out_channels
        # 输入层,定义卷积操作
        self.conv_in = nn.Conv(
            block_out_channels[0],  # 输入通道数
            kernel_size=(3, 3),  # 卷积核大小
            strides=(1, 1),  # 步幅
            padding=((1, 1), (1, 1)),  # 填充方式
            dtype=self.dtype,  # 数据类型
        )

        # 下采样部分
        down_blocks = []  # 初始化下采样块列表
        output_channel = block_out_channels[0]  # 当前输出通道
        for i, _ in enumerate(self.down_block_types):  # 遍历下采样块类型
            input_channel = output_channel  # 当前输入通道
            output_channel = block_out_channels[i]  # 更新输出通道
            is_final_block = i == len(block_out_channels) - 1  # 检查是否为最后一个块

            # 创建下采样块
            down_block = FlaxDownEncoderBlock2D(
                in_channels=input_channel,  # 输入通道数
                out_channels=output_channel,  # 输出通道数
                num_layers=self.layers_per_block,  # 块内层数
                resnet_groups=self.norm_num_groups,  # 归一化组数
                add_downsample=not is_final_block,  # 是否添加下采样
                dtype=self.dtype,  # 数据类型
            )
            down_blocks.append(down_block)  # 将下采样块添加到列表
        self.down_blocks = down_blocks  # 保存下采样块

        # 中间层
        self.mid_block = FlaxUNetMidBlock2D(
            in_channels=block_out_channels[-1],  # 输入通道数为最后一个块的输出通道
            resnet_groups=self.norm_num_groups,  # 归一化组数
            num_attention_heads=None,  # 注意力头数(未使用)
            dtype=self.dtype,  # 数据类型
        )

        # 结束层
        conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels  # 输出通道数
        self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)  # 归一化层
        self.conv_out = nn.Conv(
            conv_out_channels,  # 输出通道数
            kernel_size=(3, 3),  # 卷积核大小
            strides=(1, 1),  # 步幅
            padding=((1, 1), (1, 1)),  # 填充方式
            dtype=self.dtype,  # 数据类型
        )

    # 前向传播方法
    def __call__(self, sample, deterministic: bool = True):
        # 输入层处理
        sample = self.conv_in(sample)  # 对输入样本应用卷积

        # 下采样处理
        for block in self.down_blocks:  # 遍历下采样块
            sample = block(sample, deterministic=deterministic)  # 处理样本

        # 中间层处理
        sample = self.mid_block(sample, deterministic=deterministic)  # 对样本应用中间块

        # 结束层处理
        sample = self.conv_norm_out(sample)  # 应用归一化
        sample = nn.swish(sample)  # 使用 Swish 激活函数
        sample = self.conv_out(sample)  # 应用最后的卷积层

        return sample  # 返回处理后的样本
# 定义 FlaxDecoder 类,继承自 nn.Module
class FlaxDecoder(nn.Module):
    r"""
    Flax 实现的 VAE 解码器。

    该模型是 Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    的子类。可将其作为常规 Flax linen 模块使用,并参考 Flax 文档以了解所有相关的
    使用和行为。

    最后,该模型支持固有的 JAX 特性,例如:
    - [即时编译 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [自动微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [并行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    参数:
        in_channels (:obj:`int`, *可选*, 默认为 3):
            输入通道数
        out_channels (:obj:`int`, *可选*, 默认为 3):
            输出通道数
        up_block_types (:obj:`Tuple[str]`, *可选*, 默认为 `(UpDecoderBlock2D)`):
            UpDecoder 块类型
        block_out_channels (:obj:`Tuple[str]`, *可选*, 默认为 `(64,)`):
            包含每个块输出通道数量的元组
        layers_per_block (:obj:`int`, *可选*, 默认为 `2`):
            每个块的 Resnet 层数量
        norm_num_groups (:obj:`int`, *可选*, 默认为 `32`):
            规范的组数量
        act_fn (:obj:`str`, *可选*, 默认为 `silu`):
            激活函数
        double_z (:obj:`bool`, *可选*, 默认为 `False`):
            是否加倍最后的输出通道数
        dtype (:obj:`jnp.dtype`, *可选*, 默认为 jnp.float32):
            参数的 `dtype`
    """

    # 定义输入通道数,默认为 3
    in_channels: int = 3
    # 定义输出通道数,默认为 3
    out_channels: int = 3
    # 定义 UpDecoder 块类型,默认为一个元组,包含 "UpDecoderBlock2D"
    up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
    # 定义每个块的输出通道数量,默认为一个元组,包含 64
    block_out_channels: int = (64,)
    # 定义每个块的层数,默认为 2
    layers_per_block: int = 2
    # 定义规范的组数量,默认为 32
    norm_num_groups: int = 32
    # 定义激活函数,默认为 "silu"
    act_fn: str = "silu"
    # 定义参数的数据类型,默认为 jnp.float32
    dtype: jnp.dtype = jnp.float32
    # 初始化设置方法
    def setup(self):
        # 获取输出通道数
        block_out_channels = self.block_out_channels

        # 输入层,将 z 转换为 block_in
        self.conv_in = nn.Conv(
            block_out_channels[-1],  # 输入通道数为输出通道数列表的最后一个元素
            kernel_size=(3, 3),  # 卷积核大小为 3x3
            strides=(1, 1),  # 步幅为 1
            padding=((1, 1), (1, 1)),  # 上下左右各填充 1 像素
            dtype=self.dtype,  # 数据类型
        )

        # 中间层
        self.mid_block = FlaxUNetMidBlock2D(
            in_channels=block_out_channels[-1],  # 输入通道数为输出通道数列表的最后一个元素
            resnet_groups=self.norm_num_groups,  # 归一化组数
            num_attention_heads=None,  # 注意力头数设为 None
            dtype=self.dtype,  # 数据类型
        )

        # 上采样
        reversed_block_out_channels = list(reversed(block_out_channels))  # 反转输出通道数列表
        output_channel = reversed_block_out_channels[0]  # 当前输出通道数为反转列表的第一个元素
        up_blocks = []  # 初始化上采样块列表
        for i, _ in enumerate(self.up_block_types):  # 遍历上采样块类型
            prev_output_channel = output_channel  # 保存前一个输出通道数
            output_channel = reversed_block_out_channels[i]  # 更新当前输出通道数

            is_final_block = i == len(block_out_channels) - 1  # 检查是否为最后一个块

            # 创建上采样解码块
            up_block = FlaxUpDecoderBlock2D(
                in_channels=prev_output_channel,  # 输入通道数为前一个输出通道数
                out_channels=output_channel,  # 输出通道数
                num_layers=self.layers_per_block + 1,  # 层数为每个块的层数加一
                resnet_groups=self.norm_num_groups,  # 归一化组数
                add_upsample=not is_final_block,  # 如果不是最后一个块则添加上采样
                dtype=self.dtype,  # 数据类型
            )
            up_blocks.append(up_block)  # 将上采样块添加到列表
            prev_output_channel = output_channel  # 更新前一个输出通道数

        self.up_blocks = up_blocks  # 将上采样块列表赋值给实例变量

        # 结束层
        self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)  # 归一化层
        self.conv_out = nn.Conv(
            self.out_channels,  # 输出通道数
            kernel_size=(3, 3),  # 卷积核大小为 3x3
            strides=(1, 1),  # 步幅为 1
            padding=((1, 1), (1, 1)),  # 上下左右各填充 1 像素
            dtype=self.dtype,  # 数据类型
        )

    # 前向传播方法
    def __call__(self, sample, deterministic: bool = True):
        # 将 z 转换为 block_in
        sample = self.conv_in(sample)  # 通过输入卷积层处理样本

        # 中间层
        sample = self.mid_block(sample, deterministic=deterministic)  # 通过中间块处理样本

        # 上采样
        for block in self.up_blocks:  # 遍历所有上采样块
            sample = block(sample, deterministic=deterministic)  # 处理样本

        sample = self.conv_norm_out(sample)  # 通过归一化层处理样本
        sample = nn.swish(sample)  # 应用 Swish 激活函数
        sample = self.conv_out(sample)  # 通过输出卷积层处理样本

        return sample  # 返回处理后的样本
# 定义一个类表示对角高斯分布
class FlaxDiagonalGaussianDistribution(object):
    # 初始化函数,接受参数和一个可选的确定性标志
    def __init__(self, parameters, deterministic=False):
        # 将参数拆分为均值和对数方差,最后一维用于通道最后的情况
        self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
        # 限制对数方差在-30到20之间
        self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
        # 设置确定性标志
        self.deterministic = deterministic
        # 计算标准差
        self.std = jnp.exp(0.5 * self.logvar)
        # 计算方差
        self.var = jnp.exp(self.logvar)
        # 如果是确定性模式,则将方差和标准差设置为均值的零张量
        if self.deterministic:
            self.var = self.std = jnp.zeros_like(self.mean)

    # 从分布中采样
    def sample(self, key):
        # 使用均值和标准差生成样本
        return self.mean + self.std * jax.random.normal(key, self.mean.shape)

    # 计算KL散度
    def kl(self, other=None):
        # 如果是确定性模式,返回零
        if self.deterministic:
            return jnp.array([0.0])

        # 如果没有提供其他分布,计算与标准正态分布的KL散度
        if other is None:
            return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])

        # 计算两个分布之间的KL散度
        return 0.5 * jnp.sum(
            jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
            axis=[1, 2, 3],
        )

    # 计算负对数似然
    def nll(self, sample, axis=[1, 2, 3]):
        # 如果是确定性模式,返回零
        if self.deterministic:
            return jnp.array([0.0])

        # 计算2π的对数
        logtwopi = jnp.log(2.0 * jnp.pi)
        # 计算负对数似然
        return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)

    # 返回分布的众数
    def mode(self):
        return self.mean


# 使用装饰器将类注册到配置中
@flax_register_to_config
# 定义一个Flax自编码器类,使用KL损失解码潜在表示
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
    r"""
    Flax实现的变分自编码器(VAE)模型,带有KL损失以解码潜在表示。

    该模型继承自[`FlaxModelMixin`]。请查看超类文档以了解所有模型的通用方法
    (如下载或保存)。

    该模型是Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
    子类。将其用作常规Flax Linen模块,并参考Flax文档了解其
    一般用法和行为。

    该模型支持JAX的固有特性,例如以下内容:

    - [即时(JIT)编译](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [自动微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [矢量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [并行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
    # 参数说明
        Parameters:
            in_channels (`int`, *optional*, defaults to 3):  # 输入图像的通道数,默认为3
                Number of channels in the input image.
            out_channels (`int`, *optional*, defaults to 3):  # 输出图像的通道数,默认为3
                Number of channels in the output.
            down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):  # 下采样模块类型的元组,默认为 DownEncoderBlock2D
                Tuple of downsample block types.
            up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):  # 上采样模块类型的元组,默认为 UpDecoderBlock2D
                Tuple of upsample block types.
            block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`):  # 每个模块的输出通道数的元组,默认为 64
                Tuple of block output channels.
            layers_per_block (`int`, *optional*, defaults to `2`):  # 每个模块中的 ResNet 层数,默认为 2
                Number of ResNet layer for each block.
            act_fn (`str`, *optional*, defaults to `silu`):  # 使用的激活函数,默认为 silu
                The activation function to use.
            latent_channels (`int`, *optional*, defaults to `4`):  # 潜在空间中的通道数,默认为 4
                Number of channels in the latent space.
            norm_num_groups (`int`, *optional*, defaults to `32`):  # 归一化的组数,默认为 32
                The number of groups for normalization.
            sample_size (`int`, *optional*, defaults to 32):  # 输入样本的大小,默认为 32
                Sample input size.
            scaling_factor (`float`, *optional*, defaults to 0.18215):  # 用于缩放潜在空间的标准差,默认为 0.18215
                The component-wise standard deviation of the trained latent space computed using the first batch of the
                training set. This is used to scale the latent space to have unit variance when training the diffusion
                model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
                diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
                / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
                Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
            dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):  # 参数的数据类型,默认为 jnp.float32
                The `dtype` of the parameters.
        """  # 结束参数说明
    
        in_channels: int = 3  # 定义输入通道数,默认值为3
        out_channels: int = 3  # 定义输出通道数,默认值为3
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",)  # 定义下采样模块类型,默认使用 DownEncoderBlock2D
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",)  # 定义上采样模块类型,默认使用 UpDecoderBlock2D
        block_out_channels: Tuple[int] = (64,)  # 定义模块输出通道数,默认为 64
        layers_per_block: int = 1  # 定义每个模块的 ResNet 层数,默认为 1
        act_fn: str = "silu"  # 定义激活函数,默认为 silu
        latent_channels: int = 4  # 定义潜在空间通道数,默认为 4
        norm_num_groups: int = 32  # 定义归一化组数,默认为 32
        sample_size: int = 32  # 定义样本输入大小,默认为 32
        scaling_factor: float = 0.18215  # 定义缩放因子,默认为 0.18215
        dtype: jnp.dtype = jnp.float32  # 定义参数数据类型,默认为 jnp.float32
    # 设置模型的编码器和解码器等组件
    def setup(self):
        # 初始化编码器,配置输入和输出通道及其他参数
        self.encoder = FlaxEncoder(
            in_channels=self.config.in_channels,
            out_channels=self.config.latent_channels,
            down_block_types=self.config.down_block_types,
            block_out_channels=self.config.block_out_channels,
            layers_per_block=self.config.layers_per_block,
            act_fn=self.config.act_fn,
            norm_num_groups=self.config.norm_num_groups,
            double_z=True,
            dtype=self.dtype,
        )
        # 初始化解码器,配置输入和输出通道及其他参数
        self.decoder = FlaxDecoder(
            in_channels=self.config.latent_channels,
            out_channels=self.config.out_channels,
            up_block_types=self.config.up_block_types,
            block_out_channels=self.config.block_out_channels,
            layers_per_block=self.config.layers_per_block,
            norm_num_groups=self.config.norm_num_groups,
            act_fn=self.config.act_fn,
            dtype=self.dtype,
        )
        # 初始化量化卷积,配置输入通道和卷积参数
        self.quant_conv = nn.Conv(
            2 * self.config.latent_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )
        # 初始化后量化卷积,配置输入通道和卷积参数
        self.post_quant_conv = nn.Conv(
            self.config.latent_channels,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding="VALID",
            dtype=self.dtype,
        )

    # 初始化权重,返回冻结的参数字典
    def init_weights(self, rng: jax.Array) -> FrozenDict:
        # 初始化输入张量,设置样本形状
        sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
        sample = jnp.zeros(sample_shape, dtype=jnp.float32)

        # 将随机数生成器分割为参数、丢弃和高斯随机数生成器
        params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
        rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}

        # 初始化并返回参数
        return self.init(rngs, sample)["params"]

    # 编码样本,返回潜在分布
    def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
        # 调整样本维度顺序
        sample = jnp.transpose(sample, (0, 2, 3, 1))

        # 使用编码器生成隐藏状态
        hidden_states = self.encoder(sample, deterministic=deterministic)
        # 通过量化卷积处理隐藏状态
        moments = self.quant_conv(hidden_states)
        # 创建潜在分布
        posterior = FlaxDiagonalGaussianDistribution(moments)

        # 根据 return_dict 决定返回的格式
        if not return_dict:
            return (posterior,)

        return FlaxAutoencoderKLOutput(latent_dist=posterior)

    # 解码潜在变量,返回生成的样本
    def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
        # 检查潜在变量的通道数,必要时调整维度顺序
        if latents.shape[-1] != self.config.latent_channels:
            latents = jnp.transpose(latents, (0, 2, 3, 1))

        # 通过后量化卷积处理潜在变量
        hidden_states = self.post_quant_conv(latents)
        # 使用解码器生成隐藏状态
        hidden_states = self.decoder(hidden_states, deterministic=deterministic)

        # 调整隐藏状态维度顺序
        hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))

        # 根据 return_dict 决定返回的格式
        if not return_dict:
            return (hidden_states,)

        return FlaxDecoderOutput(sample=hidden_states)
    # 定义一个可调用的函数,用于处理样本,带有一些可选参数
        def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
            # 编码输入样本,获取后验分布,参数控制编码行为
            posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
            # 如果需要样本后验分布
            if sample_posterior:
                # 创建一个高斯分布的随机数生成器
                rng = self.make_rng("gaussian")
                # 从后验分布中采样隐状态
                hidden_states = posterior.latent_dist.sample(rng)
            else:
                # 获取后验分布的模态值作为隐状态
                hidden_states = posterior.latent_dist.mode()
    
            # 解码隐状态,返回解码后的样本
            sample = self.decode(hidden_states, return_dict=return_dict).sample
    
            # 如果不需要以字典形式返回结果
            if not return_dict:
                # 返回解码后的样本元组
                return (sample,)
    
            # 返回一个包含解码样本的输出对象
            return FlaxDecoderOutput(sample=sample)
posted @ 2024-10-22 12:36  绝不原创的飞龙  阅读(78)  评论(0编辑  收藏  举报