diffusers 源码解析(十四)
.\diffusers\models\unets\unet_2d_blocks_flax.py
# 版权声明,说明该文件的版权信息及相关许可协议
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 许可信息,使用 Apache License 2.0 许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 该文件只能在符合许可证的情况下使用
# you may not use this file except in compliance with the License.
# 提供许可证获取链接
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 免责声明,表明不提供任何形式的保证或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 许可证的相关权限及限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入 flax.linen 模块,用于构建神经网络
import flax.linen as nn
# 导入 jax.numpy,用于数值计算
import jax.numpy as jnp
# 从其他模块导入特定的类,用于构建模型的各个组件
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
# 定义 FlaxCrossAttnDownBlock2D 类,表示一个 2D 跨注意力下采样模块
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
跨注意力 2D 下采样块 - 原始架构来自 Unet transformers:
https://arxiv.org/abs/2103.06104
参数说明:
in_channels (:obj:`int`):
输入通道数
out_channels (:obj:`int`):
输出通道数
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout 率
num_layers (:obj:`int`, *optional*, defaults to 1):
注意力块层数
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
每个空间变换块的注意力头数
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
是否在每个最终输出之前添加下采样层
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
启用内存高效的注意力 https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
是否将头维度拆分为一个新的轴进行自注意力计算。在大多数情况下,
启用此标志应加快 Stable Diffusion 2.x 和 Stable Diffusion XL 的计算速度。
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
参数的数据类型
"""
# 定义输入通道数
in_channels: int
# 定义输出通道数
out_channels: int
# 定义 Dropout 率,默认为 0.0
dropout: float = 0.0
# 定义注意力块的层数,默认为 1
num_layers: int = 1
# 定义注意力头数,默认为 1
num_attention_heads: int = 1
# 定义是否添加下采样层,默认为 True
add_downsample: bool = True
# 定义是否使用线性投影,默认为 False
use_linear_projection: bool = False
# 定义是否仅使用跨注意力,默认为 False
only_cross_attention: bool = False
# 定义是否启用内存高效注意力,默认为 False
use_memory_efficient_attention: bool = False
# 定义是否拆分头维度,默认为 False
split_head_dim: bool = False
# 定义参数的数据类型,默认为 jnp.float32
dtype: jnp.dtype = jnp.float32
# 定义每个块的变换器层数,默认为 1
transformer_layers_per_block: int = 1
# 设置模型的各个组成部分,包括残差块和注意力块
def setup(self):
# 初始化残差块列表
resnets = []
# 初始化注意力块列表
attentions = []
# 遍历每一层,构建残差块和注意力块
for i in range(self.num_layers):
# 第一层的输入通道为 in_channels,其他层为 out_channels
in_channels = self.in_channels if i == 0 else self.out_channels
# 创建一个 FlaxResnetBlock2D 实例
res_block = FlaxResnetBlock2D(
in_channels=in_channels, # 输入通道
out_channels=self.out_channels, # 输出通道
dropout_prob=self.dropout, # 丢弃率
dtype=self.dtype, # 数据类型
)
# 将残差块添加到列表中
resnets.append(res_block)
# 创建一个 FlaxTransformer2DModel 实例
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels, # 输入通道
n_heads=self.num_attention_heads, # 注意力头数
d_head=self.out_channels // self.num_attention_heads, # 每个头的维度
depth=self.transformer_layers_per_block, # 每个块的层数
use_linear_projection=self.use_linear_projection, # 是否使用线性投影
only_cross_attention=self.only_cross_attention, # 是否只使用交叉注意力
use_memory_efficient_attention=self.use_memory_efficient_attention, # 是否使用内存高效的注意力
split_head_dim=self.split_head_dim, # 是否拆分头的维度
dtype=self.dtype, # 数据类型
)
# 将注意力块添加到列表中
attentions.append(attn_block)
# 将残差块列表赋值给实例变量
self.resnets = resnets
# 将注意力块列表赋值给实例变量
self.attentions = attentions
# 如果需要下采样,则创建下采样层
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
# 定义前向调用方法,处理隐藏状态和编码器隐藏状态
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
# 初始化输出状态元组
output_states = ()
# 遍历残差块和注意力块并进行处理
for resnet, attn in zip(self.resnets, self.attentions):
# 通过残差块处理隐藏状态
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 通过注意力块处理隐藏状态
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
# 将当前隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 如果需要下采样,则进行下采样
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
# 将下采样后的隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 返回最终的隐藏状态和输出状态元组
return hidden_states, output_states
# 定义 Flax 2D 降维块类,继承自 nn.Module
class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
# 声明输入输出通道和其他参数
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
# 设置方法,用于初始化模型的层
def setup(self):
# 创建空列表以存储残差块
resnets = []
# 根据层数创建残差块
for i in range(self.num_layers):
# 第一个块的输入通道为 in_channels,其余为 out_channels
in_channels = self.in_channels if i == 0 else self.out_channels
# 创建残差块实例
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
# 将残差块添加到列表中
resnets.append(res_block)
# 将列表赋值给实例属性
self.resnets = resnets
# 如果需要,添加降采样层
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
# 调用方法,执行前向传播
def __call__(self, hidden_states, temb, deterministic=True):
# 创建空元组以存储输出状态
output_states = ()
# 遍历所有残差块进行前向传播
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 将当前隐藏状态添加到输出状态中
output_states += (hidden_states,)
# 如果需要,应用降采样层
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
# 将降采样后的隐藏状态添加到输出状态中
output_states += (hidden_states,)
# 返回最终的隐藏状态和输出状态
return hidden_states, output_states
# 定义 Flax 交叉注意力 2D 上采样块类,继承自 nn.Module
class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
# 定义参数的文档字符串,描述各个参数的用途和类型
Parameters:
in_channels (:obj:`int`): # 输入通道数
Input channels
out_channels (:obj:`int`): # 输出通道数
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0): # Dropout 率,默认值为 0.0
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): # 注意力块的层数,默认值为 1
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1): # 每个空间变换块的注意力头数量,默认值为 1
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`): # 是否在每个最终输出前添加上采样层,默认值为 True
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): # 启用内存高效注意力,默认值为 False
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): # 是否将头维度拆分为新轴以进行自注意力计算,默认值为 False
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): # 数据类型参数,默认值为 jnp.float32
Parameters `dtype`
"""
in_channels: int # 输入通道数的声明
out_channels: int # 输出通道数的声明
prev_output_channel: int # 前一个输出通道数的声明
dropout: float = 0.0 # Dropout 率的声明,默认值为 0.0
num_layers: int = 1 # 注意力层数的声明,默认值为 1
num_attention_heads: int = 1 # 注意力头数量的声明,默认值为 1
add_upsample: bool = True # 是否添加上采样的声明,默认值为 True
use_linear_projection: bool = False # 是否使用线性投影的声明,默认值为 False
only_cross_attention: bool = False # 是否仅使用交叉注意力的声明,默认值为 False
use_memory_efficient_attention: bool = False # 是否启用内存高效注意力的声明,默认值为 False
split_head_dim: bool = False # 是否拆分头维度的声明,默认值为 False
dtype: jnp.dtype = jnp.float32 # 数据类型的声明,默认值为 jnp.float32
transformer_layers_per_block: int = 1 # 每个块的变换层数的声明,默认值为 1
# 设置方法,初始化网络结构
def setup(self):
# 初始化空列表以存储 ResNet 块
resnets = []
# 初始化空列表以存储注意力块
attentions = []
# 遍历每一层以创建相应的 ResNet 和注意力块
for i in range(self.num_layers):
# 设置跳跃连接的通道数,最后一层使用输入通道,否则使用输出通道
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
# 设置当前 ResNet 块的输入通道,第一层使用前一层的输出通道
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
# 创建 FlaxResnetBlock2D 实例
res_block = FlaxResnetBlock2D(
# 设置输入通道为当前 ResNet 块输入通道加跳跃连接通道
in_channels=resnet_in_channels + res_skip_channels,
# 设置输出通道为指定的输出通道
out_channels=self.out_channels,
# 设置 dropout 概率
dropout_prob=self.dropout,
# 设置数据类型
dtype=self.dtype,
)
# 将创建的 ResNet 块添加到列表中
resnets.append(res_block)
# 创建 FlaxTransformer2DModel 实例
attn_block = FlaxTransformer2DModel(
# 设置输入通道为输出通道
in_channels=self.out_channels,
# 设置注意力头数
n_heads=self.num_attention_heads,
# 设置每个注意力头的维度
d_head=self.out_channels // self.num_attention_heads,
# 设置 transformer 块的深度
depth=self.transformer_layers_per_block,
# 设置是否使用线性投影
use_linear_projection=self.use_linear_projection,
# 设置是否仅使用交叉注意力
only_cross_attention=self.only_cross_attention,
# 设置是否使用内存高效的注意力机制
use_memory_efficient_attention=self.use_memory_efficient_attention,
# 设置是否分割头部维度
split_head_dim=self.split_head_dim,
# 设置数据类型
dtype=self.dtype,
)
# 将创建的注意力块添加到列表中
attentions.append(attn_block)
# 将 ResNet 列表保存到实例属性
self.resnets = resnets
# 将注意力列表保存到实例属性
self.attentions = attentions
# 如果需要添加上采样层,则创建相应的 FlaxUpsample2D 实例
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
# 定义调用方法,接受隐藏状态和其他参数
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
# 遍历 ResNet 和注意力块
for resnet, attn in zip(self.resnets, self.attentions):
# 从跳跃连接的隐藏状态元组中取出最后一个状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新跳跃连接的隐藏状态元组,去掉最后一个状态
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 将隐藏状态与跳跃连接的隐藏状态在最后一个轴上拼接
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
# 使用当前的 ResNet 块处理隐藏状态
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 使用当前的注意力块处理隐藏状态
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
# 如果需要添加上采样,则使用上采样层处理隐藏状态
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个 2D 上采样块类,继承自 nn.Module
class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
# 定义输入输出通道和其他参数
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
# 设置方法用于初始化块的结构
def setup(self):
resnets = [] # 创建一个空列表用于存储 ResNet 块
# 遍历每一层,创建 ResNet 块
for i in range(self.num_layers):
# 计算跳跃连接通道数
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
# 设置输入通道数
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
# 创建一个新的 FlaxResnetBlock2D 实例
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block) # 将块添加到列表中
self.resnets = resnets # 将列表赋值给实例变量
# 如果需要上采样,初始化上采样层
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
# 定义前向传播方法
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
# 遍历每个 ResNet 块进行前向传播
for resnet in self.resnets:
# 从元组中弹出最后的残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元组,去掉最后一项
# 连接当前隐藏状态与残差隐藏状态
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
# 通过 ResNet 块处理隐藏状态
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
# 如果需要上采样,调用上采样层
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states # 返回处理后的隐藏状态
# 定义一个 2D 中级交叉注意力块类,继承自 nn.Module
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
# 定义参数的文档字符串
Parameters:
in_channels (:obj:`int`): # 输入通道数
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0): # Dropout比率,默认为0.0
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1): # 注意力层的数量,默认为1
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1): # 每个空间变换块的注意力头数量,默认为1
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): # 是否启用内存高效的注意力机制,默认为False
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`): # 是否将头维度分割为新的轴以加速计算,默认为False
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): # 数据类型参数,默认为jnp.float32
Parameters `dtype`
"""
in_channels: int # 输入通道数的类型
dropout: float = 0.0 # Dropout比率的默认值
num_layers: int = 1 # 注意力层数量的默认值
num_attention_heads: int = 1 # 注意力头数量的默认值
use_linear_projection: bool = False # 是否使用线性投影的默认值
use_memory_efficient_attention: bool = False # 是否使用内存高效注意力的默认值
split_head_dim: bool = False # 是否分割头维度的默认值
dtype: jnp.dtype = jnp.float32 # 数据类型的默认值
transformer_layers_per_block: int = 1 # 每个块中的变换层数量的默认值
def setup(self): # 设置方法,用于初始化
# 至少会有一个ResNet块
resnets = [ # 创建ResNet块列表
FlaxResnetBlock2D( # 创建一个ResNet块
in_channels=self.in_channels, # 输入通道数
out_channels=self.in_channels, # 输出通道数
dropout_prob=self.dropout, # Dropout概率
dtype=self.dtype, # 数据类型
)
]
attentions = [] # 初始化注意力块列表
for _ in range(self.num_layers): # 遍历指定的注意力层数
attn_block = FlaxTransformer2DModel( # 创建一个Transformer块
in_channels=self.in_channels, # 输入通道数
n_heads=self.num_attention_heads, # 注意力头数量
d_head=self.in_channels // self.num_attention_heads, # 每个头的维度
depth=self.transformer_layers_per_block, # 变换层深度
use_linear_projection=self.use_linear_projection, # 是否使用线性投影
use_memory_efficient_attention=self.use_memory_efficient_attention, # 是否使用内存高效注意力
split_head_dim=self.split_head_dim, # 是否分割头维度
dtype=self.dtype, # 数据类型
)
attentions.append(attn_block) # 将注意力块添加到列表中
res_block = FlaxResnetBlock2D( # 创建一个ResNet块
in_channels=self.in_channels, # 输入通道数
out_channels=self.in_channels, # 输出通道数
dropout_prob=self.dropout, # Dropout概率
dtype=self.dtype, # 数据类型
)
resnets.append(res_block) # 将ResNet块添加到列表中
self.resnets = resnets # 将ResNet块列表赋值给实例属性
self.attentions = attentions # 将注意力块列表赋值给实例属性
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): # 调用方法
hidden_states = self.resnets[0](hidden_states, temb) # 通过第一个ResNet块处理隐藏状态
for attn, resnet in zip(self.attentions, self.resnets[1:]): # 遍历每个注意力块和后续ResNet块
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) # 处理隐藏状态
hidden_states = resnet(hidden_states, temb, deterministic=deterministic) # 再次处理隐藏状态
return hidden_states # 返回处理后的隐藏状态
.\diffusers\models\unets\unet_2d_condition.py
# 版权声明,标明版权信息和使用许可
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 按照 Apache License 2.0 版本进行许可
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得在未遵守许可的情况下使用此文件
# you may not use this file except in compliance with the License.
# 你可以在以下网址获得许可证的副本
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非法律要求或书面同意,软件以“按原样”方式分发,不提供任何形式的担保或条件
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 查看许可证以了解特定语言所适用的权限和限制
# See the License for the specific language governing permissions and
# limitations under the License.
# 从 dataclasses 模块导入 dataclass 装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入所需的类型注释
from typing import Any, Dict, List, Optional, Tuple, Union
# 导入 PyTorch 库和相关模块
import torch
import torch.nn as nn
import torch.utils.checkpoint
# 从配置和加载器模块中导入所需的类和函数
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, # 导入与注意力机制相关的处理器
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..embeddings import (
GaussianFourierProjection, # 导入多种嵌入方法
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from ..modeling_utils import ModelMixin # 导入模型混合类
from .unet_2d_blocks import (
get_down_block, # 导入下采样块的构造函数
get_mid_block, # 导入中间块的构造函数
get_up_block, # 导入上采样块的构造函数
)
# 创建一个日志记录器,用于记录模型相关信息
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义 UNet2DConditionOutput 数据类,用于存储 UNet2DConditionModel 的输出
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
UNet2DConditionModel 的输出。
参数:
sample (`torch.Tensor`,形状为 `(batch_size, num_channels, height, width)`):
基于 `encoder_hidden_states` 输入的隐藏状态输出,模型最后一层的输出。
"""
sample: torch.Tensor = None # 定义一个样本属性,默认为 None
# 定义 UNet2DConditionModel 类,表示一个条件 2D UNet 模型
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
一个条件 2D UNet 模型,接受一个噪声样本、条件状态和时间步,并返回样本形状的输出。
该模型继承自 [`ModelMixin`]。查看超类文档以获取其为所有模型实现的通用方法
(例如下载或保存)。
"""
_supports_gradient_checkpointing = True # 表示该模型支持梯度检查点
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] # 不进行拆分的模块列表
@register_to_config # 将该方法注册到配置中
# 初始化方法,设置类的基本属性
def __init__(
# 样本大小,默认为 None
self,
sample_size: Optional[int] = None,
# 输入通道数,默认为 4
in_channels: int = 4,
# 输出通道数,默认为 4
out_channels: int = 4,
# 是否将输入样本中心化,默认为 False
center_input_sample: bool = False,
# 是否将正弦函数翻转为余弦函数,默认为 True
flip_sin_to_cos: bool = True,
# 频率偏移量,默认为 0
freq_shift: int = 0,
# 向下采样的块类型,包含多种块类型
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
# 中间块的类型,默认为 UNet 的中间块类型
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
# 向上采样的块类型,包含多种块类型
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
# 是否仅使用交叉注意力,默认为 False
only_cross_attention: Union[bool, Tuple[bool]] = False,
# 每个块的输出通道数
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
# 每个块的层数,默认为 2
layers_per_block: Union[int, Tuple[int]] = 2,
# 下采样时的填充大小,默认为 1
downsample_padding: int = 1,
# 中间块的缩放因子,默认为 1
mid_block_scale_factor: float = 1,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
# 激活函数类型,默认为 "silu"
act_fn: str = "silu",
# 归一化的组数,默认为 32
norm_num_groups: Optional[int] = 32,
# 归一化的 epsilon 值,默认为 1e-5
norm_eps: float = 1e-5,
# 交叉注意力的维度,默认为 1280
cross_attention_dim: Union[int, Tuple[int]] = 1280,
# 每个块的变换层数,默认为 1
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
# 反向变换层的块数,默认为 None
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
# 编码器隐藏层的维度,默认为 None
encoder_hid_dim: Optional[int] = None,
# 编码器隐藏层类型,默认为 None
encoder_hid_dim_type: Optional[str] = None,
# 注意力头的维度,默认为 8
attention_head_dim: Union[int, Tuple[int]] = 8,
# 注意力头的数量,默认为 None
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
# 是否使用双交叉注意力,默认为 False
dual_cross_attention: bool = False,
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False,
# 类嵌入类型,默认为 None
class_embed_type: Optional[str] = None,
# 附加嵌入类型,默认为 None
addition_embed_type: Optional[str] = None,
# 附加时间嵌入维度,默认为 None
addition_time_embed_dim: Optional[int] = None,
# 类嵌入数量,默认为 None
num_class_embeds: Optional[int] = None,
# 是否上溯注意力,默认为 False
upcast_attention: bool = False,
# ResNet 时间缩放偏移类型,默认为 "default"
resnet_time_scale_shift: str = "default",
# ResNet 是否跳过时间激活,默认为 False
resnet_skip_time_act: bool = False,
# ResNet 输出缩放因子,默认为 1.0
resnet_out_scale_factor: float = 1.0,
# 时间嵌入类型,默认为 "positional"
time_embedding_type: str = "positional",
# 时间嵌入维度,默认为 None
time_embedding_dim: Optional[int] = None,
# 时间嵌入激活函数,默认为 None
time_embedding_act_fn: Optional[str] = None,
# 时间步后激活函数,默认为 None
timestep_post_act: Optional[str] = None,
# 时间条件投影维度,默认为 None
time_cond_proj_dim: Optional[int] = None,
# 输入卷积核大小,默认为 3
conv_in_kernel: int = 3,
# 输出卷积核大小,默认为 3
conv_out_kernel: int = 3,
# 投影类嵌入输入维度,默认为 None
projection_class_embeddings_input_dim: Optional[int] = None,
# 注意力类型,默认为 "default"
attention_type: str = "default",
# 类嵌入是否拼接,默认为 False
class_embeddings_concat: bool = False,
# 中间块是否仅使用交叉注意力,默认为 None
mid_block_only_cross_attention: Optional[bool] = None,
# 交叉注意力归一化类型,默认为 None
cross_attention_norm: Optional[str] = None,
# 附加嵌入类型的头数量,默认为 64
addition_embed_type_num_heads: int = 64,
# 定义一个私有方法,用于检查配置参数
def _check_config(
self,
# 定义下行块类型的元组,表示模型的结构
down_block_types: Tuple[str],
# 定义上行块类型的元组,表示模型的结构
up_block_types: Tuple[str],
# 定义仅使用交叉注意力的标志,可以是布尔值或布尔值的元组
only_cross_attention: Union[bool, Tuple[bool]],
# 定义每个块的输出通道数的元组,表示层的宽度
block_out_channels: Tuple[int],
# 定义每个块的层数,可以是整数或整数的元组
layers_per_block: Union[int, Tuple[int]],
# 定义交叉注意力维度,可以是整数或整数的元组
cross_attention_dim: Union[int, Tuple[int]],
# 定义每个块的变换器层数,可以是整数、整数的元组或元组的元组
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
# 定义是否反转变换器层的布尔值
reverse_transformer_layers_per_block: bool,
# 定义注意力头的维度,表示注意力的分辨率
attention_head_dim: int,
# 定义注意力头的数量,可以是可选的整数或整数的元组
num_attention_heads: Optional[Union[int, Tuple[int]],
):
# 检查 down_block_types 和 up_block_types 的长度是否相同
if len(down_block_types) != len(up_block_types):
# 如果不同,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
# 检查 block_out_channels 和 down_block_types 的长度是否相同
if len(block_out_channels) != len(down_block_types):
# 如果不同,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# 检查 only_cross_attention 是否为布尔值且长度与 down_block_types 相同
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
# 如果不满足条件,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
# 检查 num_attention_heads 是否为整数且长度与 down_block_types 相同
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
# 如果不满足条件,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
# 检查 attention_head_dim 是否为整数且长度与 down_block_types 相同
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
# 如果不满足条件,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
# 检查 cross_attention_dim 是否为列表且长度与 down_block_types 相同
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
# 如果不满足条件,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
# 检查 layers_per_block 是否为整数且长度与 down_block_types 相同
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
# 如果不满足条件,抛出值错误并提供详细信息
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# 检查 transformer_layers_per_block 是否为列表且 reverse_transformer_layers_per_block 为 None
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
# 遍历 transformer_layers_per_block 中的每个层
for layer_number_per_block in transformer_layers_per_block:
# 检查每个层是否为列表
if isinstance(layer_number_per_block, list):
# 如果是,则抛出值错误,提示需要提供 reverse_transformer_layers_per_block
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# 定义设置时间投影的私有方法
def _set_time_proj(
self,
# 时间嵌入类型
time_embedding_type: str,
# 块输出通道数
block_out_channels: int,
# 是否翻转正弦和余弦
flip_sin_to_cos: bool,
# 频率偏移
freq_shift: float,
# 时间嵌入维度
time_embedding_dim: int,
# 返回时间嵌入维度和时间步输入维度的元组
) -> Tuple[int, int]:
# 判断时间嵌入类型是否为傅里叶
if time_embedding_type == "fourier":
# 计算时间嵌入维度,默认为 block_out_channels[0] * 2
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
# 确保时间嵌入维度为偶数
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
# 初始化高斯傅里叶投影,设定相关参数
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
# 设置时间步输入维度为时间嵌入维度
timestep_input_dim = time_embed_dim
# 判断时间嵌入类型是否为位置编码
elif time_embedding_type == "positional":
# 计算时间嵌入维度,默认为 block_out_channels[0] * 4
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
# 初始化时间步对象,设定相关参数
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
# 设置时间步输入维度为 block_out_channels[0]
timestep_input_dim = block_out_channels[0]
# 如果时间嵌入类型不合法,抛出错误
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
# 返回时间嵌入维度和时间步输入维度
return time_embed_dim, timestep_input_dim
# 定义设置编码器隐藏投影的方法
def _set_encoder_hid_proj(
self,
encoder_hid_dim_type: Optional[str],
cross_attention_dim: Union[int, Tuple[int]],
encoder_hid_dim: Optional[int],
):
# 如果编码器隐藏维度类型为空且隐藏维度已定义
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
# 默认将编码器隐藏维度类型设为'text_proj'
encoder_hid_dim_type = "text_proj"
# 注册编码器隐藏维度类型到配置中
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
# 记录信息日志
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
# 如果编码器隐藏维度为空且隐藏维度类型已定义,抛出错误
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
# 判断编码器隐藏维度类型是否为'text_proj'
if encoder_hid_dim_type == "text_proj":
# 初始化线性投影层,输入维度为encoder_hid_dim,输出维度为cross_attention_dim
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
# 判断编码器隐藏维度类型是否为'text_image_proj'
elif encoder_hid_dim_type == "text_image_proj":
# 初始化文本-图像投影对象,设定相关参数
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
# 判断编码器隐藏维度类型是否为'image_proj'
elif encoder_hid_dim_type == "image_proj":
# 初始化图像投影对象,设定相关参数
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
# 如果编码器隐藏维度类型不合法,抛出错误
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
# 如果都不符合,将编码器隐藏投影设为None
else:
self.encoder_hid_proj = None
# 设置类嵌入的私有方法
def _set_class_embedding(
self,
class_embed_type: Optional[str], # 嵌入类型,可能为 None 或特定字符串
act_fn: str, # 激活函数的名称
num_class_embeds: Optional[int], # 类嵌入数量,可能为 None
projection_class_embeddings_input_dim: Optional[int], # 投影类嵌入输入维度,可能为 None
time_embed_dim: int, # 时间嵌入的维度
timestep_input_dim: int, # 时间步输入的维度
):
# 如果嵌入类型为 None 且类嵌入数量不为 None
if class_embed_type is None and num_class_embeds is not None:
# 创建嵌入层,大小为类嵌入数量和时间嵌入维度
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
# 如果嵌入类型为 "timestep"
elif class_embed_type == "timestep":
# 创建时间步嵌入对象
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
# 如果嵌入类型为 "identity"
elif class_embed_type == "identity":
# 创建恒等层,输入和输出维度相同
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
# 如果嵌入类型为 "projection"
elif class_embed_type == "projection":
# 如果投影类嵌入输入维度为 None,抛出错误
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# 创建投影时间步嵌入对象
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# 如果嵌入类型为 "simple_projection"
elif class_embed_type == "simple_projection":
# 如果投影类嵌入输入维度为 None,抛出错误
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
# 创建线性层作为简单投影
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
# 如果没有匹配的嵌入类型
else:
# 将类嵌入设置为 None
self.class_embedding = None
# 设置附加嵌入的私有方法
def _set_add_embedding(
self,
addition_embed_type: str, # 附加嵌入类型
addition_embed_type_num_heads: int, # 附加嵌入类型的头数
addition_time_embed_dim: Optional[int], # 附加时间嵌入维度,可能为 None
flip_sin_to_cos: bool, # 是否翻转正弦到余弦
freq_shift: float, # 频率偏移量
cross_attention_dim: Optional[int], # 交叉注意力维度,可能为 None
encoder_hid_dim: Optional[int], # 编码器隐藏维度,可能为 None
projection_class_embeddings_input_dim: Optional[int], # 投影类嵌入输入维度,可能为 None
time_embed_dim: int, # 时间嵌入维度
):
# 检查附加嵌入类型是否为 "text"
if addition_embed_type == "text":
# 如果编码器隐藏维度不为 None,则使用该维度
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
# 否则使用交叉注意力维度
else:
text_time_embedding_from_dim = cross_attention_dim
# 创建文本时间嵌入对象
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
# 检查附加嵌入类型是否为 "text_image"
elif addition_embed_type == "text_image":
# text_embed_dim 和 image_embed_dim 不必是 `cross_attention_dim`,为了避免 __init__ 过于繁杂
# 在这里设置为 `cross_attention_dim`,因为这是当前唯一使用情况的所需维度 (Kandinsky 2.1)
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
# 检查附加嵌入类型是否为 "text_time"
elif addition_embed_type == "text_time":
# 创建时间投影对象
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
# 创建时间嵌入对象
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
# 检查附加嵌入类型是否为 "image"
elif addition_embed_type == "image":
# Kandinsky 2.2
# 创建图像时间嵌入对象
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
# 检查附加嵌入类型是否为 "image_hint"
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
# 创建图像提示时间嵌入对象
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
# 检查附加嵌入类型是否为 None 以外的值
elif addition_embed_type is not None:
# 抛出值错误,提示无效的附加嵌入类型
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
# 定义一个属性方法,用于设置位置网络
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
# 检查注意力类型是否为 "gated" 或 "gated-text-image"
if attention_type in ["gated", "gated-text-image"]:
positive_len = 768 # 默认的正向长度
# 如果交叉注意力维度是整数,则使用该值
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
# 如果交叉注意力维度是列表或元组,则使用第一个值
elif isinstance(cross_attention_dim, (list, tuple)):
positive_len = cross_attention_dim[0]
# 根据注意力类型确定特征类型
feature_type = "text-only" if attention_type == "gated" else "text-image"
# 创建 GLIGEN 文本边界框投影对象
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)
# 定义一个属性
@property
# 定义一个方法,返回一个字典,包含模型中所有的注意力处理器
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# 初始化一个空字典,用于存储注意力处理器
processors = {}
# 定义一个递归函数,用于添加处理器到字典
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 检查模块是否有获取处理器的方法
if hasattr(module, "get_processor"):
# 将处理器添加到字典中,键为名称,值为处理器
processors[f"{name}.processor"] = module.get_processor()
# 遍历模块的所有子模块
for sub_name, child in module.named_children():
# 递归调用,处理子模块
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
# 返回更新后的处理器字典
return processors
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用递归函数,添加处理器
fn_recursive_add_processors(name, module, processors)
# 返回包含所有处理器的字典
return processors
# 定义一个方法,设置用于计算注意力的处理器
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
# 获取当前处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的是字典,且字典长度与注意力层数量不匹配,抛出错误
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
# 定义一个递归函数,用于设置处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查模块是否有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中弹出对应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的所有子模块
for sub_name, child in module.named_children():
# 递归调用,设置子模块的处理器
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前模块的所有子模块
for name, module in self.named_children():
# 调用递归函数,设置处理器
fn_recursive_attn_processor(name, module, processor)
# 定义设置默认注意力处理器的方法
def set_default_attn_processor(self):
"""
禁用自定义注意力处理器并设置默认的注意力实现。
"""
# 检查所有注意力处理器是否属于添加的键值注意力处理器类
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 创建添加键值注意力处理器的实例
processor = AttnAddedKVProcessor()
# 检查所有注意力处理器是否属于交叉注意力处理器类
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
# 创建标准注意力处理器的实例
processor = AttnProcessor()
else:
# 如果注意力处理器类型不匹配,则抛出错误
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
# 设置选定的注意力处理器
self.set_attn_processor(processor)
# 定义设置梯度检查点的方法
def _set_gradient_checkpointing(self, module, value=False):
# 如果模块具有梯度检查点属性,则设置其值
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# 定义启用 FreeU 机制的方法
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""启用 FreeU 机制,详细信息请见 https://arxiv.org/abs/2309.11497。
在缩放因子后面的后缀表示它们被应用的阶段块。
请参考 [官方仓库](https://github.com/ChenyangSi/FreeU) 以获取已知在不同管道(如 Stable Diffusion v1、v2 和 Stable Diffusion XL)中效果良好的值组合。
参数:
s1 (`float`):
阶段 1 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
s2 (`float`):
阶段 2 的缩放因子,用于减弱跳跃特征的贡献,以减轻增强去噪过程中的“过平滑效应”。
b1 (`float`): 阶段 1 的缩放因子,用于增强骨干特征的贡献。
b2 (`float`): 阶段 2 的缩放因子,用于增强骨干特征的贡献。
"""
# 遍历上采样块并设置相应的缩放因子
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1) # 设置阶段 1 的缩放因子
setattr(upsample_block, "s2", s2) # 设置阶段 2 的缩放因子
setattr(upsample_block, "b1", b1) # 设置阶段 1 的骨干缩放因子
setattr(upsample_block, "b2", b2) # 设置阶段 2 的骨干缩放因子
# 定义禁用 FreeU 机制的方法
def disable_freeu(self):
"""禁用 FreeU 机制。"""
freeu_keys = {"s1", "s2", "b1", "b2"} # 定义 FreeU 相关的键
# 遍历上采样块
for i, upsample_block in enumerate(self.up_blocks):
# 遍历每个 FreeU 键
for k in freeu_keys:
# 如果上采样块具有该键的属性或其值不为 None,则将其值设置为 None
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# 定义一个方法,用于启用融合的 QKV 投影
def fuse_qkv_projections(self):
"""
启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)都被融合。
对于交叉注意力模块,键和值的投影矩阵被融合。
<Tip warning={true}>
此 API 是 🧪 实验性的。
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 遍历注意力处理器,检查是否包含“Added”字样
for _, attn_processor in self.attn_processors.items():
# 如果发现添加的 KV 投影,抛出错误
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
# 保存当前的注意力处理器
self.original_attn_processors = self.attn_processors
# 遍历模块,查找类型为 Attention 的模块
for module in self.modules():
if isinstance(module, Attention):
# 启用投影融合
module.fuse_projections(fuse=True)
# 设置注意力处理器为融合的处理器
self.set_attn_processor(FusedAttnProcessor2_0())
# 定义一个方法,用于禁用已启用的融合 QKV 投影
def unfuse_qkv_projections(self):
"""禁用已启用的融合 QKV 投影。
<Tip warning={true}>
此 API 是 🧪 实验性的。
</Tip>
"""
# 如果原始注意力处理器不为 None,则恢复到原始处理器
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# 定义一个方法,用于获取时间嵌入
def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]:
# 将时间步长赋值给 timesteps
timesteps = timestep
# 如果 timesteps 不是张量
if not torch.is_tensor(timesteps):
# TODO: 这需要在 CPU 和 GPU 之间同步。因此,如果可以的话,尽量将 timesteps 作为张量传递
# 这将是使用 `match` 语句的好例子(Python 3.10+)
is_mps = sample.device.type == "mps" # 检查设备类型是否为 MPS
# 根据时间步长类型设置数据类型
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64 # 浮点数类型
else:
dtype = torch.int32 if is_mps else torch.int64 # 整数类型
# 将 timesteps 转换为张量
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
# 如果 timesteps 是标量(零维张量),则扩展维度
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device) # 增加一个维度并转移到样本设备
# 将 timesteps 广播到与样本批次维度兼容的方式
timesteps = timesteps.expand(sample.shape[0]) # 扩展到批次大小
# 通过时间投影获得时间嵌入
t_emb = self.time_proj(timesteps)
# `Timesteps` 不包含任何权重,总是返回 f32 张量
# 但时间嵌入可能实际在 fp16 中运行,因此需要进行类型转换。
# 可能有更好的方法来封装这一点。
t_emb = t_emb.to(dtype=sample.dtype) # 转换 t_emb 的数据类型
# 返回时间嵌入
return t_emb
# 获取类嵌入的方法,接受样本张量和可选的类标签
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
# 初始化类嵌入为 None
class_emb = None
# 检查类嵌入是否存在
if self.class_embedding is not None:
# 如果类标签为 None,抛出错误
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
# 检查类嵌入类型是否为时间步
if self.config.class_embed_type == "timestep":
# 将类标签通过时间投影处理
class_labels = self.time_proj(class_labels)
# `Timesteps` 不包含权重,总是返回 f32 张量
# 可能有更好的方式来封装这一点
class_labels = class_labels.to(dtype=sample.dtype)
# 获取类嵌入并转换为与样本相同的数据类型
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
# 返回类嵌入
return class_emb
# 获取增强嵌入的方法,接受嵌入张量、编码器隐藏状态和额外条件参数
def get_aug_embed(
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
# 处理编码器隐藏状态的方法,接受编码器隐藏状态和额外条件参数
def process_encoder_hidden_states(
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
# 定义返回类型为 torch.Tensor
) -> torch.Tensor:
# 检查是否存在隐藏层投影,并且配置为 "text_proj"
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
# 使用文本投影对编码隐藏状态进行转换
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
# 检查是否存在隐藏层投影,并且配置为 "text_image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# 检查条件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 抛出错误提示缺少必要参数
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 获取传入的图像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 对编码隐藏状态和图像嵌入进行投影转换
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
# 检查是否存在隐藏层投影,并且配置为 "image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# 检查条件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 抛出错误提示缺少必要参数
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 获取传入的图像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 使用图像嵌入对编码隐藏状态进行投影转换
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 检查是否存在隐藏层投影,并且配置为 "ip_image_proj"
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
# 检查条件中是否包含 "image_embeds"
if "image_embeds" not in added_cond_kwargs:
# 抛出错误提示缺少必要参数
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
# 如果存在文本编码器的隐藏层投影,则对编码隐藏状态进行投影转换
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
# 获取传入的图像嵌入
image_embeds = added_cond_kwargs.get("image_embeds")
# 对图像嵌入进行投影转换
image_embeds = self.encoder_hid_proj(image_embeds)
# 将编码隐藏状态和图像嵌入打包成元组
encoder_hidden_states = (encoder_hidden_states, image_embeds)
# 返回最终的编码隐藏状态
return encoder_hidden_states
# 定义前向传播函数
def forward(
# 输入的样本数据,类型为 PyTorch 张量
sample: torch.Tensor,
# 当前时间步,类型可以是张量、浮点数或整数
timestep: Union[torch.Tensor, float, int],
# 编码器的隐藏状态,类型为 PyTorch 张量
encoder_hidden_states: torch.Tensor,
# 可选的类别标签,类型为 PyTorch 张量
class_labels: Optional[torch.Tensor] = None,
# 可选的时间步条件,类型为 PyTorch 张量
timestep_cond: Optional[torch.Tensor] = None,
# 可选的注意力掩码,类型为 PyTorch 张量
attention_mask: Optional[torch.Tensor] = None,
# 可选的交叉注意力参数,类型为字典,包含额外的关键字参数
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
# 可选的附加条件参数,类型为字典,键为字符串,值为 PyTorch 张量
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
# 可选的下层块附加残差,类型为元组,包含 PyTorch 张量
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可选的中间块附加残差,类型为 PyTorch 张量
mid_block_additional_residual: Optional[torch.Tensor] = None,
# 可选的下层内部块附加残差,类型为元组,包含 PyTorch 张量
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
# 可选的编码器注意力掩码,类型为 PyTorch 张量
encoder_attention_mask: Optional[torch.Tensor] = None,
# 返回结果的标志,布尔值,默认值为 True
return_dict: bool = True,
.\diffusers\models\unets\unet_2d_condition_flax.py
# 版权声明,表明该文件的版权所有者及相关信息
#
# 根据 Apache License 2.0 版本的许可协议
# 除非遵守该许可协议,否则不得使用本文件
# 可以在以下地址获取许可证副本
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件在“按现状”基础上分发,
# 不提供任何明示或暗示的保证或条件
# 请参阅许可证了解管理权限和限制的具体条款
from typing import Dict, Optional, Tuple, Union # 从 typing 模块导入类型注释工具
import flax # 导入 flax 库用于构建神经网络
import flax.linen as nn # 从 flax 中导入 linen 模块,方便定义神经网络层
import jax # 导入 jax 库用于高效数值计算
import jax.numpy as jnp # 导入 jax 的 numpy 模块,提供张量操作功能
from flax.core.frozen_dict import FrozenDict # 从 flax 导入 FrozenDict,用于不可变字典
from ...configuration_utils import ConfigMixin, flax_register_to_config # 导入配置相关工具
from ...utils import BaseOutput # 导入基础输出类
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps # 导入时间步嵌入相关类
from ..modeling_flax_utils import FlaxModelMixin # 导入模型混合类
from .unet_2d_blocks_flax import ( # 导入 UNet 的不同构建块
FlaxCrossAttnDownBlock2D, # 导入交叉注意力下采样块
FlaxCrossAttnUpBlock2D, # 导入交叉注意力上采样块
FlaxDownBlock2D, # 导入下采样块
FlaxUNetMidBlock2DCrossAttn, # 导入中间块,带有交叉注意力
FlaxUpBlock2D, # 导入上采样块
)
@flax.struct.dataclass # 使用 flax 的数据类装饰器
class FlaxUNet2DConditionOutput(BaseOutput): # 定义 UNet 条件输出类,继承自基础输出类
"""
[`FlaxUNet2DConditionModel`] 的输出。
参数:
sample (`jnp.ndarray` 的形状为 `(batch_size, num_channels, height, width)`):
基于 `encoder_hidden_states` 输入的隐藏状态输出。模型最后一层的输出。
"""
sample: jnp.ndarray # 定义输出样本,数据类型为 jnp.ndarray
@flax_register_to_config # 使用装饰器将模型注册到配置中
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): # 定义条件 UNet 模型类,继承多个混合类
r"""
一个条件 2D UNet 模型,接收噪声样本、条件状态和时间步,并返回样本形状的输出。
此模型继承自 [`FlaxModelMixin`]。请查看超类文档以了解其通用方法
(例如下载或保存)。
此模型也是 Flax Linen 的 [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
子类。将其作为常规 Flax Linen 模块使用,具体使用和行为请参阅 Flax 文档。
支持以下 JAX 特性:
- [即时编译 (JIT)](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [自动微分](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [向量化](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [并行化](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
# 参数说明部分
Parameters:
# 输入样本的大小,类型为整型,选填参数
sample_size (`int`, *optional*):
The size of the input sample.
# 输入样本的通道数,类型为整型,默认为4
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
# 输出的通道数,类型为整型,默认为4
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
# 使用的下采样块的元组,类型为字符串元组,默认为特定的下采样块
down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`):
The tuple of downsample blocks to use.
# 使用的上采样块的元组,类型为字符串元组,默认为特定的上采样块
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
# UNet中间块的类型,类型为字符串,默认为"UNetMidBlock2DCrossAttn"
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
is skipped.
# 每个块的输出通道的元组,类型为整型元组,默认为特定的输出通道
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
# 每个块的层数,类型为整型,默认为2
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
# 注意力头的维度,可以是整型或整型元组,默认为8
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
# 注意力头的数量,可以是整型或整型元组,选填参数
num_attention_heads (`int` or `Tuple[int]`, *optional*):
The number of attention heads.
# 交叉注意力特征的维度,类型为整型,默认为768
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
# dropout的概率,类型为浮点数,默认为0
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
# 是否在时间嵌入中将正弦转换为余弦,类型为布尔值,默认为True
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
# 应用于时间嵌入的频率偏移,类型为整型,默认为0
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
# 是否启用内存高效的注意力机制,类型为布尔值,默认为False
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
# 是否将头维度拆分为新的轴进行自注意力计算,类型为布尔值,默认为False
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
# 定义样本大小,默认为32
sample_size: int = 32
# 定义输入通道数,默认为4
in_channels: int = 4
# 定义输出通道数,默认为4
out_channels: int = 4
# 定义下采样块的类型元组
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", # 第一个下采样块
"CrossAttnDownBlock2D", # 第二个下采样块
"CrossAttnDownBlock2D", # 第三个下采样块
"DownBlock2D", # 第四个下采样块
)
# 定义上采样块的类型元组
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
# 定义中间块类型,默认为"UNetMidBlock2DCrossAttn"
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
# 定义是否只使用交叉注意力,默认为False
only_cross_attention: Union[bool, Tuple[bool]] = False
# 定义每个块的输出通道元组
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
# 每个块的层数设为 2
layers_per_block: int = 2
# 注意力头的维度设为 8
attention_head_dim: Union[int, Tuple[int, ...]] = 8
# 可选的注意力头数量,默认为 None
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
# 跨注意力的维度设为 1280
cross_attention_dim: int = 1280
# dropout 比率设为 0.0
dropout: float = 0.0
# 是否使用线性投影,默认为 False
use_linear_projection: bool = False
# 数据类型设为 float32
dtype: jnp.dtype = jnp.float32
# flip_sin_to_cos 设为 True
flip_sin_to_cos: bool = True
# 频移设为 0
freq_shift: int = 0
# 是否使用内存高效的注意力,默认为 False
use_memory_efficient_attention: bool = False
# 是否拆分头维度,默认为 False
split_head_dim: bool = False
# 每个块的变换层数设为 1
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
# 可选的附加嵌入类型,默认为 None
addition_embed_type: Optional[str] = None
# 可选的附加时间嵌入维度,默认为 None
addition_time_embed_dim: Optional[int] = None
# 附加嵌入类型的头数量设为 64
addition_embed_type_num_heads: int = 64
# 可选的投影类嵌入输入维度,默认为 None
projection_class_embeddings_input_dim: Optional[int] = None
# 初始化权重函数,接受随机数生成器作为参数
def init_weights(self, rng: jax.Array) -> FrozenDict:
# 初始化输入张量的形状
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
# 创建全零的输入样本
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
# 创建全一的时间步张量
timesteps = jnp.ones((1,), dtype=jnp.int32)
# 初始化编码器的隐藏状态
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
# 分割随机数生成器,用于参数和 dropout
params_rng, dropout_rng = jax.random.split(rng)
# 创建随机数字典
rngs = {"params": params_rng, "dropout": dropout_rng}
# 初始化附加条件关键字参数
added_cond_kwargs = None
# 判断嵌入类型是否为 "text_time"
if self.addition_embed_type == "text_time":
# 通过反向计算获取期望的文本嵌入维度
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
# 确定微条件的数量
num_micro_conditions = 5 if is_refiner else 6
# 计算文本嵌入维度
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)
# 计算时间 ID 的通道数和维度
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
# 创建附加条件关键字参数字典
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
# 返回初始化后的参数字典
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
# 定义调用函数,接收多个输入参数
def __call__(
self,
sample: jnp.ndarray,
timesteps: Union[jnp.ndarray, float, int],
encoder_hidden_states: jnp.ndarray,
# 可选的附加条件关键字参数
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
# 可选的下块附加残差
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
# 可选的中块附加残差
mid_block_additional_residual: Optional[jnp.ndarray] = None,
# 是否返回字典,默认为 True
return_dict: bool = True,
# 是否为训练模式,默认为 False
train: bool = False,
.\diffusers\models\unets\unet_3d_blocks.py
# 版权所有 2024 HuggingFace 团队。所有权利保留。
#
# 根据 Apache 许可证 2.0 版本("许可证")授权;
# 除非遵守许可证,否则您不得使用此文件。
# 您可以在以下位置获得许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件是按“原样”基础分发,
# 不提供任何形式的保证或条件,无论是明示或暗示。
# 有关许可证的具体权限和限制,请参阅许可证。
# 导入类型提示中的任何类型
from typing import Any, Dict, Optional, Tuple, Union
# 导入 PyTorch 库
import torch
# 从 PyTorch 导入神经网络模块
from torch import nn
# 导入实用工具函数,包括弃用和日志记录
from ...utils import deprecate, is_torch_version, logging
# 导入 PyTorch 相关的工具函数
from ...utils.torch_utils import apply_freeu
# 导入注意力机制相关的类
from ..attention import Attention
# 导入 ResNet 相关的类
from ..resnet import (
Downsample2D, # 导入 2D 下采样模块
ResnetBlock2D, # 导入 2D ResNet 块
SpatioTemporalResBlock, # 导入时空 ResNet 块
TemporalConvLayer, # 导入时间卷积层
Upsample2D, # 导入 2D 上采样模块
)
# 导入 2D 变换器模型
from ..transformers.transformer_2d import Transformer2DModel
# 导入时间相关的变换器模型
from ..transformers.transformer_temporal import (
TransformerSpatioTemporalModel, # 导入时空变换器模型
TransformerTemporalModel, # 导入时间变换器模型
)
# 导入运动模型的 UNet 相关类
from .unet_motion_model import (
CrossAttnDownBlockMotion, # 导入交叉注意力下块运动类
CrossAttnUpBlockMotion, # 导入交叉注意力上块运动类
DownBlockMotion, # 导入下块运动类
UNetMidBlockCrossAttnMotion, # 导入中间块交叉注意力运动类
UpBlockMotion, # 导入上块运动类
)
# 创建一个日志记录器,使用当前模块的名称
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# 定义 DownBlockMotion 类,继承自 DownBlockMotion
class DownBlockMotion(DownBlockMotion):
# 初始化方法,接受任意参数和关键字参数
def __init__(self, *args, **kwargs):
# 设置弃用消息,提醒用户变更
deprecation_message = "Importing `DownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import DownBlockMotion` instead."
# 调用弃用函数,记录弃用信息
deprecate("DownBlockMotion", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 定义 CrossAttnDownBlockMotion 类,继承自 CrossAttnDownBlockMotion
class CrossAttnDownBlockMotion(CrossAttnDownBlockMotion):
# 初始化方法,接受任意参数和关键字参数
def __init__(self, *args, **kwargs):
# 设置弃用消息,提醒用户变更
deprecation_message = "Importing `CrossAttnDownBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnDownBlockMotion` instead."
# 调用弃用函数,记录弃用信息
deprecate("CrossAttnDownBlockMotion", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 定义 UpBlockMotion 类,继承自 UpBlockMotion
class UpBlockMotion(UpBlockMotion):
# 初始化方法,接受任意参数和关键字参数
def __init__(self, *args, **kwargs):
# 设置弃用消息,提醒用户变更
deprecation_message = "Importing `UpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UpBlockMotion` instead."
# 调用弃用函数,记录弃用信息
deprecate("UpBlockMotion", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 定义 CrossAttnUpBlockMotion 类,继承自 CrossAttnUpBlockMotion
class CrossAttnUpBlockMotion(CrossAttnUpBlockMotion):
# 初始化方法,用于创建类的实例
def __init__(self, *args, **kwargs):
# 定义一个关于导入的弃用警告信息
deprecation_message = "Importing `CrossAttnUpBlockMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import CrossAttnUpBlockMotion` instead."
# 调用弃用警告函数,记录该功能的弃用信息及版本
deprecate("CrossAttnUpBlockMotion", "1.0.0", deprecation_message)
# 调用父类的初始化方法,传递参数以初始化父类部分
super().__init__(*args, **kwargs)
# 定义一个名为 UNetMidBlockCrossAttnMotion 的类,继承自同名父类
class UNetMidBlockCrossAttnMotion(UNetMidBlockCrossAttnMotion):
# 初始化方法,接收可变参数和关键字参数
def __init__(self, *args, **kwargs):
# 定义弃用警告信息,提示用户更新导入路径
deprecation_message = "Importing `UNetMidBlockCrossAttnMotion` from `diffusers.models.unets.unet_3d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_motion_model import UNetMidBlockCrossAttnMotion` instead."
# 触发弃用警告
deprecate("UNetMidBlockCrossAttnMotion", "1.0.0", deprecation_message)
# 调用父类的初始化方法
super().__init__(*args, **kwargs)
# 定义一个函数,返回不同类型的下采样块
def get_down_block(
# 定义参数,类型和含义
down_block_type: str, # 下采样块的类型
num_layers: int, # 层数
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
add_downsample: bool, # 是否添加下采样
resnet_eps: float, # ResNet 的 epsilon 参数
resnet_act_fn: str, # ResNet 的激活函数
num_attention_heads: int, # 注意力头数
resnet_groups: Optional[int] = None, # ResNet 的分组数,可选
cross_attention_dim: Optional[int] = None, # 交叉注意力维度,可选
downsample_padding: Optional[int] = None, # 下采样填充,可选
dual_cross_attention: bool = False, # 是否使用双重交叉注意力
use_linear_projection: bool = True, # 是否使用线性投影
only_cross_attention: bool = False, # 是否仅使用交叉注意力
upcast_attention: bool = False, # 是否提升注意力精度
resnet_time_scale_shift: str = "default", # ResNet 时间尺度偏移
temporal_num_attention_heads: int = 8, # 时间注意力头数
temporal_max_seq_length: int = 32, # 时间序列最大长度
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每个块的变换器层数
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 时间变换器每块层数
dropout: float = 0.0, # dropout 概率
) -> Union[
"DownBlock3D", # 返回的可能类型之一:3D 下采样块
"CrossAttnDownBlock3D", # 返回的可能类型之二:交叉注意力下采样块
"DownBlockSpatioTemporal", # 返回的可能类型之三:时空下采样块
"CrossAttnDownBlockSpatioTemporal", # 返回的可能类型之四:时空交叉注意力下采样块
]:
# 检查下采样块类型是否为 DownBlock3D
if down_block_type == "DownBlock3D":
# 创建并返回 DownBlock3D 实例
return DownBlock3D(
num_layers=num_layers, # 传入层数
in_channels=in_channels, # 传入输入通道数
out_channels=out_channels, # 传入输出通道数
temb_channels=temb_channels, # 传入时间嵌入通道数
add_downsample=add_downsample, # 传入是否添加下采样
resnet_eps=resnet_eps, # 传入 ResNet 的 epsilon 参数
resnet_act_fn=resnet_act_fn, # 传入激活函数
resnet_groups=resnet_groups, # 传入分组数
downsample_padding=downsample_padding, # 传入下采样填充
resnet_time_scale_shift=resnet_time_scale_shift, # 传入时间尺度偏移
dropout=dropout, # 传入 dropout 概率
)
# 检查下采样块类型是否为 CrossAttnDownBlock3D
elif down_block_type == "CrossAttnDownBlock3D":
# 如果交叉注意力维度未指定,抛出错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
# 创建并返回 CrossAttnDownBlock3D 实例
return CrossAttnDownBlock3D(
num_layers=num_layers, # 传入层数
in_channels=in_channels, # 传入输入通道数
out_channels=out_channels, # 传入输出通道数
temb_channels=temb_channels, # 传入时间嵌入通道数
add_downsample=add_downsample, # 传入是否添加下采样
resnet_eps=resnet_eps, # 传入 ResNet 的 epsilon 参数
resnet_act_fn=resnet_act_fn, # 传入激活函数
resnet_groups=resnet_groups, # 传入分组数
downsample_padding=downsample_padding, # 传入下采样填充
cross_attention_dim=cross_attention_dim, # 传入交叉注意力维度
num_attention_heads=num_attention_heads, # 传入注意力头数
dual_cross_attention=dual_cross_attention, # 传入是否使用双重交叉注意力
use_linear_projection=use_linear_projection, # 传入是否使用线性投影
only_cross_attention=only_cross_attention, # 传入是否仅使用交叉注意力
upcast_attention=upcast_attention, # 传入是否提升注意力精度
resnet_time_scale_shift=resnet_time_scale_shift, # 传入时间尺度偏移
dropout=dropout, # 传入 dropout 概率
)
# 检查下一个块的类型是否为时空下采样块
elif down_block_type == "DownBlockSpatioTemporal":
# 为 SDV 进行了添加
# 返回一个时空下采样块的实例
return DownBlockSpatioTemporal(
# 设置层数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 是否添加下采样
add_downsample=add_downsample,
)
# 检查下一个块的类型是否为交叉注意力时空下采样块
elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
# 为 SDV 进行了添加
# 如果没有指定交叉注意力维度,抛出错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
# 返回一个交叉注意力时空下采样块的实例
return CrossAttnDownBlockSpatioTemporal(
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 设置层数
num_layers=num_layers,
# 每个块的变换层数
transformer_layers_per_block=transformer_layers_per_block,
# 是否添加下采样
add_downsample=add_downsample,
# 设置交叉注意力维度
cross_attention_dim=cross_attention_dim,
# 注意力头数
num_attention_heads=num_attention_heads,
)
# 如果块类型不匹配,则抛出错误
raise ValueError(f"{down_block_type} does not exist.")
# 定义函数 get_up_block,返回不同类型的上采样模块
def get_up_block(
# 上采样块类型
up_block_type: str,
# 层数
num_layers: int,
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 上一层的输出通道数
prev_output_channel: int,
# 时间嵌入通道数
temb_channels: int,
# 是否添加上采样
add_upsample: bool,
# ResNet 的 epsilon 值
resnet_eps: float,
# ResNet 的激活函数类型
resnet_act_fn: str,
# 注意力头的数量
num_attention_heads: int,
# 分辨率索引(可选)
resolution_idx: Optional[int] = None,
# ResNet 组数(可选)
resnet_groups: Optional[int] = None,
# 跨注意力维度(可选)
cross_attention_dim: Optional[int] = None,
# 是否使用双重跨注意力
dual_cross_attention: bool = False,
# 是否使用线性投影
use_linear_projection: bool = True,
# 是否仅使用跨注意力
only_cross_attention: bool = False,
# 是否提升注意力计算
upcast_attention: bool = False,
# ResNet 时间尺度移位的设置
resnet_time_scale_shift: str = "default",
# 时间上的注意力头数量
temporal_num_attention_heads: int = 8,
# 时间上的跨注意力维度(可选)
temporal_cross_attention_dim: Optional[int] = None,
# 时间序列最大长度
temporal_max_seq_length: int = 32,
# 每个块的变换器层数
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 每个块的时间变换器层数
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# dropout 概率
dropout: float = 0.0,
) -> Union[
# 返回类型为不同的上采样块
"UpBlock3D",
"CrossAttnUpBlock3D",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
]:
# 判断上采样块类型是否为 UpBlock3D
if up_block_type == "UpBlock3D":
# 创建并返回 UpBlock3D 实例
return UpBlock3D(
# 传入参数设置
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
# 判断上采样块类型是否为 CrossAttnUpBlock3D
elif up_block_type == "CrossAttnUpBlock3D":
# 检查是否提供跨注意力维度
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
# 创建并返回 CrossAttnUpBlock3D 实例
return CrossAttnUpBlock3D(
# 传入参数设置
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
dropout=dropout,
)
# 检查上升块类型是否为 "UpBlockSpatioTemporal"
elif up_block_type == "UpBlockSpatioTemporal":
# 为 SDV 添加的内容
# 返回 UpBlockSpatioTemporal 实例,使用指定的参数
return UpBlockSpatioTemporal(
# 层数参数
num_layers=num_layers,
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 前一个输出通道数
prev_output_channel=prev_output_channel,
# 时间嵌入通道数
temb_channels=temb_channels,
# 分辨率索引
resolution_idx=resolution_idx,
# 是否添加上采样
add_upsample=add_upsample,
)
# 检查上升块类型是否为 "CrossAttnUpBlockSpatioTemporal"
elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
# 为 SDV 添加的内容
# 如果没有指定交叉注意力维度,抛出错误
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
# 返回 CrossAttnUpBlockSpatioTemporal 实例,使用指定的参数
return CrossAttnUpBlockSpatioTemporal(
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 前一个输出通道数
prev_output_channel=prev_output_channel,
# 时间嵌入通道数
temb_channels=temb_channels,
# 层数参数
num_layers=num_layers,
# 每个块的变换层数
transformer_layers_per_block=transformer_layers_per_block,
# 是否添加上采样
add_upsample=add_upsample,
# 交叉注意力维度
cross_attention_dim=cross_attention_dim,
# 注意力头数
num_attention_heads=num_attention_heads,
# 分辨率索引
resolution_idx=resolution_idx,
)
# 如果上升块类型不符合任何已知类型,抛出错误
raise ValueError(f"{up_block_type} does not exist.")
# 定义一个名为 UNetMidBlock3DCrossAttn 的类,继承自 nn.Module
class UNetMidBlock3DCrossAttn(nn.Module):
# 初始化方法,设置类的参数
def __init__(
self,
in_channels: int, # 输入通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 中的小常数,避免除零
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移
resnet_act_fn: str = "swish", # ResNet 激活函数类型
resnet_groups: int = 32, # ResNet 分组数
resnet_pre_norm: bool = True, # 是否在 ResNet 前进行归一化
num_attention_heads: int = 1, # 注意力头数
output_scale_factor: float = 1.0, # 输出缩放因子
cross_attention_dim: int = 1280, # 交叉注意力维度
dual_cross_attention: bool = False, # 是否使用双交叉注意力
use_linear_projection: bool = True, # 是否使用线性投影
upcast_attention: bool = False, # 是否使用上采样注意力
):
# 省略具体初始化代码,通常会在这里初始化各个层和参数
# 定义前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
num_frames: int = 1, # 帧数
cross_attention_kwargs: Optional[Dict[str, Any]] = None, # 交叉注意力的额外参数
) -> torch.Tensor: # 返回一个张量
# 通过第一个 ResNet 层处理隐藏状态
hidden_states = self.resnets[0](hidden_states, temb)
# 通过第一个时间卷积层处理隐藏状态
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
# 遍历所有的注意力层、时间注意力层、ResNet 层和时间卷积层
for attn, temp_attn, resnet, temp_conv in zip(
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
):
# 通过当前注意力层处理隐藏状态
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 只取返回的第一个元素
# 通过当前时间注意力层处理隐藏状态
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 只取返回的第一个元素
# 通过当前 ResNet 层处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过当前时间卷积层处理隐藏状态
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 CrossAttnDownBlock3D 的类,继承自 nn.Module
class CrossAttnDownBlock3D(nn.Module):
# 初始化方法,设置类的参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 中的小常数,避免除零
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移
resnet_act_fn: str = "swish", # ResNet 激活函数类型
resnet_groups: int = 32, # ResNet 分组数
resnet_pre_norm: bool = True, # 是否在 ResNet 前进行归一化
num_attention_heads: int = 1, # 注意力头数
cross_attention_dim: int = 1280, # 交叉注意力维度
output_scale_factor: float = 1.0, # 输出缩放因子
downsample_padding: int = 1, # 下采样的填充大小
add_downsample: bool = True, # 是否添加下采样层
dual_cross_attention: bool = False, # 是否使用双交叉注意力
use_linear_projection: bool = False, # 是否使用线性投影
only_cross_attention: bool = False, # 是否只使用交叉注意力
upcast_attention: bool = False, # 是否使用上采样注意力
):
# 调用父类的初始化方法
super().__init__()
# 初始化残差块列表
resnets = []
# 初始化注意力层列表
attentions = []
# 初始化临时注意力层列表
temp_attentions = []
# 初始化临时卷积层列表
temp_convs = []
# 设置是否使用交叉注意力
self.has_cross_attention = True
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 根据层数创建各层模块
for i in range(num_layers):
# 设置输入通道数,第一层使用传入的输入通道数,后续层使用输出通道数
in_channels = in_channels if i == 0 else out_channels
# 创建残差块并添加到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # 残差块的 epsilon 值
groups=resnet_groups, # 组卷积的组数
dropout=dropout, # Dropout 比例
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入的归一化方式
non_linearity=resnet_act_fn, # 非线性激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否进行预归一化
)
)
# 创建时间卷积层并添加到列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 输入通道数
out_channels, # 输出通道数
dropout=0.1, # Dropout 比例
norm_num_groups=resnet_groups, # 归一化的组数
)
)
# 创建二维变换器模型并添加到列表中
attentions.append(
Transformer2DModel(
out_channels // num_attention_heads, # 每个注意力头的通道数
num_attention_heads, # 注意力头的数量
in_channels=out_channels, # 输入通道数
num_layers=1, # 变换器层数
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化的组数
use_linear_projection=use_linear_projection, # 是否使用线性映射
only_cross_attention=only_cross_attention, # 是否只使用交叉注意力
upcast_attention=upcast_attention, # 是否上溢注意力
)
)
# 创建时间变换器模型并添加到列表中
temp_attentions.append(
TransformerTemporalModel(
out_channels // num_attention_heads, # 每个注意力头的通道数
num_attention_heads, # 注意力头的数量
in_channels=out_channels, # 输入通道数
num_layers=1, # 变换器层数
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
norm_num_groups=resnet_groups, # 归一化的组数
)
)
# 将残差块列表转换为模块列表
self.resnets = nn.ModuleList(resnets)
# 将临时卷积层列表转换为模块列表
self.temp_convs = nn.ModuleList(temp_convs)
# 将注意力层列表转换为模块列表
self.attentions = nn.ModuleList(attentions)
# 将临时注意力层列表转换为模块列表
self.temp_attentions = nn.ModuleList(temp_attentions)
# 如果需要添加下采样层
if add_downsample:
# 创建下采样模块列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 输出通道数
use_conv=True, # 是否使用卷积进行下采样
out_channels=out_channels, # 输出通道数
padding=downsample_padding, # 下采样时的填充
name="op", # 模块名称
)
]
)
else:
# 如果不需要下采样,设置为 None
self.downsamplers = None
# 初始化梯度检查点设置为 False
self.gradient_checkpointing = False
# 定义前向传播方法,接受多个输入参数并返回张量或元组
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
num_frames: int = 1,
cross_attention_kwargs: Dict[str, Any] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# TODO(Patrick, William) - 注意力掩码未使用
output_states = () # 初始化输出状态为元组
# 遍历所有的残差网络、临时卷积、注意力和临时注意力层
for resnet, temp_conv, attn, temp_attn in zip(
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
):
# 使用残差网络处理隐状态和时间嵌入
hidden_states = resnet(hidden_states, temb)
# 使用临时卷积处理隐状态,考虑帧数
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 使用注意力层处理隐状态,返回字典设为 False
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 取返回的第一个元素
# 使用临时注意力层处理隐状态,返回字典设为 False
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0] # 取返回的第一个元素
# 将当前隐状态添加到输出状态中
output_states += (hidden_states,)
# 如果存在下采样器,则逐个应用
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states) # 应用下采样器
# 将下采样后的隐状态添加到输出状态中
output_states += (hidden_states,)
# 返回最终的隐状态和所有输出状态
return hidden_states, output_states
# 定义一个 3D 下采样模块,继承自 nn.Module
class DownBlock3D(nn.Module):
# 初始化方法,设置各个参数及模块
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 中的 epsilon 值
resnet_time_scale_shift: str = "default", # 时间尺度偏移方式
resnet_act_fn: str = "swish", # ResNet 激活函数类型
resnet_groups: int = 32, # ResNet 中的组数
resnet_pre_norm: bool = True, # 是否使用预归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样层
downsample_padding: int = 1, # 下采样的填充大小
):
# 调用父类构造函数
super().__init__()
# 初始化 ResNet 模块列表
resnets = []
# 初始化时间卷积层列表
temp_convs = []
# 遍历层数,构建 ResNet 模块和时间卷积层
for i in range(num_layers):
# 第一层使用输入通道数,后续层使用输出通道数
in_channels = in_channels if i == 0 else out_channels
# 添加 ResNet 模块到列表中
resnets.append(
ResnetBlock2D(
in_channels=in_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # ResNet 中的 epsilon 值
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化方式
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 是否预归一化
)
)
# 添加时间卷积层到列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 输入通道数
out_channels, # 输出通道数
dropout=0.1, # dropout 概率
norm_num_groups=resnet_groups, # 组数
)
)
# 将 ResNet 模块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 将时间卷积层列表转换为 nn.ModuleList
self.temp_convs = nn.ModuleList(temp_convs)
# 如果需要添加下采样层
if add_downsample:
# 初始化下采样模块列表
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, # 输出通道数
use_conv=True, # 使用卷积
out_channels=out_channels, # 输出通道数
padding=downsample_padding, # 填充大小
name="op", # 模块名称
)
]
)
else:
# 不添加下采样层
self.downsamplers = None
# 初始化梯度检查点开关为 False
self.gradient_checkpointing = False
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
num_frames: int = 1, # 帧数
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
# 初始化输出状态元组
output_states = ()
# 遍历 ResNet 模块和时间卷积层
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# 通过 ResNet 模块处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过时间卷积层处理隐藏状态
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 将当前的隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 如果存在下采样层
if self.downsamplers is not None:
# 遍历下采样层
for downsampler in self.downsamplers:
# 通过下采样层处理隐藏状态
hidden_states = downsampler(hidden_states)
# 将当前的隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 返回最终的隐藏状态和输出状态元组
return hidden_states, output_states
# 定义一个 3D 交叉注意力上采样模块,继承自 nn.Module
class CrossAttnUpBlock3D(nn.Module):
# 初始化方法,用于设置类的属性
def __init__(
# 输入通道数
self,
in_channels: int,
# 输出通道数
out_channels: int,
# 前一个输出通道数
prev_output_channel: int,
# 时间嵌入通道数
temb_channels: int,
# dropout比率,默认为0.0
dropout: float = 0.0,
# 网络层数,默认为1
num_layers: int = 1,
# ResNet的epsilon值,默认为1e-6
resnet_eps: float = 1e-6,
# ResNet的时间尺度偏移,默认为"default"
resnet_time_scale_shift: str = "default",
# ResNet激活函数,默认为"swish"
resnet_act_fn: str = "swish",
# ResNet的组数,默认为32
resnet_groups: int = 32,
# ResNet是否使用预归一化,默认为True
resnet_pre_norm: bool = True,
# 注意力头的数量,默认为1
num_attention_heads: int = 1,
# 交叉注意力的维度,默认为1280
cross_attention_dim: int = 1280,
# 输出缩放因子,默认为1.0
output_scale_factor: float = 1.0,
# 是否添加上采样,默认为True
add_upsample: bool = True,
# 双重交叉注意力的开关,默认为False
dual_cross_attention: bool = False,
# 是否使用线性投影,默认为False
use_linear_projection: bool = False,
# 仅使用交叉注意力的开关,默认为False
only_cross_attention: bool = False,
# 是否上采样注意力,默认为False
upcast_attention: bool = False,
# 分辨率索引,默认为None
resolution_idx: Optional[int] = None,
):
# 调用父类的构造函数进行初始化
super().__init__()
# 初始化用于存储 ResNet 块的列表
resnets = []
# 初始化用于存储时间卷积层的列表
temp_convs = []
# 初始化用于存储注意力模型的列表
attentions = []
# 初始化用于存储时间注意力模型的列表
temp_attentions = []
# 设置是否使用交叉注意力
self.has_cross_attention = True
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 遍历每一层以构建网络结构
for i in range(num_layers):
# 确定残差跳过通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 确定 ResNet 的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 添加一个 ResNet 块到列表中
resnets.append(
ResnetBlock2D(
# 设置输入通道数,包括残差跳过通道
in_channels=resnet_in_channels + res_skip_channels,
# 设置输出通道数
out_channels=out_channels,
# 设置时间嵌入通道数
temb_channels=temb_channels,
# 设置 epsilon 值
eps=resnet_eps,
# 设置组数
groups=resnet_groups,
# 设置 dropout 概率
dropout=dropout,
# 设置时间嵌入归一化方式
time_embedding_norm=resnet_time_scale_shift,
# 设置非线性激活函数
non_linearity=resnet_act_fn,
# 设置输出缩放因子
output_scale_factor=output_scale_factor,
# 设置是否在预归一化
pre_norm=resnet_pre_norm,
)
)
# 添加一个时间卷积层到列表中
temp_convs.append(
TemporalConvLayer(
# 设置输出通道数
out_channels,
# 设置输入通道数
out_channels,
# 设置 dropout 概率
dropout=0.1,
# 设置组数
norm_num_groups=resnet_groups,
)
)
# 添加一个 2D 转换器模型到列表中
attentions.append(
Transformer2DModel(
# 设置每个注意力头的通道数
out_channels // num_attention_heads,
# 设置注意力头的数量
num_attention_heads,
# 设置输入通道数
in_channels=out_channels,
# 设置层数
num_layers=1,
# 设置交叉注意力维度
cross_attention_dim=cross_attention_dim,
# 设置组数
norm_num_groups=resnet_groups,
# 是否使用线性投影
use_linear_projection=use_linear_projection,
# 是否仅使用交叉注意力
only_cross_attention=only_cross_attention,
# 是否提升注意力精度
upcast_attention=upcast_attention,
)
)
# 添加一个时间转换器模型到列表中
temp_attentions.append(
TransformerTemporalModel(
# 设置每个注意力头的通道数
out_channels // num_attention_heads,
# 设置注意力头的数量
num_attention_heads,
# 设置输入通道数
in_channels=out_channels,
# 设置层数
num_layers=1,
# 设置交叉注意力维度
cross_attention_dim=cross_attention_dim,
# 设置组数
norm_num_groups=resnet_groups,
)
)
# 将 ResNet 块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 将时间卷积层列表转换为 nn.ModuleList
self.temp_convs = nn.ModuleList(temp_convs)
# 将注意力模型列表转换为 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 将时间注意力模型列表转换为 nn.ModuleList
self.temp_attentions = nn.ModuleList(temp_attentions)
# 如果需要上采样,则初始化上采样层
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否则将上采样层设置为 None
self.upsamplers = None
# 设置梯度检查点标志
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数,接受多个参数并返回张量
def forward(
self,
hidden_states: torch.Tensor, # 当前层的隐藏状态
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 以前层的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
upsample_size: Optional[int] = None, # 可选的上采样大小
attention_mask: Optional[torch.Tensor] = None, # 可选的注意力掩码
num_frames: int = 1, # 帧数,默认值为1
cross_attention_kwargs: Dict[str, Any] = None, # 可选的交叉注意力参数
) -> torch.Tensor: # 返回一个张量
# 检查 FreeU 是否启用,基于多个属性的存在性
is_freeu_enabled = (
getattr(self, "s1", None) # 获取属性 s1
and getattr(self, "s2", None) # 获取属性 s2
and getattr(self, "b1", None) # 获取属性 b1
and getattr(self, "b2", None) # 获取属性 b2
)
# TODO(Patrick, William) - 注意力掩码尚未使用
for resnet, temp_conv, attn, temp_attn in zip( # 遍历网络模块
self.resnets, self.temp_convs, self.attentions, self.temp_attentions # 从各模块提取
):
# 从元组中弹出最后一个残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新元组,去掉最后一个隐藏状态
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# FreeU:仅对前两个阶段操作
if is_freeu_enabled:
hidden_states, res_hidden_states = apply_freeu( # 应用 FreeU 操作
self.resolution_idx, # 当前分辨率索引
hidden_states, # 当前隐藏状态
res_hidden_states, # 残差隐藏状态
s1=self.s1, # 属性 s1
s2=self.s2, # 属性 s2
b1=self.b1, # 属性 b1
b2=self.b2, # 属性 b2
)
# 将当前隐藏状态与残差隐藏状态在维度1上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 通过 ResNet 模块处理隐藏状态
hidden_states = resnet(hidden_states, temb)
# 通过临时卷积模块处理隐藏状态
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 通过注意力模块处理隐藏状态,并提取第一个返回值
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states, # 传递编码器隐藏状态
cross_attention_kwargs=cross_attention_kwargs, # 传递交叉注意力参数
return_dict=False, # 不返回字典形式的结果
)[0] # 提取第一个返回值
# 通过临时注意力模块处理隐藏状态,并提取第一个返回值
hidden_states = temp_attn(
hidden_states,
num_frames=num_frames, # 传递帧数
cross_attention_kwargs=cross_attention_kwargs, # 传递交叉注意力参数
return_dict=False, # 不返回字典形式的结果
)[0] # 提取第一个返回值
# 如果存在上采样模块
if self.upsamplers is not None:
for upsampler in self.upsamplers: # 遍历上采样模块
hidden_states = upsampler(hidden_states, upsample_size) # 应用上采样模块
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 UpBlock3D 的类,继承自 nn.Module
class UpBlock3D(nn.Module):
# 初始化函数,接受多个参数以配置网络层
def __init__(
self,
in_channels: int, # 输入通道数
prev_output_channel: int, # 前一层的输出通道数
out_channels: int, # 当前层的输出通道数
temb_channels: int, # 时间嵌入通道数
dropout: float = 0.0, # dropout 概率
num_layers: int = 1, # 层数
resnet_eps: float = 1e-6, # ResNet 的 epsilon 参数,用于数值稳定
resnet_time_scale_shift: str = "default", # ResNet 时间缩放偏移设置
resnet_act_fn: str = "swish", # ResNet 激活函数
resnet_groups: int = 32, # ResNet 的组数
resnet_pre_norm: bool = True, # 是否使用前置归一化
output_scale_factor: float = 1.0, # 输出缩放因子
add_upsample: bool = True, # 是否添加上采样层
resolution_idx: Optional[int] = None, # 分辨率索引,默认为 None
):
# 调用父类构造函数
super().__init__()
# 创建一个空列表,用于存储 ResNet 层
resnets = []
# 创建一个空列表,用于存储时间卷积层
temp_convs = []
# 根据层数创建 ResNet 和时间卷积层
for i in range(num_layers):
# 确定跳跃连接的通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 确定当前 ResNet 层的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 创建 ResNet 层,并添加到 resnets 列表中
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # epsilon 参数
groups=resnet_groups, # 组数
dropout=dropout, # dropout 概率
time_embedding_norm=resnet_time_scale_shift, # 时间嵌入归一化
non_linearity=resnet_act_fn, # 激活函数
output_scale_factor=output_scale_factor, # 输出缩放因子
pre_norm=resnet_pre_norm, # 前置归一化
)
)
# 创建时间卷积层,并添加到 temp_convs 列表中
temp_convs.append(
TemporalConvLayer(
out_channels, # 输入通道数
out_channels, # 输出通道数
dropout=0.1, # dropout 概率
norm_num_groups=resnet_groups, # 归一化组数
)
)
# 将 ResNet 层的列表转为 nn.ModuleList,以便于管理
self.resnets = nn.ModuleList(resnets)
# 将时间卷积层的列表转为 nn.ModuleList,以便于管理
self.temp_convs = nn.ModuleList(temp_convs)
# 如果需要添加上采样层,则创建并添加
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不需要上采样层,则设置为 None
self.upsamplers = None
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播函数
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 跳跃连接的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
upsample_size: Optional[int] = None, # 可选的上采样尺寸
num_frames: int = 1, # 帧数
) -> torch.Tensor:
# 判断是否启用 FreeU,检查相关属性是否存在且不为 None
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
and getattr(self, "b1", None)
and getattr(self, "b2", None)
)
# 遍历自定义的 resnets 和 temp_convs,进行逐对处理
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
# 从 res_hidden_states_tuple 中弹出最后一个隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新 res_hidden_states_tuple,去掉最后一个隐藏状态
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 如果启用了 FreeU,则仅对前两个阶段进行操作
if is_freeu_enabled:
# 调用 apply_freeu 函数,处理隐藏状态
hidden_states, res_hidden_states = apply_freeu(
self.resolution_idx,
hidden_states,
res_hidden_states,
s1=self.s1,
s2=self.s2,
b1=self.b1,
b2=self.b2,
)
# 将当前的隐藏状态与残差隐藏状态在维度 1 上连接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 通过当前的 resnet 处理隐藏状态和 temb
hidden_states = resnet(hidden_states, temb)
# 通过当前的 temp_conv 处理隐藏状态,传入 num_frames 参数
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
# 如果存在上采样器,则对每个上采样器进行处理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
# 通过当前的 upsampler 处理隐藏状态,传入 upsample_size 参数
hidden_states = upsampler(hidden_states, upsample_size)
# 返回最终的隐藏状态
return hidden_states
# 定义一个中间块时间解码器类,继承自 nn.Module
class MidBlockTemporalDecoder(nn.Module):
# 初始化函数,定义输入输出通道、注意力头维度、层数等参数
def __init__(
self,
in_channels: int,
out_channels: int,
attention_head_dim: int = 512,
num_layers: int = 1,
upcast_attention: bool = False,
):
# 调用父类的初始化函数
super().__init__()
# 初始化 ResNet 和 Attention 列表
resnets = []
attentions = []
# 根据层数创建相应数量的 ResNet
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
# 将 SpatioTemporalResBlock 实例添加到 ResNet 列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
# 添加 Attention 实例到 Attention 列表中
attentions.append(
Attention(
query_dim=in_channels,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
eps=1e-6,
upcast_attention=upcast_attention,
norm_num_groups=32,
bias=True,
residual_connection=True,
)
)
# 将 Attention 和 ResNet 列表转换为 ModuleList
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
# 前向传播函数,定义输入的隐藏状态和图像指示器的处理
def forward(
self,
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
):
# 处理第一个 ResNet 的输出
hidden_states = self.resnets[0](
hidden_states,
image_only_indicator=image_only_indicator,
)
# 遍历剩余的 ResNet 和 Attention,交替处理
for resnet, attn in zip(self.resnets[1:], self.attentions):
hidden_states = attn(hidden_states) # 应用注意力机制
# 处理 ResNet 的输出
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
# 返回最终的隐藏状态
return hidden_states
# 定义一个上采样块时间解码器类,继承自 nn.Module
class UpBlockTemporalDecoder(nn.Module):
# 初始化函数,定义输入输出通道、层数等参数
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
add_upsample: bool = True,
):
# 调用父类的初始化函数
super().__init__()
# 初始化 ResNet 列表
resnets = []
# 根据层数创建相应数量的 ResNet
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
# 将 SpatioTemporalResBlock 实例添加到 ResNet 列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
eps=1e-6,
temporal_eps=1e-5,
merge_factor=0.0,
merge_strategy="learned",
switch_spatial_to_temporal_mix=True,
)
)
# 将 ResNet 列表转换为 ModuleList
self.resnets = nn.ModuleList(resnets)
# 如果需要,初始化上采样层
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
# 定义前向传播函数,接收隐藏状态和图像指示器作为输入,返回处理后的张量
def forward(
self,
hidden_states: torch.Tensor,
image_only_indicator: torch.Tensor,
) -> torch.Tensor:
# 遍历每个 ResNet 模块,更新隐藏状态
for resnet in self.resnets:
hidden_states = resnet(
hidden_states,
image_only_indicator=image_only_indicator,
)
# 如果存在上采样模块,则对隐藏状态进行上采样处理
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义一个名为 UNetMidBlockSpatioTemporal 的类,继承自 nn.Module
class UNetMidBlockSpatioTemporal(nn.Module):
# 初始化方法,接收多个参数以配置该模块
def __init__(
self,
in_channels: int, # 输入通道数
temb_channels: int, # 时间嵌入通道数
num_layers: int = 1, # 层数,默认为 1
transformer_layers_per_block: Union[int, Tuple[int]] = 1, # 每个块的变换层数,默认为 1
num_attention_heads: int = 1, # 注意力头数,默认为 1
cross_attention_dim: int = 1280, # 交叉注意力维度,默认为 1280
):
# 调用父类的初始化方法
super().__init__()
# 设置是否使用交叉注意力标志
self.has_cross_attention = True
# 存储注意力头的数量
self.num_attention_heads = num_attention_heads
# 支持每个块的变换层数为可变的
if isinstance(transformer_layers_per_block, int):
# 如果是整数,则将其转换为包含 num_layers 个相同元素的列表
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 至少有一个 ResNet 块
resnets = [
# 创建第一个时空残差块
SpatioTemporalResBlock(
in_channels=in_channels, # 输入通道数
out_channels=in_channels, # 输出通道数与输入相同
temb_channels=temb_channels, # 时间嵌入通道数
eps=1e-5, # 小常数用于数值稳定性
)
]
# 初始化注意力模块列表
attentions = []
# 遍历层数以添加注意力和残差块
for i in range(num_layers):
# 添加时空变换模型到注意力列表
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads, # 注意力头数
in_channels // num_attention_heads, # 每个头的通道数
in_channels=in_channels, # 输入通道数
num_layers=transformer_layers_per_block[i], # 当前层的变换层数
cross_attention_dim=cross_attention_dim, # 交叉注意力维度
)
)
# 添加另一个时空残差块到残差列表
resnets.append(
SpatioTemporalResBlock(
in_channels=in_channels, # 输入通道数
out_channels=in_channels, # 输出通道数与输入相同
temb_channels=temb_channels, # 时间嵌入通道数
eps=1e-5, # 小常数用于数值稳定性
)
)
# 将注意力模块列表转换为 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 将残差模块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 初始化梯度检查点标志为 False
self.gradient_checkpointing = False
# 前向传播方法,接收多个输入参数
def forward(
self,
hidden_states: torch.Tensor, # 隐藏状态张量
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 可选的编码器隐藏状态
image_only_indicator: Optional[torch.Tensor] = None, # 可选的图像指示张量
# 返回类型为 torch.Tensor 的函数
) -> torch.Tensor:
# 使用第一个残差网络处理隐藏状态,传入时间嵌入和图像指示器
hidden_states = self.resnets[0](
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 遍历注意力层和后续残差网络的组合
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# 检查是否在训练中且开启了梯度检查点
if self.training and self.gradient_checkpointing: # TODO
# 创建自定义前向传播函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向传播内部函数
def custom_forward(*inputs):
# 根据返回字典参数决定是否使用返回字典
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 设置检查点参数,取决于 PyTorch 版本
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用注意力层处理隐藏状态,传入编码器的隐藏状态和图像指示器
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 使用检查点进行残差网络的前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
else:
# 使用注意力层处理隐藏状态,传入编码器的隐藏状态和图像指示器
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 使用残差网络处理隐藏状态
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 返回处理后的隐藏状态
return hidden_states
# 定义一个下采样的时空块,继承自 nn.Module
class DownBlockSpatioTemporal(nn.Module):
# 初始化方法,设置输入输出通道和层数等参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
temb_channels: int, # 时间嵌入通道数
num_layers: int = 1, # 层数,默认为1
add_downsample: bool = True, # 是否添加下采样
):
super().__init__() # 调用父类的初始化方法
resnets = [] # 初始化一个空列表以存储残差块
# 根据层数创建相应数量的 SpatioTemporalResBlock
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels # 确定当前层的输入通道数
resnets.append(
SpatioTemporalResBlock( # 添加一个新的时空残差块
in_channels=in_channels, # 设置输入通道
out_channels=out_channels, # 设置输出通道
temb_channels=temb_channels, # 设置时间嵌入通道
eps=1e-5, # 设置 epsilon 值
)
)
self.resnets = nn.ModuleList(resnets) # 将残差块列表转化为 ModuleList
# 如果需要下采样,创建下采样模块
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D( # 添加一个下采样层
out_channels, # 设置输入通道
use_conv=True, # 是否使用卷积进行下采样
out_channels=out_channels, # 设置输出通道
name="op", # 下采样层名称
)
]
)
else:
self.downsamplers = None # 不添加下采样模块
self.gradient_checkpointing = False # 初始化梯度检查点为 False
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 隐藏状态输入
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入
image_only_indicator: Optional[torch.Tensor] = None, # 可选的图像指示器
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
output_states = () # 初始化输出状态元组
for resnet in self.resnets: # 遍历每个残差块
if self.training and self.gradient_checkpointing: # 如果在训练且启用了梯度检查点
# 定义一个创建自定义前向传播的方法
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs) # 返回模块的前向输出
return custom_forward
# 根据 PyTorch 版本进行检查点操作
if is_torch_version(">=", "1.11.0"):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定义前向传播
hidden_states, # 输入隐藏状态
temb, # 输入时间嵌入
image_only_indicator, # 输入图像指示器
use_reentrant=False, # 不使用重入
)
else:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), # 使用自定义前向传播
hidden_states, # 输入隐藏状态
temb, # 输入时间嵌入
image_only_indicator, # 输入图像指示器
)
else:
hidden_states = resnet( # 直接通过残差块进行前向传播
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
output_states = output_states + (hidden_states,) # 将当前隐藏状态添加到输出状态中
if self.downsamplers is not None: # 如果存在下采样模块
for downsampler in self.downsamplers: # 遍历下采样模块
hidden_states = downsampler(hidden_states) # 对隐藏状态进行下采样
output_states = output_states + (hidden_states,) # 将下采样后的状态添加到输出状态中
return hidden_states, output_states # 返回最终的隐藏状态和输出状态元组
# 定义一个交叉注意力下采样时空块,继承自 nn.Module
class CrossAttnDownBlockSpatioTemporal(nn.Module):
# 初始化方法,设置模型的基本参数
def __init__(
# 输入通道数
self,
in_channels: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# 层数,默认为1
num_layers: int = 1,
# 每个块的变换层数,可以是整数或元组,默认为1
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# 注意力头数,默认为1
num_attention_heads: int = 1,
# 交叉注意力的维度,默认为1280
cross_attention_dim: int = 1280,
# 是否添加下采样,默认为True
add_downsample: bool = True,
):
# 调用父类初始化方法
super().__init__()
# 初始化残差网络列表
resnets = []
# 初始化注意力层列表
attentions = []
# 设置是否使用交叉注意力为True
self.has_cross_attention = True
# 设置注意力头数
self.num_attention_heads = num_attention_heads
# 如果变换层数是整数,则扩展为列表
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍历层数以创建残差块和注意力层
for i in range(num_layers):
# 如果是第一层,使用输入通道数,否则使用输出通道数
in_channels = in_channels if i == 0 else out_channels
# 添加残差块到列表
resnets.append(
SpatioTemporalResBlock(
# 输入通道数
in_channels=in_channels,
# 输出通道数
out_channels=out_channels,
# 时间嵌入通道数
temb_channels=temb_channels,
# 防止除零的微小值
eps=1e-6,
)
)
# 添加注意力模型到列表
attentions.append(
TransformerSpatioTemporalModel(
# 注意力头数
num_attention_heads,
# 每个头的输出通道数
out_channels // num_attention_heads,
# 输入通道数
in_channels=out_channels,
# 该层的变换层数
num_layers=transformer_layers_per_block[i],
# 交叉注意力的维度
cross_attention_dim=cross_attention_dim,
)
)
# 将注意力层转换为nn.ModuleList以支持PyTorch模型
self.attentions = nn.ModuleList(attentions)
# 将残差层转换为nn.ModuleList以支持PyTorch模型
self.resnets = nn.ModuleList(resnets)
# 如果需要添加下采样层
if add_downsample:
# 添加下采样层到nn.ModuleList
self.downsamplers = nn.ModuleList(
[
Downsample2D(
# 输出通道数
out_channels,
# 是否使用卷积
use_conv=True,
# 输出通道数
out_channels=out_channels,
# 填充大小
padding=1,
# 操作名称
name="op",
)
]
)
else:
# 如果不需要下采样层,则设置为None
self.downsamplers = None
# 初始化梯度检查点为False
self.gradient_checkpointing = False
# 前向传播方法
def forward(
# 隐藏状态的张量
hidden_states: torch.Tensor,
# 时间嵌入的可选张量
temb: Optional[torch.Tensor] = None,
# 编码器隐藏状态的可选张量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 仅图像指示的可选张量
image_only_indicator: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: # 定义返回类型为一个元组,包含一个张量和多个张量的元组
output_states = () # 初始化一个空元组,用于存储输出状态
blocks = list(zip(self.resnets, self.attentions)) # 将自定义的残差网络和注意力模块打包成一个列表
for resnet, attn in blocks: # 遍历每个残差网络和对应的注意力模块
if self.training and self.gradient_checkpointing: # 如果处于训练模式并且启用了梯度检查点
def create_custom_forward(module, return_dict=None): # 定义一个创建自定义前向传播函数的辅助函数
def custom_forward(*inputs): # 定义自定义前向传播函数,接受任意数量的输入
if return_dict is not None: # 如果提供了返回字典参数
return module(*inputs, return_dict=return_dict) # 使用返回字典参数调用模块
else: # 如果没有提供返回字典
return module(*inputs) # 直接调用模块并返回结果
return custom_forward # 返回自定义前向传播函数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} # 根据 PyTorch 版本设置检查点参数
hidden_states = torch.utils.checkpoint.checkpoint( # 使用检查点功能计算隐藏状态以节省内存
create_custom_forward(resnet), # 将残差网络传入自定义前向函数
hidden_states, # 将当前隐藏状态作为输入
temb, # 传递时间嵌入
image_only_indicator, # 传递图像指示器
**ckpt_kwargs, # 解包检查点参数
)
hidden_states = attn( # 使用注意力模块处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 传递编码器隐藏状态
image_only_indicator=image_only_indicator, # 传递图像指示器
return_dict=False, # 不返回字典形式的结果
)[0] # 获取输出的第一个元素
else: # 如果不处于训练模式或未启用梯度检查点
hidden_states = resnet( # 直接使用残差网络处理隐藏状态
hidden_states, # 输入当前隐藏状态
temb, # 传递时间嵌入
image_only_indicator=image_only_indicator, # 传递图像指示器
)
hidden_states = attn( # 使用注意力模块处理隐藏状态
hidden_states, # 输入隐藏状态
encoder_hidden_states=encoder_hidden_states, # 传递编码器隐藏状态
image_only_indicator=image_only_indicator, # 传递图像指示器
return_dict=False, # 不返回字典形式的结果
)[0] # 获取输出的第一个元素
output_states = output_states + (hidden_states,) # 将当前的隐藏状态添加到输出状态元组中
if self.downsamplers is not None: # 如果存在下采样模块
for downsampler in self.downsamplers: # 遍历每个下采样模块
hidden_states = downsampler(hidden_states) # 使用下采样模块处理隐藏状态
output_states = output_states + (hidden_states,) # 将处理后的隐藏状态添加到输出状态元组中
return hidden_states, output_states # 返回当前的隐藏状态和所有输出状态
# 定义一个上采样的时空块类,继承自 nn.Module
class UpBlockSpatioTemporal(nn.Module):
# 初始化方法,定义类的参数
def __init__(
self,
in_channels: int, # 输入通道数
prev_output_channel: int, # 前一层输出通道数
out_channels: int, # 当前层输出通道数
temb_channels: int, # 时间嵌入通道数
resolution_idx: Optional[int] = None, # 可选的分辨率索引
num_layers: int = 1, # 层数,默认为1
resnet_eps: float = 1e-6, # ResNet 的 epsilon 值
add_upsample: bool = True, # 是否添加上采样层
):
# 调用父类的初始化方法
super().__init__()
# 初始化一个空的列表,用于存储 ResNet 模块
resnets = []
# 根据层数创建对应的 ResNet 块
for i in range(num_layers):
# 如果是最后一层,使用输入通道数,否则使用输出通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 第一层的输入通道为前一层输出通道,后续层为输出通道
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 创建时空 ResNet 块并添加到列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels, # 输入通道数
out_channels=out_channels, # 输出通道数
temb_channels=temb_channels, # 时间嵌入通道数
eps=resnet_eps, # ResNet 的 epsilon 值
)
)
# 将 ResNet 模块列表转换为 nn.ModuleList,以便在模型中管理
self.resnets = nn.ModuleList(resnets)
# 如果需要添加上采样层,则创建对应的 nn.ModuleList
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 如果不添加上采样层,设置为 None
self.upsamplers = None
# 初始化梯度检查点标志
self.gradient_checkpointing = False
# 设置分辨率索引
self.resolution_idx = resolution_idx
# 定义前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态张量
res_hidden_states_tuple: Tuple[torch.Tensor, ...], # 之前层的隐藏状态元组
temb: Optional[torch.Tensor] = None, # 可选的时间嵌入张量
image_only_indicator: Optional[torch.Tensor] = None, # 可选的图像指示器张量
) -> torch.Tensor: # 定义函数返回类型为 PyTorch 的张量
for resnet in self.resnets: # 遍历当前对象中的所有 ResNet 模型
# pop res hidden states # 从隐藏状态元组中提取最后一个隐藏状态
res_hidden_states = res_hidden_states_tuple[-1] # 获取最后的 ResNet 隐藏状态
res_hidden_states_tuple = res_hidden_states_tuple[:-1] # 更新元组,移除最后一个隐藏状态
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) # 将当前隐藏状态与 ResNet 隐藏状态在维度 1 上拼接
if self.training and self.gradient_checkpointing: # 如果处于训练状态并启用了梯度检查点
def create_custom_forward(module): # 定义用于创建自定义前向传播函数的内部函数
def custom_forward(*inputs): # 自定义前向传播,接收任意数量的输入
return module(*inputs) # 调用原始模块的前向传播
return custom_forward # 返回自定义前向传播函数
if is_torch_version(">=", "1.11.0"): # 检查当前 PyTorch 版本是否大于等于 1.11.0
hidden_states = torch.utils.checkpoint.checkpoint( # 使用检查点机制保存内存
create_custom_forward(resnet), # 传入自定义前向传播函数
hidden_states, # 传入当前的隐藏状态
temb, # 传入时间嵌入
image_only_indicator, # 传入图像指示器
use_reentrant=False, # 禁用重入
)
else: # 如果 PyTorch 版本小于 1.11.0
hidden_states = torch.utils.checkpoint.checkpoint( # 使用检查点机制保存内存
create_custom_forward(resnet), # 传入自定义前向传播函数
hidden_states, # 传入当前的隐藏状态
temb, # 传入时间嵌入
image_only_indicator, # 传入图像指示器
)
else: # 如果不是训练状态或没有启用梯度检查点
hidden_states = resnet( # 直接调用 ResNet 模型处理隐藏状态
hidden_states, # 传入当前的隐藏状态
temb, # 传入时间嵌入
image_only_indicator=image_only_indicator, # 传入图像指示器
)
if self.upsamplers is not None: # 如果存在上采样模块
for upsampler in self.upsamplers: # 遍历所有上采样模块
hidden_states = upsampler(hidden_states) # 调用上采样模块处理隐藏状态
return hidden_states # 返回处理后的隐藏状态
# 定义一个时空交叉注意力上采样块类,继承自 nn.Module
class CrossAttnUpBlockSpatioTemporal(nn.Module):
# 初始化方法,设置网络的各个参数
def __init__(
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 前一层输出通道数
prev_output_channel: int,
# 时间嵌入通道数
temb_channels: int,
# 分辨率索引,可选
resolution_idx: Optional[int] = None,
# 层数
num_layers: int = 1,
# 每个块的变换器层数,支持单个整数或元组
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
# ResNet 的 epsilon 值,防止除零错误
resnet_eps: float = 1e-6,
# 注意力头的数量
num_attention_heads: int = 1,
# 交叉注意力维度
cross_attention_dim: int = 1280,
# 是否添加上采样层
add_upsample: bool = True,
):
# 调用父类的初始化方法
super().__init__()
# 存储 ResNet 层的列表
resnets = []
# 存储注意力层的列表
attentions = []
# 指示是否使用交叉注意力
self.has_cross_attention = True
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 如果是整数,将其转换为列表,包含 num_layers 个相同的元素
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
# 遍历每一层,构建 ResNet 和注意力层
for i in range(num_layers):
# 根据当前层是否是最后一层来决定跳过连接的通道数
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
# 根据当前层决定 ResNet 的输入通道数
resnet_in_channels = prev_output_channel if i == 0 else out_channels
# 添加时空 ResNet 块到列表中
resnets.append(
SpatioTemporalResBlock(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
)
)
# 添加时空变换器模型到列表中
attentions.append(
TransformerSpatioTemporalModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
)
)
# 将注意力层列表转换为 nn.ModuleList,以便于管理
self.attentions = nn.ModuleList(attentions)
# 将 ResNet 层列表转换为 nn.ModuleList,以便于管理
self.resnets = nn.ModuleList(resnets)
# 如果需要,添加上采样层
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
else:
# 否则上采样层设置为 None
self.upsamplers = None
# 设置梯度检查点标志为 False
self.gradient_checkpointing = False
# 存储分辨率索引
self.resolution_idx = resolution_idx
# 前向传播方法,定义输入和输出
def forward(
# 隐藏状态张量
hidden_states: torch.Tensor,
# 上一层隐藏状态的元组
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
# 可选的时间嵌入张量
temb: Optional[torch.Tensor] = None,
# 可选的编码器隐藏状态张量
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选的图像指示器张量
image_only_indicator: Optional[torch.Tensor] = None,
# 返回一个 torch.Tensor 类型的结果
) -> torch.Tensor:
# 遍历每个 resnet 和 attention 模块的组合
for resnet, attn in zip(self.resnets, self.attentions):
# 从隐藏状态元组中弹出最后一个 res 隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 更新隐藏状态元组,去掉最后一个元素
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# 在指定维度连接当前的 hidden_states 和 res_hidden_states
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 如果处于训练模式且开启了梯度检查点
if self.training and self.gradient_checkpointing: # TODO
# 定义一个用于创建自定义前向传播函数的函数
def create_custom_forward(module, return_dict=None):
# 定义自定义前向传播逻辑
def custom_forward(*inputs):
# 根据是否返回字典,调用模块的前向传播
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
# 根据 PyTorch 版本选择 checkpoint 的参数
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
# 使用检查点保存内存,调用自定义前向传播
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
temb,
image_only_indicator,
**ckpt_kwargs,
)
# 通过 attention 模块处理 hidden_states
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
else:
# 如果不使用检查点,直接通过 resnet 模块处理 hidden_states
hidden_states = resnet(
hidden_states,
temb,
image_only_indicator=image_only_indicator,
)
# 通过 attention 模块处理 hidden_states
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
# 如果存在上采样模块,逐个应用于 hidden_states
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
# 返回处理后的 hidden_states
return hidden_states