diffusers 源码解析(十二)
# 版权声明,指明版权归 2024 年 HuggingFace 团队所有
#
# 根据 Apache 许可证第 2.0 版("许可证")进行许可;
# 除非符合许可证,否则您不得使用此文件。
# 您可以在以下网址获得许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有约定,根据许可证分发的软件是以 "按现状" 的基础进行分发,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取有关权限和
# 限制的特定语言。
from typing import Any, Dict, Optional # 导入用于类型注释的 Any、Dict 和 Optional 模块
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 的函数式神经网络模块,通常用于激活函数等
from torch import nn # 从 PyTorch 中导入 nn 模块,用于构建神经网络
from ...configuration_utils import LegacyConfigMixin, register_to_config # 从配置工具导入遗留配置混合类和注册配置函数
from ...utils import deprecate, is_torch_version, logging # 从工具模块导入弃用函数、PyTorch 版本检查函数和日志功能
from ..attention import BasicTransformerBlock # 从注意力模块导入基础变换器块
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection # 从嵌入模块导入图像位置嵌入、补丁嵌入和 PixArt Alpha 文本投影
from ..modeling_outputs import Transformer2DModelOutput # 从建模输出模块导入 2D 变换器模型输出类
from ..modeling_utils import LegacyModelMixin # 从建模工具模块导入遗留模型混合类
from ..normalization import AdaLayerNormSingle # 从归一化模块导入 AdaLayerNormSingle 类
logger = logging.get_logger(__name__) # 创建一个与当前模块名称相关的日志记录器,禁用 pylint 的无效名称警告
class Transformer2DModelOutput(Transformer2DModelOutput): # 定义 Transformer2DModelOutput 类,继承自 Transformer2DModelOutput
def __init__(self, *args, **kwargs): # 构造函数,接受任意参数
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead." # 设置弃用信息
deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message) # 调用弃用函数,记录弃用信息
super().__init__(*args, **kwargs) # 调用父类的构造函数,传递参数
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): # 定义 Transformer2DModel 类,继承遗留模型混合类和遗留配置混合类
"""
A 2D Transformer model for image-like data. # 类文档字符串,说明这是一个用于图像类数据的 2D 变换器模型。
# 定义参数部分,用于描述模型配置
Parameters:
# 多头注意力中使用的头数,默认为16
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
# 每个头中的通道数,默认为88
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
# 输入和输出中的通道数(如果输入是**连续**,则需指定)
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
# 使用的Transformer块的层数,默认为1
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
# 使用的丢弃概率,默认为0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 使用的`encoder_hidden_states`维度数
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
# 潜在图像的宽度(如果输入是**离散**,则需指定)
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
# 在训练期间固定使用,以学习多个位置嵌入
This is fixed during training since it is used to learn a number of position embeddings.
# 潜在像素的向量嵌入类数(如果输入是**离散**,则需指定)
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
# 包含用于掩蔽潜在像素的类
Includes the class for the masked latent pixel.
# 前馈中的激活函数,默认为"geglu"
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
# 训练期间使用的扩散步骤数。如果至少有一个norm_layers是
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
# `AdaLayerNorm`。在训练期间固定使用,以学习多个嵌入
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
# 添加到隐藏状态。
added to the hidden states.
# 在推理期间,可以去噪的步骤数最多不超过`num_embeds_ada_norm`
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
# 配置是否`TransformerBlocks`的注意力应包含偏差参数
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
# 支持梯度检查点功能的标志
_supports_gradient_checkpointing = True
# 不进行拆分的模块列表
_no_split_modules = ["BasicTransformerBlock"]
# 注册到配置的装饰器
@register_to_config
# 初始化方法,用于设置模型的参数
def __init__(
# 设置注意力头的数量,默认为16
self,
num_attention_heads: int = 16,
# 每个注意力头的维度,默认为88
attention_head_dim: int = 88,
# 输入通道数,默认为None,表示未指定
in_channels: Optional[int] = None,
# 输出通道数,默认为None,表示未指定
out_channels: Optional[int] = None,
# 模型层数,默认为1
num_layers: int = 1,
# dropout比率,默认为0.0
dropout: float = 0.0,
# 归一化时的组数,默认为32
norm_num_groups: int = 32,
# 交叉注意力维度,默认为None,表示未指定
cross_attention_dim: Optional[int] = None,
# 是否使用注意力偏差,默认为False
attention_bias: bool = False,
# 采样大小,默认为None,表示未指定
sample_size: Optional[int] = None,
# 向量嵌入的数量,默认为None,表示未指定
num_vector_embeds: Optional[int] = None,
# patch大小,默认为None,表示未指定
patch_size: Optional[int] = None,
# 激活函数类型,默认为"geglu"
activation_fn: str = "geglu",
# 自适应归一化嵌入的数量,默认为None,表示未指定
num_embeds_ada_norm: Optional[int] = None,
# 是否使用线性投影,默认为False
use_linear_projection: bool = False,
# 是否仅使用交叉注意力,默认为False
only_cross_attention: bool = False,
# 是否使用双重自注意力,默认为False
double_self_attention: bool = False,
# 是否提高注意力精度,默认为False
upcast_attention: bool = False,
# 归一化类型,默认为"layer_norm"
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
# 归一化时是否使用元素级仿射,默认为True
norm_elementwise_affine: bool = True,
# 归一化的epsilon值,默认为1e-5
norm_eps: float = 1e-5,
# 注意力类型,默认为"default"
attention_type: str = "default",
# 说明通道数,默认为None,表示未指定
caption_channels: int = None,
# 插值缩放因子,默认为None,表示未指定
interpolation_scale: float = None,
# 是否使用额外条件,默认为None,表示未指定
use_additional_conditions: Optional[bool] = None,
# 初始化连续输入的方法,接受归一化类型作为参数
def _init_continuous_input(self, norm_type):
# 创建归一化层,使用组归一化,设置组数、通道数和epsilon
self.norm = torch.nn.GroupNorm(
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
)
# 如果使用线性投影,则创建线性层进行输入投影
if self.use_linear_projection:
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
# 否则,创建卷积层进行输入投影
else:
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
# 创建变换器块的模块列表
self.transformer_blocks = nn.ModuleList(
[
# 对于每一层,初始化一个基本变换器块
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
only_cross_attention=self.config.only_cross_attention,
double_self_attention=self.config.double_self_attention,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
)
# 重复上面的块,根据模型层数
for _ in range(self.config.num_layers)
]
)
# 如果使用线性投影,则创建线性层进行输出投影
if self.use_linear_projection:
self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
# 否则,创建卷积层进行输出投影
else:
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
# 初始化向量化输入的方法,接收规范类型作为参数
def _init_vectorized_inputs(self, norm_type):
# 确保配置中的样本大小不为 None,否则抛出错误信息
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
# 确保配置中的向量嵌入数量不为 None,否则抛出错误信息
assert (
self.config.num_vector_embeds is not None
), "Transformer2DModel over discrete input must provide num_embed"
# 从配置中获取样本大小并赋值给高度
self.height = self.config.sample_size
# 从配置中获取样本大小并赋值给宽度
self.width = self.config.sample_size
# 计算潜在像素的总数量,等于高度乘以宽度
self.num_latent_pixels = self.height * self.width
# 创建图像位置嵌入对象,用于处理向量嵌入和图像维度
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
)
# 创建一个包含基本变换块的模块列表
self.transformer_blocks = nn.ModuleList(
[
# 为每一层创建一个基本变换块
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
cross_attention_dim=self.config.cross_attention_dim,
activation_fn=self.config.activation_fn,
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
attention_bias=self.config.attention_bias,
only_cross_attention=self.config.only_cross_attention,
double_self_attention=self.config.double_self_attention,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
attention_type=self.config.attention_type,
)
# 通过配置中的层数决定变换块的数量
for _ in range(self.config.num_layers)
]
)
# 创建输出层归一化层
self.norm_out = nn.LayerNorm(self.inner_dim)
# 创建线性层,将内部维度映射到向量嵌入数量减一
self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
# 设置梯度检查点的方法,接收模块和布尔值作为参数
def _set_gradient_checkpointing(self, module, value=False):
# 检查模块是否具有梯度检查点属性
if hasattr(module, "gradient_checkpointing"):
# 设置模块的梯度检查点属性
module.gradient_checkpointing = value
# 前向传播方法,处理输入的隐藏状态及其他可选参数
def forward(
self,
hidden_states: torch.Tensor, # 隐藏状态张量
encoder_hidden_states: Optional[torch.Tensor] = None, # 编码器隐藏状态,默认为 None
timestep: Optional[torch.LongTensor] = None, # 时间步长,默认为 None
added_cond_kwargs: Dict[str, torch.Tensor] = None, # 额外条件的字典,默认为 None
class_labels: Optional[torch.LongTensor] = None, # 类标签,默认为 None
cross_attention_kwargs: Dict[str, Any] = None, # 交叉注意力参数字典,默认为 None
attention_mask: Optional[torch.Tensor] = None, # 注意力掩码,默认为 None
encoder_attention_mask: Optional[torch.Tensor] = None, # 编码器注意力掩码,默认为 None
return_dict: bool = True, # 是否返回字典格式的结果,默认为 True
# 对连续输入进行操作,处理隐藏状态
def _operate_on_continuous_inputs(self, hidden_states):
# 获取隐藏状态的批次大小、高度和宽度
batch, _, height, width = hidden_states.shape
# 对隐藏状态进行归一化处理
hidden_states = self.norm(hidden_states)
# 如果不使用线性投影
if not self.use_linear_projection:
# 通过输入投影层处理隐藏状态
hidden_states = self.proj_in(hidden_states)
# 获取内部维度
inner_dim = hidden_states.shape[1]
# 调整隐藏状态的维度
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
# 获取内部维度
inner_dim = hidden_states.shape[1]
# 调整隐藏状态的维度
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
# 通过输入投影层处理隐藏状态
hidden_states = self.proj_in(hidden_states)
# 返回处理后的隐藏状态和内部维度
return hidden_states, inner_dim
# 对修补输入进行操作,处理隐藏状态和编码器隐藏状态
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
# 获取批次大小
batch_size = hidden_states.shape[0]
# 对隐藏状态进行位置嵌入
hidden_states = self.pos_embed(hidden_states)
# 初始化嵌入时间步
embedded_timestep = None
# 如果自适应归一化单元存在
if self.adaln_single is not None:
# 如果使用额外条件且未提供额外参数,抛出错误
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
# 处理时间步和嵌入时间步
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 如果存在标题投影
if self.caption_projection is not None:
# 对编码器隐藏状态进行投影
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
# 调整编码器隐藏状态的维度
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
# 返回处理后的隐藏状态、编码器隐藏状态、时间步和嵌入时间步
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
# 获取连续输入的输出,处理隐藏状态和残差
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
# 如果不使用线性投影
if not self.use_linear_projection:
# 调整隐藏状态的维度
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
)
# 通过输出投影层处理隐藏状态
hidden_states = self.proj_out(hidden_states)
else:
# 通过输出投影层处理隐藏状态
hidden_states = self.proj_out(hidden_states)
# 调整隐藏状态的维度
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
)
# 将处理后的隐藏状态与残差相加
output = hidden_states + residual
# 返回最终输出
return output
# 获取向量化输入的输出,处理隐藏状态
def _get_output_for_vectorized_inputs(self, hidden_states):
# 对隐藏状态进行归一化处理
hidden_states = self.norm_out(hidden_states)
# 通过输出层处理隐藏状态,得到 logits
logits = self.out(hidden_states)
# 调整 logits 的维度
logits = logits.permute(0, 2, 1)
# 对 logits 应用 log_softmax,获取最终输出
output = F.log_softmax(logits.double(), dim=1).float()
# 返回最终输出
return output
# 获取修补输入的输出,处理隐藏状态和时间步
def _get_output_for_patched_inputs(
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
):
# 检查配置中的归一化类型是否不是 "ada_norm_single"
if self.config.norm_type != "ada_norm_single":
# 使用第一个变换块的归一化层对时间步和类别标签进行嵌入处理
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
# 将条件信息通过线性变换获得偏移和缩放因子
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
# 对隐藏状态进行归一化和调整,应用偏移和缩放
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
# 将调整后的隐藏状态通过第二个线性变换
hidden_states = self.proj_out_2(hidden_states)
# 检查配置中的归一化类型是否是 "ada_norm_single"
elif self.config.norm_type == "ada_norm_single":
# 从缩放偏移表中获得偏移和缩放因子
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
# 对隐藏状态进行归一化处理
hidden_states = self.norm_out(hidden_states)
# 调整隐藏状态,应用偏移和缩放
hidden_states = hidden_states * (1 + scale) + shift
# 将调整后的隐藏状态通过线性变换
hidden_states = self.proj_out(hidden_states)
# 压缩维度,去掉多余的维度
hidden_states = hidden_states.squeeze(1)
# 取消补丁化处理
if self.adaln_single is None:
# 计算高度和宽度,基于隐藏状态的形状
height = width = int(hidden_states.shape[1] ** 0.5)
# 重新调整隐藏状态的形状
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
# 使用爱因斯坦求和约定重排维度
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
# 再次调整输出的形状
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
# 返回最终的输出结果
return output
# 版权声明,标明版权归属及相关许可信息
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证第 2.0 版的规定,使用该文件的条款
# Licensed under the Apache License, Version 2.0 (the "License");
# 你不得使用该文件,除非遵守许可证
# You may not use this file except in compliance with the License.
# 许可证的副本可以在以下网址获得
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面同意,按许可证分发的软件在 "按现状" 的基础上提供,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 参见许可证了解有关权限和限制的具体条款
# See the License for the specific language governing permissions and
# limitations under the License.
# 导入类型注解以支持类型提示
from typing import Any, Dict, List, Optional, Union
# 导入 PyTorch 及其神经网络模块
import torch
import torch.nn as nn
import torch.nn.functional as F
# 从配置和加载模块导入所需的类和混合
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
# 从注意力模块导入所需的类
from ...models.attention import FeedForward
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
# 从模型工具导入基础模型类
from ...models.modeling_utils import ModelMixin
# 从归一化模块导入自适应层归一化类
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
# 从工具模块导入各种实用功能
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
# 从嵌入模块导入所需的类
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
# 从输出模块导入 Transformer 模型输出类
from ..modeling_outputs import Transformer2DModelOutput
# 获取日志记录器以供本模块使用
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi 待办事项: 重构与 rope 相关的函数/类
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
# 确保输入维度为偶数
assert dim % 2 == 0, "The dimension must be even."
# 计算缩放因子,范围从 0 到 dim,步长为 2
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
# 计算 omega 值
omega = 1.0 / (theta**scale)
# 获取批次大小和序列长度
batch_size, seq_length = pos.shape
# 使用爱因斯坦求和约定计算输出
out = torch.einsum("...n,d->...nd", pos, omega)
# 计算余弦和正弦值
cos_out = torch.cos(out)
sin_out = torch.sin(out)
# 堆叠余弦和正弦结果
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# 重塑输出形状
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
# 返回输出的浮点数形式
return out.float()
# YiYi 待办事项: 重构与 rope 相关的函数/类
class EmbedND(nn.Module):
# 初始化方法,接收维度、theta 和轴维度
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__() # 调用父类初始化
self.dim = dim # 存储维度
self.theta = theta # 存储 theta 参数
self.axes_dim = axes_dim # 存储轴维度
# 前向传播方法,接受输入 ID
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1] # 获取轴的数量
# 计算嵌入并沿着维度 -3 连接结果
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
# 返回添加维度的嵌入
return emb.unsqueeze(1)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
# 参数说明文档
Parameters:
dim (`int`): 输入和输出的通道数量
num_attention_heads (`int`): 多头注意力机制中使用的头数量
attention_head_dim (`int`): 每个头的通道数量
context_pre_only (`bool`): 布尔值,确定是否添加与处理 `context` 条件相关的一些模块
"""
# 初始化函数,设置模型参数
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
# 调用父类构造函数
super().__init__()
# 计算 MLP 隐藏层的维度
self.mlp_hidden_dim = int(dim * mlp_ratio)
# 创建自适应层归一化实例
self.norm = AdaLayerNormZeroSingle(dim)
# 创建线性变换层,将输入维度映射到 MLP 隐藏维度
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
# 使用 GELU 激活函数
self.act_mlp = nn.GELU(approximate="tanh")
# 创建线性变换层,将 MLP 输出和输入维度合并并映射回输入维度
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
# 创建注意力处理器实例
processor = FluxSingleAttnProcessor2_0()
# 初始化注意力机制,配置相关参数
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
# 前向传播函数,定义输入的处理方式
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
):
# 保存输入的残差用于后续相加
residual = hidden_states
# 进行层归一化,并得到归一化后的隐藏状态和门控值
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
# 对归一化后的隐藏状态进行线性变换和激活
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
# 计算注意力输出
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# 将注意力输出和 MLP 隐藏状态在最后一维拼接
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
# 扩展门控值维度以便后续运算
gate = gate.unsqueeze(1)
# 使用门控机制对输出进行加权,并将其与残差相加
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
# 如果数据类型为 float16,则对输出进行裁剪,避免溢出
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
# 返回最终的隐藏状态
return hidden_states
# 装饰器,可能允许该类在计算图中使用
@maybe_allow_in_graph
# 定义 FluxTransformerBlock 类,继承自 nn.Module
class FluxTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
# 初始化方法,接受多个参数以配置 Transformer 块
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
# 调用父类构造函数
super().__init__()
# 初始化第一个自适应层归一化
self.norm1 = AdaLayerNormZero(dim)
# 初始化上下文的自适应层归一化
self.norm1_context = AdaLayerNormZero(dim)
# 检查 PyTorch 是否支持 scaled_dot_product_attention
if hasattr(F, "scaled_dot_product_attention"):
# 创建 Attention 处理器
processor = FluxAttnProcessor2_0()
else:
# 如果不支持,抛出异常
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
# 初始化注意力层
self.attn = Attention(
query_dim=dim, # 查询维度
cross_attention_dim=None, # 交叉注意力维度
added_kv_proj_dim=dim, # 额外键值投影维度
dim_head=attention_head_dim, # 每个头的维度
heads=num_attention_heads, # 注意力头的数量
out_dim=dim, # 输出维度
context_pre_only=False, # 上下文预处理标志
bias=True, # 是否使用偏置
processor=processor, # 注意力处理器
qk_norm=qk_norm, # 查询键的归一化方式
eps=eps, # 稳定性常数
)
# 初始化第二个层归一化
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
# 初始化前馈网络
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
# 初始化上下文的第二个层归一化
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
# 初始化上下文的前馈网络
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
# 让块大小默认为 None
self._chunk_size = None
# 设定块维度为 0
self._chunk_dim = 0
# 前向传播方法,定义输入及其处理
def forward(
self,
hidden_states: torch.FloatTensor, # 输入的隐藏状态
encoder_hidden_states: torch.FloatTensor, # 编码器的隐藏状态
temb: torch.FloatTensor, # 额外的嵌入信息
image_rotary_emb=None, # 可选的图像旋转嵌入
):
# 对隐藏状态进行归一化处理,并计算门控相关的多头自注意力值
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# 对编码器的隐藏状态进行归一化处理,并计算门控相关的多头自注意力值
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 注意力机制计算
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# 处理注意力输出以更新 `hidden_states`
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
# 对更新后的隐藏状态进行第二次归一化处理
norm_hidden_states = self.norm2(hidden_states)
# 结合门控机制调整归一化后的隐藏状态
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
# 前馈网络处理
ff_output = self.ff(norm_hidden_states)
# 结合门控机制调整前馈网络输出
ff_output = gate_mlp.unsqueeze(1) * ff_output
# 更新 `hidden_states`
hidden_states = hidden_states + ff_output
# 处理编码器隐藏状态的注意力输出
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
# 对编码器隐藏状态进行第二次归一化处理
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
# 结合门控机制调整归一化后的编码器隐藏状态
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 对编码器的前馈网络处理
context_ff_output = self.ff_context(norm_encoder_hidden_states)
# 更新编码器隐藏状态
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
# 对半精度数据进行范围裁剪,避免溢出
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
# 返回更新后的编码器和隐藏状态
return encoder_hidden_states, hidden_states
# 定义 FluxTransformer2DModel 类,继承自多个混合类以获取其功能
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
Flux 中引入的 Transformer 模型。
参考文献: https://blackforestlabs.ai/announcing-black-forest-labs/
参数:
patch_size (`int`): 将输入数据转换为小块的块大小。
in_channels (`int`, *可选*, 默认为 16): 输入的通道数量。
num_layers (`int`, *可选*, 默认为 18): 使用的 MMDiT 块的层数。
num_single_layers (`int`, *可选*, 默认为 18): 使用的单 DiT 块的层数。
attention_head_dim (`int`, *可选*, 默认为 64): 每个头的通道数。
num_attention_heads (`int`, *可选*, 默认为 18): 用于多头注意力的头数。
joint_attention_dim (`int`, *可选*): 用于 `encoder_hidden_states` 维度的数量。
pooled_projection_dim (`int`): 投影 `pooled_projections` 时使用的维度数量。
guidance_embeds (`bool`, 默认为 False): 是否使用引导嵌入。
"""
# 支持梯度检查点,减少内存使用
_supports_gradient_checkpointing = True
# 注册到配置中,初始化模型参数
@register_to_config
def __init__(
# 定义块的大小,默认为 1
self,
patch_size: int = 1,
# 定义输入通道的数量,默认为 64
in_channels: int = 64,
# 定义 MMDiT 块的层数,默认为 19
num_layers: int = 19,
# 定义单 DiT 块的层数,默认为 38
num_single_layers: int = 38,
# 定义每个注意力头的通道数,默认为 128
attention_head_dim: int = 128,
# 定义多头注意力的头数,默认为 24
num_attention_heads: int = 24,
# 定义用于 `encoder_hidden_states` 的维度,默认为 4096
joint_attention_dim: int = 4096,
# 定义投影的维度,默认为 768
pooled_projection_dim: int = 768,
# 定义是否使用引导嵌入,默认为 False
guidance_embeds: bool = False,
# 定义 ROPE 的轴维度,默认值为 [16, 56, 56]
axes_dims_rope: List[int] = [16, 56, 56],
):
# 调用父类构造函数初始化
super().__init__()
# 设置输出通道数为输入通道数
self.out_channels = in_channels
# 计算内层维度为注意力头数量乘以每个头的维度
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
# 创建位置嵌入对象,用于维度和轴的设置
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
# 根据是否使用引导嵌入选择合并时间步引导文本投影嵌入类
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
# 创建时间文本嵌入对象,使用前面选择的类
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
# 创建线性层用于上下文嵌入
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
# 创建线性层用于输入嵌入
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
# 创建多个变换器块的模块列表
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
# 创建多个单一变换器块的模块列表
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
# 创建自适应层归一化层作为输出层
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
# 创建线性投影层,将内层维度映射到输出通道的形状
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
# 设置梯度检查点为 False
self.gradient_checkpointing = False
# 定义设置梯度检查点的函数
def _set_gradient_checkpointing(self, module, value=False):
# 检查模块是否具有梯度检查点属性
if hasattr(module, "gradient_checkpointing"):
# 设置模块的梯度检查点属性
module.gradient_checkpointing = value
# 定义前向传播方法
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
# 版权所有 2024 Stability AI, The HuggingFace Team 和 The InstantX Team。保留所有权利。
#
# 根据 Apache 许可证第 2.0 版(“许可证”)许可;
# 除非遵循许可证,否则不得使用此文件。
# 可以在以下网址获取许可证的副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律要求或书面协议另有规定,按照许可证分发的软件是以“原样”基础提供的,
# 不提供任何形式的保证或条件,无论是明示的还是暗示的。
# 有关许可证下权限和限制的具体语言,请参见许可证。
from typing import Any, Dict, List, Optional, Union # 从 typing 模块导入各种类型注释
import torch # 导入 PyTorch 库
import torch.nn as nn # 导入 PyTorch 的神经网络模块,并命名为 nn
from ...configuration_utils import ConfigMixin, register_to_config # 从配置工具导入配置混合类和注册函数
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin # 从加载器导入模型混合类
from ...models.attention import JointTransformerBlock # 从注意力模块导入联合变换器块
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 # 导入不同的注意力处理器
from ...models.modeling_utils import ModelMixin # 导入模型混合类
from ...models.normalization import AdaLayerNormContinuous # 导入自适应层归一化模块
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers # 导入工具函数和变量
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed # 从嵌入模块导入嵌入类
from ..modeling_outputs import Transformer2DModelOutput # 导入变换器 2D 模型输出类
logger = logging.get_logger(__name__) # 创建一个记录器实例,名称为当前模块名,禁用 pylint 对名称的警告
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): # 定义 SD3 变换器 2D 模型类,继承多个混合类
"""
Stable Diffusion 3 中引入的变换器模型。
参考文献: https://arxiv.org/abs/2403.03206
参数:
sample_size (`int`): 潜在图像的宽度。训练期间固定使用,因为
它用于学习一组位置嵌入。
patch_size (`int`): 将输入数据转化为小块的块大小。
in_channels (`int`, *可选*, 默认为 16): 输入的通道数量。
num_layers (`int`, *可选*, 默认为 18): 使用的变换器块层数。
attention_head_dim (`int`, *可选*, 默认为 64): 每个头的通道数量。
num_attention_heads (`int`, *可选*, 默认为 18): 多头注意力使用的头数。
cross_attention_dim (`int`, *可选*): 用于 `encoder_hidden_states` 维度的数量。
caption_projection_dim (`int`): 用于投影 `encoder_hidden_states` 的维度数量。
pooled_projection_dim (`int`): 用于投影 `pooled_projections` 的维度数量。
out_channels (`int`, 默认为 16): 输出通道的数量。
"""
_supports_gradient_checkpointing = True # 表示模型支持梯度检查点功能
@register_to_config # 使用装饰器将此方法注册到配置中
# 初始化方法,设置模型的基本参数
def __init__(
self,
sample_size: int = 128, # 输入样本的大小,默认值为128
patch_size: int = 2, # 每个补丁的大小,默认值为2
in_channels: int = 16, # 输入通道数,默认值为16
num_layers: int = 18, # Transformer层的数量,默认值为18
attention_head_dim: int = 64, # 每个注意力头的维度,默认值为64
num_attention_heads: int = 18, # 注意力头的数量,默认值为18
joint_attention_dim: int = 4096, # 联合注意力维度,默认值为4096
caption_projection_dim: int = 1152, # 标题投影维度,默认值为1152
pooled_projection_dim: int = 2048, # 池化投影维度,默认值为2048
out_channels: int = 16, # 输出通道数,默认值为16
pos_embed_max_size: int = 96, # 位置嵌入的最大大小,默认值为96
):
super().__init__() # 调用父类的初始化方法
default_out_channels = in_channels # 设置默认的输出通道为输入通道数
# 如果指定输出通道,则使用指定值,否则使用默认值
self.out_channels = out_channels if out_channels is not None else default_out_channels
# 计算内部维度,等于注意力头数量乘以每个注意力头的维度
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
# 创建位置嵌入模块,用于将输入图像转为嵌入表示
self.pos_embed = PatchEmbed(
height=self.config.sample_size, # 高度设置为样本大小
width=self.config.sample_size, # 宽度设置为样本大小
patch_size=self.config.patch_size, # 补丁大小
in_channels=self.config.in_channels, # 输入通道数
embed_dim=self.inner_dim, # 嵌入维度
pos_embed_max_size=pos_embed_max_size, # 当前硬编码位置嵌入最大大小
)
# 创建时间与文本嵌入的组合模块
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, # 嵌入维度
pooled_projection_dim=self.config.pooled_projection_dim # 池化投影维度
)
# 创建线性层,用于将上下文信息映射到标题投影维度
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
# 创建Transformer块的列表
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
dim=self.inner_dim, # 输入维度为内部维度
num_attention_heads=self.config.num_attention_heads, # 注意力头的数量
attention_head_dim=self.config.attention_head_dim, # 每个注意力头的维度
context_pre_only=i == num_layers - 1, # 仅在最后一层设置上下文优先
)
for i in range(self.config.num_layers) # 遍历创建每一层
]
)
# 创建自适应层归一化层
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
# 创建线性层,用于输出映射
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
# 设置梯度检查点开关,默认值为False
self.gradient_checkpointing = False
# 从diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking复制的方法
# 定义一个启用前馈分块的函数,接受可选的分块大小和维度
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
设置注意力处理器使用前馈分块机制。
参数:
chunk_size (`int`, *可选*):
前馈层的分块大小。如果未指定,将单独在维度为`dim`的每个张量上运行前馈层。
dim (`int`, *可选*, 默认值为`0`):
前馈计算应分块的维度。选择dim=0(批量)或dim=1(序列长度)。
"""
# 如果维度不是0或1,则抛出值错误
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# 默认分块大小为1
chunk_size = chunk_size or 1
# 定义递归前馈函数,接受模块、分块大小和维度作为参数
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有设置分块前馈的属性,则调用该方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的子模块并递归调用前馈函数
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的子模块,应用递归前馈函数
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# 从diffusers.models.unets.unet_3d_condition复制的方法,禁用前馈分块
def disable_forward_chunking(self):
# 定义递归前馈函数,接受模块、分块大小和维度作为参数
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
# 如果模块有设置分块前馈的属性,则调用该方法
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
# 遍历模块的子模块并递归调用前馈函数
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
# 遍历当前对象的子模块,应用递归前馈函数,分块大小为None,维度为0
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
@property
# 从diffusers.models.unets.unet_2d_condition复制的属性,获取注意力处理器
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
返回:
`dict`类型的注意力处理器:一个包含模型中所有注意力处理器的字典,以其权重名称索引。
"""
# 初始化处理器字典
processors = {}
# 定义递归添加处理器的函数
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# 如果模块有获取处理器的方法,则将其添加到处理器字典
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
# 遍历子模块并递归调用添加处理器函数
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
# 遍历当前对象的子模块,应用递归添加处理器函数
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
# 返回处理器字典
return processors
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 复制而来
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
设置用于计算注意力的处理器。
参数:
processor(`dict` 类型的 `AttentionProcessor` 或仅为 `AttentionProcessor`):
实例化的处理器类或一个处理器类的字典,将被设置为 **所有** `Attention` 层的处理器。
如果 `processor` 是一个字典,键需要定义相应的交叉注意力处理器的路径。
在设置可训练的注意力处理器时,强烈建议使用这种方式。
"""
# 获取当前注意力处理器的数量
count = len(self.attn_processors.keys())
# 如果传入的处理器是字典且数量与注意力层数量不匹配,抛出异常
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"传入了处理器字典,但处理器数量 {len(processor)} 与注意力层数量 {count} 不匹配。请确保传入 {count} 个处理器类。"
)
# 定义递归函数,用于设置每个模块的注意力处理器
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
# 检查模块是否具有设置处理器的方法
if hasattr(module, "set_processor"):
# 如果处理器不是字典,直接设置
if not isinstance(processor, dict):
module.set_processor(processor)
else:
# 从字典中弹出相应的处理器并设置
module.set_processor(processor.pop(f"{name}.processor"))
# 遍历模块的子模块,递归调用
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
# 遍历当前实例的子模块,并为每个模块设置处理器
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 复制而来
def fuse_qkv_projections(self):
"""
启用融合的 QKV 投影。对于自注意力模块,所有投影矩阵(即查询、键、值)被融合。
对于交叉注意力模块,键和值投影矩阵被融合。
<Tip warning={true}>
此 API 是 🧪 实验性。
</Tip>
"""
# 初始化原始注意力处理器为 None
self.original_attn_processors = None
# 检查每个注意力处理器是否包含 "Added"
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
# 如果发现不支持的处理器,抛出异常
raise ValueError("`fuse_qkv_projections()` 不支持具有添加 KV 投影的模型。")
# 将当前的注意力处理器保存为原始处理器
self.original_attn_processors = self.attn_processors
# 遍历所有模块,如果模块是 Attention 类型,则进行投影融合
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# 设置注意力处理器为 FusedJointAttnProcessor2_0 的实例
self.set_attn_processor(FusedJointAttnProcessor2_0())
# 从 diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 复制而来
# 定义一个方法,用于禁用已启用的融合 QKV 投影
def unfuse_qkv_projections(self):
"""禁用融合的 QKV 投影(如果已启用)。
<Tip warning={true}>
此 API 为 🧪 实验性。
</Tip>
"""
# 检查原始注意力处理器是否存在
if self.original_attn_processors is not None:
# 将当前的注意力处理器设置为原始的
self.set_attn_processor(self.original_attn_processors)
# 定义一个私有方法,用于设置梯度检查点
def _set_gradient_checkpointing(self, module, value=False):
# 检查模块是否具有梯度检查点属性
if hasattr(module, "gradient_checkpointing"):
# 将梯度检查点属性设置为指定值
module.gradient_checkpointing = value
# 定义前向传播方法,接受多个输入参数
def forward(
self,
hidden_states: torch.FloatTensor, # 输入的隐藏状态张量
encoder_hidden_states: torch.FloatTensor = None, # 编码器的隐藏状态张量,可选
pooled_projections: torch.FloatTensor = None, # 池化后的投影张量,可选
timestep: torch.LongTensor = None, # 时间步长张量,可选
block_controlnet_hidden_states: List = None, # 控制网的隐藏状态列表,可选
joint_attention_kwargs: Optional[Dict[str, Any]] = None, # 联合注意力的额外参数,可选
return_dict: bool = True, # 指示是否返回字典格式的结果,默认为 True
# 版权声明,2024年HuggingFace团队所有,保留所有权利。
#
# 根据Apache许可证第2.0版("许可证")授权;
# 除非符合许可证,否则不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,
# 否则根据许可证分发的软件是在"按原样"的基础上提供的,
# 不提供任何形式的明示或暗示的保证或条件。
# 有关许可证具体条款的信息,
# 请参阅许可证中的权限和限制。
# 导入dataclass装饰器,用于简化类的定义
from dataclasses import dataclass
# 导入Any、Dict和Optional类型,用于类型注解
from typing import Any, Dict, Optional
# 导入PyTorch库
import torch
# 从torch库中导入神经网络模块
from torch import nn
# 从配置工具中导入ConfigMixin类和注册配置函数
from ...configuration_utils import ConfigMixin, register_to_config
# 从工具模块中导入BaseOutput类
from ...utils import BaseOutput
# 从注意力模块中导入基本变换器块和时间基本变换器块
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
# 从嵌入模块中导入时间步嵌入和时间步类
from ..embeddings import TimestepEmbedding, Timesteps
# 从模型工具中导入ModelMixin类
from ..modeling_utils import ModelMixin
# 从ResNet模块中导入AlphaBlender类
from ..resnet import AlphaBlender
# 定义TransformerTemporalModelOutput类,继承自BaseOutput
@dataclass
class TransformerTemporalModelOutput(BaseOutput):
"""
[`TransformerTemporalModel`]的输出。
参数:
sample (`torch.Tensor`形状为`(batch_size x num_frames, num_channels, height, width)`):
基于`encoder_hidden_states`输入条件的隐藏状态输出。
"""
# 定义sample属性,类型为torch.Tensor
sample: torch.Tensor
# 定义TransformerTemporalModel类,继承自ModelMixin和ConfigMixin
class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
适用于视频类数据的变换器模型。
# 参数说明
Parameters:
# 多头注意力的头数,默认为 16
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
# 每个头中的通道数,默认为 88
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
# 输入和输出的通道数(如果输入为 **连续**,则需要指定)
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
# Transformer 块的层数,默认为 1
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
# 使用的 dropout 概率,默认为 0.0
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
# 使用的 `encoder_hidden_states` 维度数
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
# 配置 `TransformerBlock` 的注意力是否应包含偏置参数
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
# 潜在图像的宽度(如果输入为 **离散**,则需要指定)
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
# 在训练期间固定使用,以学习位置嵌入的数量
This is fixed during training since it is used to learn a number of position embeddings.
# 前馈中使用的激活函数,默认为 "geglu"
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
# 配置 `TransformerBlock` 是否应使用可学习的逐元素仿射参数进行归一化
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
# 配置每个 `TransformerBlock` 是否应包含两个自注意力层
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
# 应用到序列输入之前的位置信息嵌入类型
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
# 应用位置嵌入的最大序列长度
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
# 注册到配置
@register_to_config
def __init__(
# 初始化函数的参数
self,
# 多头注意力的头数,默认为 16
num_attention_heads: int = 16,
# 每个头中的通道数,默认为 88
attention_head_dim: int = 88,
# 输入通道数(可选)
in_channels: Optional[int] = None,
# 输出通道数(可选)
out_channels: Optional[int] = None,
# Transformer 块的层数,默认为 1
num_layers: int = 1,
# dropout 概率,默认为 0.0
dropout: float = 0.0,
# 归一化组数,默认为 32
norm_num_groups: int = 32,
# 使用的 `encoder_hidden_states` 维度数(可选)
cross_attention_dim: Optional[int] = None,
# 注意力是否包含偏置参数,默认为 False
attention_bias: bool = False,
# 潜在图像的宽度(可选)
sample_size: Optional[int] = None,
# 激活函数,默认为 "geglu"
activation_fn: str = "geglu",
# 是否使用可学习的逐元素仿射参数,默认为 True
norm_elementwise_affine: bool = True,
# 是否包含两个自注意力层,默认为 True
double_self_attention: bool = True,
# 位置信息嵌入类型(可选)
positional_embeddings: Optional[str] = None,
# 最大序列长度(可选)
num_positional_embeddings: Optional[int] = None,
):
# 调用父类的初始化方法
super().__init__()
# 设置注意力头的数量
self.num_attention_heads = num_attention_heads
# 设置每个注意力头的维度
self.attention_head_dim = attention_head_dim
# 计算内部维度,即注意力头数量与每个注意力头维度的乘积
inner_dim = num_attention_heads * attention_head_dim
# 设置输入通道数
self.in_channels = in_channels
# 定义组归一化层,指定组数量、通道数、稳定数和是否启用仿射变换
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
# 定义输入线性变换层,映射输入通道到内部维度
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. 定义变换器块
self.transformer_blocks = nn.ModuleList(
[
# 创建基本变换器块,并传入相应参数
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
# 重复创建多个变换器块,数量由 num_layers 决定
for d in range(num_layers)
]
)
# 定义输出线性变换层,将内部维度映射回输入通道数
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
# 定义前向传播方法的输入参数
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
# 定义一个用于视频类数据的 Transformer 模型
class TransformerSpatioTemporalModel(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
out_channels (`int`, *optional*):
The number of channels in the output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
"""
# 初始化函数,设置模型参数
def __init__(
self,
num_attention_heads: int = 16, # 多头注意力机制的头数,默认为16
attention_head_dim: int = 88, # 每个头的通道数,默认为88
in_channels: int = 320, # 输入和输出的通道数,默认为320
out_channels: Optional[int] = None, # 输出的通道数,如果输入是连续的则需要指定
num_layers: int = 1, # Transformer 块的层数,默认为1
cross_attention_dim: Optional[int] = None, # 编码器隐藏状态的维度
):
super().__init__() # 调用父类的初始化方法
self.num_attention_heads = num_attention_heads # 保存多头注意力的头数
self.attention_head_dim = attention_head_dim # 保存每个头的通道数
inner_dim = num_attention_heads * attention_head_dim # 计算内部维度
self.inner_dim = inner_dim # 保存内部维度
# 2. 定义输入层
self.in_channels = in_channels # 保存输入通道数
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) # 定义分组归一化层
self.proj_in = nn.Linear(in_channels, inner_dim) # 定义输入的线性变换
# 3. 定义 Transformer 块
self.transformer_blocks = nn.ModuleList( # 创建 Transformer 块的模块列表
[
BasicTransformerBlock( # 实例化基本的 Transformer 块
inner_dim, # 传入内部维度
num_attention_heads, # 传入多头注意力的头数
attention_head_dim, # 传入每个头的通道数
cross_attention_dim=cross_attention_dim, # 传入交叉注意力维度
)
for d in range(num_layers) # 根据层数创建多个块
]
)
time_mix_inner_dim = inner_dim # 定义时间混合内部维度
self.temporal_transformer_blocks = nn.ModuleList( # 创建时间 Transformer 块的模块列表
[
TemporalBasicTransformerBlock( # 实例化时间基本 Transformer 块
inner_dim, # 传入内部维度
time_mix_inner_dim, # 传入时间混合内部维度
num_attention_heads, # 传入多头注意力的头数
attention_head_dim, # 传入每个头的通道数
cross_attention_dim=cross_attention_dim, # 传入交叉注意力维度
)
for _ in range(num_layers) # 根据层数创建多个块
]
)
time_embed_dim = in_channels * 4 # 定义时间嵌入的维度
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) # 创建时间步嵌入
self.time_proj = Timesteps(in_channels, True, 0) # 定义时间投影
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") # 定义时间混合器
# 4. 定义输出层
self.out_channels = in_channels if out_channels is None else out_channels # 确定输出通道数
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels) # 定义输出的线性变换
self.gradient_checkpointing = False # 是否启用梯度检查点,默认为False
# 定义一个名为 forward 的方法
def forward(
# 输入参数 hidden_states,类型为 torch.Tensor
self,
hidden_states: torch.Tensor,
# 可选输入参数 encoder_hidden_states,类型为 torch.Tensor,默认为 None
encoder_hidden_states: Optional[torch.Tensor] = None,
# 可选输入参数 image_only_indicator,类型为 torch.Tensor,默认为 None
image_only_indicator: Optional[torch.Tensor] = None,
# 可选输入参数 return_dict,类型为 bool,默认为 True
return_dict: bool = True,
# 从 utils 模块导入判断是否可用的函数
from ...utils import is_torch_available
# 检查是否可用 Torch 库
if is_torch_available():
# 从当前包导入 2D AuraFlow Transformer 模型
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
# 从当前包导入 3D CogVideoX Transformer 模型
from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
# 从当前包导入 2D DiT Transformer 模型
from .dit_transformer_2d import DiTTransformer2DModel
# 从当前包导入 2D Dual Transformer 模型
from .dual_transformer_2d import DualTransformer2DModel
# 从当前包导入 2D Hunyuan DiT Transformer 模型
from .hunyuan_transformer_2d import HunyuanDiT2DModel
# 从当前包导入 3D Latte Transformer 模型
from .latte_transformer_3d import LatteTransformer3DModel
# 从当前包导入 2D Lumina Next DiT Transformer 模型
from .lumina_nextdit2d import LuminaNextDiT2DModel
# 从当前包导入 2D PixArt Transformer 模型
from .pixart_transformer_2d import PixArtTransformer2DModel
# 从当前包导入 Prior Transformer 模型
from .prior_transformer import PriorTransformer
# 从当前包导入 Stable Audio DiT Transformer 模型
from .stable_audio_transformer import StableAudioDiTModel
# 从当前包导入 T5 Film Decoder 模型
from .t5_film_transformer import T5FilmDecoder
# 从当前包导入 2D Transformer 模型
from .transformer_2d import Transformer2DModel
# 从当前包导入 2D Flux Transformer 模型
from .transformer_flux import FluxTransformer2DModel
# 从当前包导入 2D SD3 Transformer 模型
from .transformer_sd3 import SD3Transformer2DModel
# 从当前包导入 Temporal Transformer 模型
from .transformer_temporal import TransformerTemporalModel
.\diffusers\models\unets\unet_1d.py
# 版权声明,声明此文件归 HuggingFace 团队所有,所有权利保留。
#
# 根据 Apache 许可证 2.0 版(“许可证”)进行许可;
# 除非遵循许可证,否则您不得使用此文件。
# 您可以在以下地址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面协议另有约定,软件按“原样”分发,
# 不提供任何明示或暗示的保证或条件。
# 请参阅许可证以了解有关权限的具体语言和
# 限制条款。
#
# 从 dataclasses 模块导入 dataclass 装饰器,用于创建数据类
from dataclasses import dataclass
# 从 typing 模块导入 Optional, Tuple, Union 类型提示
from typing import Optional, Tuple, Union
# 导入 PyTorch 库,用于深度学习
import torch
# 导入 PyTorch 的神经网络模块
import torch.nn as nn
# 从配置工具模块导入 ConfigMixin 和 register_to_config 以实现配置功能
from ...configuration_utils import ConfigMixin, register_to_config
# 从 utils 模块导入 BaseOutput 类,作为输出基类
from ...utils import BaseOutput
# 从 embeddings 模块导入 GaussianFourierProjection, TimestepEmbedding, Timesteps,用于处理嵌入
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
# 从 modeling_utils 模块导入 ModelMixin,作为模型混合基类
from ..modeling_utils import ModelMixin
# 从 unet_1d_blocks 模块导入用于构建 UNet 1D 的各个块
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
# 定义 UNet1DOutput 数据类,继承自 BaseOutput
@dataclass
class UNet1DOutput(BaseOutput):
"""
[`UNet1DModel`] 的输出。
参数:
sample (`torch.Tensor`,形状为 `(batch_size, num_channels, sample_size)`):
模型最后一层输出的隐藏状态。
"""
# 模型输出的样本张量
sample: torch.Tensor
# 定义 UNet1DModel 类,继承自 ModelMixin 和 ConfigMixin
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
1D UNet 模型,接收噪声样本和时间步并返回形状的输出样本。
该模型继承自 [`ModelMixin`]。请查看超类文档,以获取其实现的所有模型的通用方法(例如下载或保存)。
# 参数说明部分,列出可选参数及其默认值
Parameters:
# 默认样本长度,可在运行时适应的整型参数
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
# 输入样本的通道数,默认值为2
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
# 输出的通道数,默认值为2
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
# 附加的输入通道数,默认值为0
extra_in_channels (`int`, *optional*, defaults to 0):
# 首个下采样块输入中额外通道的数量,用于处理输入数据通道多于模型设计时的情况
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
# 时间嵌入类型,默认值为"fourier"
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
# 傅里叶时间嵌入的频率偏移,默认值为0.0
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
# 是否将正弦函数翻转为余弦函数,默认值为False
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
# 对于傅里叶时间嵌入,是否翻转sin为cos。
Whether to flip sin to cos for Fourier time embedding.
# 下采样块类型的元组,默认值为指定的块类型
down_block_types (`Tuple[str]`, *optional*, defaults to ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")):
# 下采样块类型的元组。
Tuple of downsample block types.
# 上采样块类型的元组,默认值为指定的块类型
up_block_types (`Tuple[str]`, *optional*, defaults to ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")):
# 上采样块类型的元组。
Tuple of upsample block types.
# 块输出通道的元组,默认值为(32, 32, 64)
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
# 块输出通道的元组。
Tuple of block output channels.
# UNet中间块的类型,默认值为"UNetMidBlock1D"
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
# UNet的可选输出处理块,默认值为None
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
# UNet块中的可选激活函数,默认值为None
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
# 归一化的组数,默认值为8
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
# 每个块的层数,默认值为1
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
# 每个块是否下采样,默认值为False
downsample_each_block (`int`, *optional*, defaults to False):
# 用于不进行上采样的UNet的实验特性。
Experimental feature for using a UNet without upsampling.
"""
# 装饰器,注册到配置中
@register_to_config
def __init__(
# 初始化方法,定义各种参数及其默认值
self,
# 默认样本大小为65536
sample_size: int = 65536,
# 可选样本速率,默认为None
sample_rate: Optional[int] = None,
# 输入通道数,默认值为2
in_channels: int = 2,
# 输出通道数,默认值为2
out_channels: int = 2,
# 附加输入通道数,默认值为0
extra_in_channels: int = 0,
# 时间嵌入类型,默认值为"fourier"
time_embedding_type: str = "fourier",
# 是否翻转正弦为余弦,默认值为True
flip_sin_to_cos: bool = True,
# 是否使用时间步长嵌入,默认值为False
use_timestep_embedding: bool = False,
# 傅里叶时间嵌入的频率偏移,默认值为0.0
freq_shift: float = 0.0,
# 下采样块类型的元组,默认值为指定的块类型
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
# 上采样块类型的元组,默认值为指定的块类型
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
# 中间块的类型,默认值为"UNetMidBlock1D"
mid_block_type: Tuple[str] = "UNetMidBlock1D",
# 可选的输出处理块,默认值为None
out_block_type: str = None,
# 块输出通道的元组,默认值为(32, 32, 64)
block_out_channels: Tuple[int] = (32, 32, 64),
# 可选激活函数,默认值为None
act_fn: str = None,
# 归一化的组数,默认值为8
norm_num_groups: int = 8,
# 每个块的层数,默认值为1
layers_per_block: int = 1,
# 是否下采样,默认值为False
downsample_each_block: bool = False,
# 定义 UNet1DModel 的前向传播方法
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
UNet1DModel 的前向传播方法。
参数:
sample (`torch.Tensor`):
噪声输入张量,形状为 `(batch_size, num_channels, sample_size)`。
timestep (`torch.Tensor` 或 `float` 或 `int`): 用于去噪输入的时间步数。
return_dict (`bool`, *可选*, 默认为 `True`):
是否返回 [`~models.unets.unet_1d.UNet1DOutput`] 而不是普通元组。
返回:
[`~models.unets.unet_1d.UNet1DOutput`] 或 `tuple`:
如果 `return_dict` 为 True,则返回 [`~models.unets.unet_1d.UNet1DOutput`],否则返回一个元组,
其中第一个元素是样本张量。
"""
# 1. 时间处理
timesteps = timestep
# 检查 timesteps 是否为张量,如果不是,则将其转换为张量
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
# 如果 timesteps 是张量且没有形状,则扩展其维度
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 对时间步进行嵌入处理
timestep_embed = self.time_proj(timesteps)
# 如果使用时间步嵌入,则通过 MLP 进行处理
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
# 否则,调整嵌入的形状
else:
timestep_embed = timestep_embed[..., None]
# 重复嵌入以匹配样本的大小
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 广播嵌入以匹配样本的形状
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. 向下采样
down_block_res_samples = ()
# 遍历下采样块
for downsample_block in self.down_blocks:
# 在下采样块中处理样本和时间嵌入
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
# 收集残差样本
down_block_res_samples += res_samples
# 3. 中间块处理
if self.mid_block:
# 如果存在中间块,则进行处理
sample = self.mid_block(sample, timestep_embed)
# 4. 向上采样
for i, upsample_block in enumerate(self.up_blocks):
# 获取最后一个残差样本
res_samples = down_block_res_samples[-1:]
# 移除最后一个残差样本
down_block_res_samples = down_block_res_samples[:-1]
# 在上采样块中处理样本和时间嵌入
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. 后处理
if self.out_block:
# 如果存在输出块,则进行处理
sample = self.out_block(sample, timestep_embed)
# 如果不需要返回字典,则返回样本元组
if not return_dict:
return (sample,)
# 返回 UNet1DOutput 对象
return UNet1DOutput(sample=sample)
.\diffusers\models\unets\unet_1d_blocks.py
# 版权声明,指定版权所有者及其保留权利
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# 根据 Apache 许可证,版本 2.0(“许可证”)进行许可;
# 除非遵守许可证,否则不得使用此文件。
# 可以在以下网址获取许可证副本:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# 除非适用法律或书面同意,否则根据许可证分发的软件是按“现状”基础提供的,
# 不提供任何形式的明示或暗示的担保或条件。
# 请参阅许可证以获取管理权限和限制的具体条款。
import math # 导入数学库,以便使用数学函数
from typing import Optional, Tuple, Union # 导入类型注解工具
import torch # 导入 PyTorch 库
import torch.nn.functional as F # 导入 PyTorch 的函数式神经网络接口
from torch import nn # 从 PyTorch 导入神经网络模块
from ..activations import get_activation # 导入自定义激活函数获取工具
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims # 导入自定义的 ResNet 组件
class DownResnetBlock1D(nn.Module): # 定义一个一维下采样的 ResNet 模块
def __init__( # 初始化方法,定义模块的参数
self,
in_channels: int, # 输入通道数
out_channels: Optional[int] = None, # 输出通道数(可选)
num_layers: int = 1, # 残差层数,默认为1
conv_shortcut: bool = False, # 是否使用卷积快捷连接
temb_channels: int = 32, # 时间嵌入通道数
groups: int = 32, # 组数
groups_out: Optional[int] = None, # 输出组数(可选)
non_linearity: Optional[str] = None, # 非线性激活函数(可选)
time_embedding_norm: str = "default", # 时间嵌入的归一化方式
output_scale_factor: float = 1.0, # 输出缩放因子
add_downsample: bool = True, # 是否添加下采样层
):
super().__init__() # 调用父类初始化方法
self.in_channels = in_channels # 设置输入通道数
out_channels = in_channels if out_channels is None else out_channels # 如果未指定输出通道数,则设置为输入通道数
self.out_channels = out_channels # 设置输出通道数
self.use_conv_shortcut = conv_shortcut # 保存是否使用卷积快捷连接的标志
self.time_embedding_norm = time_embedding_norm # 设置时间嵌入的归一化方式
self.add_downsample = add_downsample # 保存是否添加下采样层的标志
self.output_scale_factor = output_scale_factor # 设置输出缩放因子
if groups_out is None: # 如果未指定输出组数
groups_out = groups # 设置输出组数为输入组数
# 始终至少有一个残差块
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] # 创建第一个残差块
for _ in range(num_layers): # 根据指定的层数添加残差块
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) # 添加后续的残差块
self.resnets = nn.ModuleList(resnets) # 将残差块列表转换为 PyTorch 模块列表
if non_linearity is None: # 如果未指定非线性激活函数
self.nonlinearity = None # 设置为 None
else:
self.nonlinearity = get_activation(non_linearity) # 获取指定的激活函数
self.downsample = None # 初始化下采样层为 None
if add_downsample: # 如果需要添加下采样层
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) # 创建下采样层
# 定义前向传播函数,接收隐藏状态和可选的时间嵌入,返回处理后的隐藏状态和输出状态
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 初始化一个空元组用于存储输出状态
output_states = ()
# 使用第一个残差网络处理输入的隐藏状态和时间嵌入
hidden_states = self.resnets[0](hidden_states, temb)
# 遍历后续的残差网络,逐个处理隐藏状态
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
# 将当前的隐藏状态添加到输出状态元组中
output_states += (hidden_states,)
# 如果非线性激活函数存在,则应用于隐藏状态
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
# 如果下采样层存在,则对隐藏状态进行下采样处理
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
# 返回处理后的隐藏状态和输出状态
return hidden_states, output_states
# 定义一个一维的上采样残差块类,继承自 nn.Module
class UpResnetBlock1D(nn.Module):
# 初始化方法,定义输入输出通道、层数等参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: Optional[int] = None, # 输出通道数,默认为 None
num_layers: int = 1, # 残差层数,默认为 1
temb_channels: int = 32, # 时间嵌入通道数
groups: int = 32, # 分组数
groups_out: Optional[int] = None, # 输出分组数,默认为 None
non_linearity: Optional[str] = None, # 非线性激活函数,默认为 None
time_embedding_norm: str = "default", # 时间嵌入归一化方式,默认为 "default"
output_scale_factor: float = 1.0, # 输出缩放因子
add_upsample: bool = True, # 是否添加上采样层,默认为 True
):
# 调用父类初始化方法
super().__init__()
self.in_channels = in_channels # 保存输入通道数
# 如果输出通道数为 None,则设置为输入通道数
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels # 保存输出通道数
self.time_embedding_norm = time_embedding_norm # 保存时间嵌入归一化方式
self.add_upsample = add_upsample # 保存是否添加上采样层
self.output_scale_factor = output_scale_factor # 保存输出缩放因子
# 如果输出分组数为 None,则设置为输入分组数
if groups_out is None:
groups_out = groups
# 初始化至少一个残差块
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
# 根据层数添加残差块
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
# 将残差块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据非线性激活函数的设置,初始化激活函数
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
# 初始化上采样层为 None
self.upsample = None
# 如果需要添加上采样层,则初始化它
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
# 前向传播方法
def forward(
self,
hidden_states: torch.Tensor, # 输入的隐藏状态
res_hidden_states_tuple: Optional[Tuple[torch.Tensor, ...]] = None, # 残差隐藏状态元组,默认为 None
temb: Optional[torch.Tensor] = None, # 时间嵌入,默认为 None
) -> torch.Tensor:
# 如果有残差隐藏状态,则将其与当前隐藏状态拼接
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1] # 取最后一个残差状态
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) # 拼接操作
# 通过第一个残差块处理隐藏状态
hidden_states = self.resnets[0](hidden_states, temb)
# 依次通过后续的残差块处理隐藏状态
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
# 如果有非线性激活函数,则应用它
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
# 如果有上采样层,则应用它
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义一个值函数中间块类,继承自 nn.Module
class ValueFunctionMidBlock1D(nn.Module):
# 初始化方法,定义输入输出通道和嵌入维度
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
# 调用父类初始化方法
super().__init__()
self.in_channels = in_channels # 保存输入通道数
self.out_channels = out_channels # 保存输出通道数
self.embed_dim = embed_dim # 保存嵌入维度
# 初始化第一个残差块
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
# 初始化第一个下采样层
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
# 初始化第二个残差块
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
# 初始化第二个下采样层
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
# 定义前向传播函数,接受输入张量和可选的嵌入张量,返回输出张量
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 将输入张量 x 通过第一个残差块处理,可能使用嵌入张量 temb
x = self.res1(x, temb)
# 将处理后的张量 x 通过第一个下采样层进行下采样
x = self.down1(x)
# 将下采样后的张量 x 通过第二个残差块处理,可能使用嵌入张量 temb
x = self.res2(x, temb)
# 将处理后的张量 x 通过第二个下采样层进行下采样
x = self.down2(x)
# 返回最终处理后的张量 x
return x
# 定义一个中间分辨率的时间块类,继承自 nn.Module
class MidResTemporalBlock1D(nn.Module):
# 初始化方法,定义该类的参数
def __init__(
self,
in_channels: int, # 输入通道数
out_channels: int, # 输出通道数
embed_dim: int, # 嵌入维度
num_layers: int = 1, # 层数,默认值为 1
add_downsample: bool = False, # 是否添加下采样,默认为 False
add_upsample: bool = False, # 是否添加上采样,默认为 False
non_linearity: Optional[str] = None, # 非线性激活函数的类型,默认为 None
):
# 调用父类构造函数
super().__init__()
# 设置输入通道数
self.in_channels = in_channels
# 设置输出通道数
self.out_channels = out_channels
# 设置是否添加下采样
self.add_downsample = add_downsample
# 至少会有一个残差网络
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
# 根据层数添加残差网络层
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
# 将残差网络层列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 根据是否提供非线性激活函数初始化相应属性
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
# 初始化上采样层为 None
self.upsample = None
# 如果添加上采样,则创建上采样层
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv=True)
# 初始化下采样层为 None
self.downsample = None
# 如果添加下采样,则创建下采样层
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
# 如果同时添加了上采样和下采样,抛出错误
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
# 定义前向传播方法
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
# 通过第一个残差网络处理隐藏状态
hidden_states = self.resnets[0](hidden_states, temb)
# 遍历其余的残差网络进行处理
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
# 如果有上采样层,则执行上采样
if self.upsample:
hidden_states = self.upsample(hidden_states)
# 如果有下采样层,则执行下采样
if self.downsample:
self.downsample = self.downsample(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义输出卷积块类,继承自 nn.Module
class OutConv1DBlock(nn.Module):
# 初始化方法,定义该类的参数
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
# 调用父类构造函数
super().__init__()
# 创建第一层 1D 卷积,kernel_size 为 5,padding 为 2
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
# 创建 GroupNorm 层,指定组数和嵌入维度
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
# 根据激活函数名称获取激活函数
self.final_conv1d_act = get_activation(act_fn)
# 创建第二层 1D 卷积,kernel_size 为 1
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
# 定义前向传播方法
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 通过第一层卷积处理隐藏状态
hidden_states = self.final_conv1d_1(hidden_states)
# 调整维度
hidden_states = rearrange_dims(hidden_states)
# 通过 GroupNorm 层处理
hidden_states = self.final_conv1d_gn(hidden_states)
# 再次调整维度
hidden_states = rearrange_dims(hidden_states)
# 通过激活函数处理
hidden_states = self.final_conv1d_act(hidden_states)
# 通过第二层卷积处理
hidden_states = self.final_conv1d_2(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义输出值函数块类,继承自 nn.Module
class OutValueFunctionBlock(nn.Module):
# 初始化方法,设置全连接层维度和激活函数
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
# 调用父类的初始化方法
super().__init__()
# 创建一个模块列表,包含线性层和激活函数
self.final_block = nn.ModuleList(
[
# 第一个线性层,将输入维度从 fc_dim + embed_dim 转换到 fc_dim // 2
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
# 获取指定的激活函数
get_activation(act_fn),
# 第二个线性层,将输入维度从 fc_dim // 2 转换到 1
nn.Linear(fc_dim // 2, 1),
]
)
# 前向传播方法,接受隐藏状态和额外的嵌入信息
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
# 重塑隐藏状态,使其成为二维张量
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
# 将重塑后的隐藏状态与额外的嵌入信息在最后一个维度上连接
hidden_states = torch.cat((hidden_states, temb), dim=-1)
# 遍历 final_block 中的每一层,逐层处理隐藏状态
for layer in self.final_block:
hidden_states = layer(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义包含不同内核函数系数的字典
_kernels = {
# 线性内核的系数
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
# 三次插值内核的系数
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
# Lanczos3 内核的系数
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
# 定义一维下采样的模块
class Downsample1d(nn.Module):
# 初始化方法,接受内核类型和填充模式
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
# 保存填充模式
self.pad_mode = pad_mode
# 根据内核类型创建一维内核的张量
kernel_1d = torch.tensor(_kernels[kernel])
# 计算填充大小
self.pad = kernel_1d.shape[0] // 2 - 1
# 注册内核张量为缓冲区
self.register_buffer("kernel", kernel_1d)
# 前向传播方法,处理输入的张量
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 对输入张量进行填充
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
# 创建权重张量,用于卷积操作
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
# 生成索引,用于选择权重
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
# 扩展内核张量以适应权重张量的形状
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
# 将内核填充到权重张量的对角线上
weight[indices, indices] = kernel
# 使用一维卷积对输入张量进行处理并返回结果
return F.conv1d(hidden_states, weight, stride=2)
# 定义一维上采样的模块
class Upsample1d(nn.Module):
# 初始化方法,接受内核类型和填充模式
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
# 保存填充模式
self.pad_mode = pad_mode
# 根据内核类型创建一维内核的张量,并乘以2以扩展作用
kernel_1d = torch.tensor(_kernels[kernel]) * 2
# 计算填充大小
self.pad = kernel_1d.shape[0] // 2 - 1
# 注册内核张量为缓冲区
self.register_buffer("kernel", kernel_1d)
# 前向传播方法,处理输入的张量
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 对输入张量进行填充
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
# 创建权重张量,用于反卷积操作
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
# 生成索引,用于选择权重
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
# 扩展内核张量以适应权重张量的形状
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
# 将内核填充到权重张量的对角线上
weight[indices, indices] = kernel
# 使用一维反卷积对输入张量进行处理并返回结果
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
# 定义一维自注意力模块
class SelfAttention1d(nn.Module):
# 初始化方法,接受输入通道数、头数和丢弃率
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
# 保存输入通道数
self.channels = in_channels
# 创建分组归一化层
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
# 保存头数
self.num_heads = n_head
# 定义查询、键、值的线性变换
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
# 定义注意力投影的线性变换
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
# 创建丢弃层
self.dropout = nn.Dropout(dropout_rate, inplace=True)
# 将输入投影张量进行转置以适应多头注意力机制
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
# 获取新的形状,将最后一个维度分割为头数和每个头的维度
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# 移动头的位置,调整形状从 (B, T, H * D) 变为 (B, T, H, D),再变为 (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
# 返回调整后的张量
return new_projection
# 前向传播函数,处理输入的隐藏状态
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 保存输入的隐藏状态以进行残差连接
residual = hidden_states
# 获取批量大小、通道维度和序列长度
batch, channel_dim, seq = hidden_states.shape
# 应用分组归一化到隐藏状态
hidden_states = self.group_norm(hidden_states)
# 转置隐藏状态的维度以便后续处理
hidden_states = hidden_states.transpose(1, 2)
# 通过查询、键和值的线性层投影隐藏状态
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# 转置查询、键和值的投影以适应注意力机制
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# 计算缩放因子以防止梯度消失
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
# 计算注意力得分,进行矩阵乘法
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
# 计算注意力概率分布
attention_probs = torch.softmax(attention_scores, dim=-1)
# 计算注意力输出
hidden_states = torch.matmul(attention_probs, value_states)
# 调整输出的维度顺序
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
# 获取新的隐藏状态形状以匹配通道数
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
# 重塑隐藏状态的形状
hidden_states = hidden_states.view(new_hidden_states_shape)
# 计算下一步的隐藏状态
hidden_states = self.proj_attn(hidden_states)
# 再次转置隐藏状态的维度
hidden_states = hidden_states.transpose(1, 2)
# 应用 dropout 正则化
hidden_states = self.dropout(hidden_states)
# 将最终输出与残差相加
output = hidden_states + residual
# 返回最终输出
return output
# 定义残差卷积块类,继承自 nn.Module
class ResConvBlock(nn.Module):
# 初始化函数,定义输入、中间和输出通道,以及是否为最后一层的标志
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
# 调用父类初始化方法
super().__init__()
# 设置是否为最后一层的标志
self.is_last = is_last
# 检查输入通道和输出通道是否相同,决定是否需要卷积跳跃连接
self.has_conv_skip = in_channels != out_channels
# 如果需要卷积跳跃连接,则定义 1D 卷积层
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
# 定义第一个卷积层,卷积核大小为 5,使用填充保持尺寸
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
# 定义第一个组归一化层
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
# 定义第一个 GELU 激活函数
self.gelu_1 = nn.GELU()
# 定义第二个卷积层,卷积核大小为 5,使用填充保持尺寸
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
# 如果不是最后一层,则定义第二个组归一化层和激活函数
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
# 前向传播函数
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# 如果有卷积跳跃连接,则对输入进行跳跃连接处理,否则直接使用输入
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
# 依次通过第一个卷积层、组归一化层和激活函数处理隐藏状态
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
# 通过第二个卷积层处理隐藏状态
hidden_states = self.conv_2(hidden_states)
# 如果不是最后一层,则继续通过第二个组归一化层和激活函数处理
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
# 将处理后的隐藏状态与残差相加,得到最终输出
output = hidden_states + residual
# 返回最终输出
return output
# 定义 UNet 中间块类,继承自 nn.Module
class UNetMidBlock1D(nn.Module):
# 初始化函数,定义中间通道、输入通道和可选输出通道
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
# 调用父类初始化方法
super().__init__()
# 如果未指定输出通道,则将输出通道设为输入通道
out_channels = in_channels if out_channels is None else out_channels
# 定义下采样模块,使用立方插值
self.down = Downsample1d("cubic")
# 创建包含多个残差卷积块的列表
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
# 创建包含自注意力模块的列表
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
# 定义上采样模块,使用立方插值
self.up = Upsample1d(kernel="cubic")
# 将自注意力模块和残差卷积块转换为模块列表
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
# 定义前向传播函数,接受隐藏状态和可选的时间嵌入
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 将隐藏状态通过下采样模块处理
hidden_states = self.down(hidden_states)
# 遍历注意力层和残差网络,对隐藏状态进行处理
for attn, resnet in zip(self.attentions, self.resnets):
# 先通过残差网络处理隐藏状态
hidden_states = resnet(hidden_states)
# 然后通过注意力层处理隐藏状态
hidden_states = attn(hidden_states)
# 将隐藏状态通过上采样模块处理
hidden_states = self.up(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义一个一维的注意力下采样块,继承自 nn.Module
class AttnDownBlock1D(nn.Module):
# 初始化方法,定义输入和输出通道数及中间通道数
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
# 调用父类的初始化方法
super().__init__()
# 如果中间通道数为 None,则设置为输出通道数
mid_channels = out_channels if mid_channels is None else mid_channels
# 创建一个下采样模块,采用三次插值法
self.down = Downsample1d("cubic")
# 定义残差卷积块的列表
resnets = [
# 第一个残差卷积块,输入通道为 in_channels,输出和中间通道为 mid_channels
ResConvBlock(in_channels, mid_channels, mid_channels),
# 第二个残差卷积块,输入和输出通道为 mid_channels
ResConvBlock(mid_channels, mid_channels, mid_channels),
# 第三个残差卷积块,输入通道为 mid_channels,输出通道为 out_channels
ResConvBlock(mid_channels, mid_channels, out_channels),
]
# 定义注意力模块的列表
attentions = [
# 第一个自注意力模块,输入通道为 mid_channels,输出通道为 mid_channels // 32
SelfAttention1d(mid_channels, mid_channels // 32),
# 第二个自注意力模块,输入通道为 mid_channels,输出通道为 mid_channels // 32
SelfAttention1d(mid_channels, mid_channels // 32),
# 第三个自注意力模块,输入通道为 out_channels,输出通道为 out_channels // 32
SelfAttention1d(out_channels, out_channels // 32),
]
# 将注意力模块列表封装成 nn.ModuleList
self.attentions = nn.ModuleList(attentions)
# 将残差卷积块列表封装成 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 前向传播方法,定义输入和输出
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 对隐藏状态进行下采样
hidden_states = self.down(hidden_states)
# 遍历残差卷积块和注意力模块
for resnet, attn in zip(self.resnets, self.attentions):
# 通过残差卷积块处理隐藏状态
hidden_states = resnet(hidden_states)
# 通过注意力模块处理隐藏状态
hidden_states = attn(hidden_states)
# 返回处理后的隐藏状态和一个包含隐藏状态的元组
return hidden_states, (hidden_states,)
# 定义一个一维的下采样块,继承自 nn.Module
class DownBlock1D(nn.Module):
# 初始化方法,定义输入和输出通道数及中间通道数
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
# 调用父类的初始化方法
super().__init__()
# 如果中间通道数为 None,则设置为输出通道数
mid_channels = out_channels if mid_channels is None else mid_channels
# 创建一个下采样模块,采用三次插值法
self.down = Downsample1d("cubic")
# 定义残差卷积块的列表
resnets = [
# 第一个残差卷积块,输入通道为 in_channels,输出和中间通道为 mid_channels
ResConvBlock(in_channels, mid_channels, mid_channels),
# 第二个残差卷积块,输入和输出通道为 mid_channels
ResConvBlock(mid_channels, mid_channels, mid_channels),
# 第三个残差卷积块,输入通道为 mid_channels,输出通道为 out_channels
ResConvBlock(mid_channels, mid_channels, out_channels),
]
# 将残差卷积块列表封装成 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 前向传播方法,定义输入和输出
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 对隐藏状态进行下采样
hidden_states = self.down(hidden_states)
# 遍历残差卷积块
for resnet in self.resnets:
# 通过残差卷积块处理隐藏状态
hidden_states = resnet(hidden_states)
# 返回处理后的隐藏状态和一个包含隐藏状态的元组
return hidden_states, (hidden_states,)
# 定义一个没有跳过连接的一维下采样块,继承自 nn.Module
class DownBlock1DNoSkip(nn.Module):
# 初始化方法,定义输入和输出通道数及中间通道数
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
# 调用父类的初始化方法
super().__init__()
# 如果中间通道数为 None,则设置为输出通道数
mid_channels = out_channels if mid_channels is None else mid_channels
# 定义残差卷积块的列表
resnets = [
# 第一个残差卷积块,输入通道为 in_channels,输出和中间通道为 mid_channels
ResConvBlock(in_channels, mid_channels, mid_channels),
# 第二个残差卷积块,输入和输出通道为 mid_channels
ResConvBlock(mid_channels, mid_channels, mid_channels),
# 第三个残差卷积块,输入通道为 mid_channels,输出通道为 out_channels
ResConvBlock(mid_channels, mid_channels, out_channels),
]
# 将残差卷积块列表封装成 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 前向传播方法,定义输入和输出
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
# 将隐藏状态和 temb 在通道维度上拼接
hidden_states = torch.cat([hidden_states, temb], dim=1)
# 遍历残差卷积块
for resnet in self.resnets:
# 通过残差卷积块处理隐藏状态
hidden_states = resnet(hidden_states)
# 返回处理后的隐藏状态和一个包含隐藏状态的元组
return hidden_states, (hidden_states,)
# 定义一个一维的注意力上采样块,继承自 nn.Module
class AttnUpBlock1D(nn.Module):
# 初始化方法,用于创建类的实例,设置输入、输出和中间通道数
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
# 调用父类初始化方法
super().__init__()
# 如果中间通道数未提供,则将其设置为输出通道数
mid_channels = out_channels if mid_channels is None else mid_channels
# 创建残差卷积块列表,配置输入、中间和输出通道数
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
# 创建自注意力层列表,配置通道数
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
# 将注意力层添加到模块列表中,以便在前向传播中使用
self.attentions = nn.ModuleList(attentions)
# 将残差卷积块添加到模块列表中,以便在前向传播中使用
self.resnets = nn.ModuleList(resnets)
# 初始化上采样层,使用立方插值
self.up = Upsample1d(kernel="cubic")
# 前向传播方法,定义输入张量和输出张量之间的计算
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 获取残差隐藏状态元组中的最后一个状态
res_hidden_states = res_hidden_states_tuple[-1]
# 将隐藏状态与残差隐藏状态在通道维度上拼接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 遍历残差块和注意力层并依次处理隐藏状态
for resnet, attn in zip(self.resnets, self.attentions):
# 使用残差块处理隐藏状态
hidden_states = resnet(hidden_states)
# 使用注意力层处理隐藏状态
hidden_states = attn(hidden_states)
# 对处理后的隐藏状态进行上采样
hidden_states = self.up(hidden_states)
# 返回最终的隐藏状态
return hidden_states
# 定义一维上采样块的类,继承自 nn.Module
class UpBlock1D(nn.Module):
# 初始化方法,接收输入通道、输出通道和中间通道的参数
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
# 调用父类初始化方法
super().__init__()
# 如果中间通道为 None,则将其设置为输入通道数
mid_channels = in_channels if mid_channels is None else mid_channels
# 定义包含三个残差卷积块的列表
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels), # 第一个残差块,输入通道是输入通道的两倍
ResConvBlock(mid_channels, mid_channels, mid_channels), # 第二个残差块,输入输出通道均为中间通道
ResConvBlock(mid_channels, mid_channels, out_channels), # 第三个残差块,输出通道为目标输出通道
]
# 将残差块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 定义一维上采样层,使用立方插值核
self.up = Upsample1d(kernel="cubic")
# 前向传播方法,接收隐藏状态和残差隐藏状态元组
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 获取最后一个残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 将隐藏状态和残差隐藏状态在通道维度上连接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 遍历每个残差块进行前向传播
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
# 对隐藏状态进行上采样
hidden_states = self.up(hidden_states)
# 返回上采样后的隐藏状态
return hidden_states
# 定义不使用跳过连接的一维上采样块的类,继承自 nn.Module
class UpBlock1DNoSkip(nn.Module):
# 初始化方法,接收输入通道、输出通道和中间通道的参数
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
# 调用父类初始化方法
super().__init__()
# 如果中间通道为 None,则将其设置为输入通道数
mid_channels = in_channels if mid_channels is None else mid_channels
# 定义包含三个残差卷积块的列表
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels), # 第一个残差块,输入通道是输入通道的两倍
ResConvBlock(mid_channels, mid_channels, mid_channels), # 第二个残差块,输入输出通道均为中间通道
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), # 第三个残差块,输出通道为目标输出通道,标记为最后一个块
]
# 将残差块列表转换为 nn.ModuleList
self.resnets = nn.ModuleList(resnets)
# 前向传播方法,接收隐藏状态和残差隐藏状态元组
def forward(
self,
hidden_states: torch.Tensor,
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 获取最后一个残差隐藏状态
res_hidden_states = res_hidden_states_tuple[-1]
# 将隐藏状态和残差隐藏状态在通道维度上连接
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
# 遍历每个残差块进行前向传播
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
# 返回处理后的隐藏状态
return hidden_states
# 定义各种下采样块的类型
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
# 定义各种中间块的类型
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
# 定义各种输出块的类型
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
# 定义各种上采样块的类型
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
# 根据类型获取对应的下采样块
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
# 如果指定的下采样块类型为 DownResnetBlock1D,返回相应的块
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
# 如果指定的下采样块类型为 DownBlock1D,返回相应的块
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
# 检查下采样块类型是否为 "AttnDownBlock1D"
elif down_block_type == "AttnDownBlock1D":
# 返回一个 AttnDownBlock1D 对象,传入输出和输入通道参数
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
# 检查下采样块类型是否为 "DownBlock1DNoSkip"
elif down_block_type == "DownBlock1DNoSkip":
# 返回一个 DownBlock1DNoSkip 对象,传入输出和输入通道参数
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
# 如果下采样块类型不匹配,抛出一个值错误异常
raise ValueError(f"{down_block_type} does not exist.")
# 根据给定的上采样块类型,创建并返回对应的上采样块实例
def get_up_block(
# 上采样块类型
up_block_type: str,
# 网络层数
num_layers: int,
# 输入通道数
in_channels: int,
# 输出通道数
out_channels: int,
# 时间嵌入通道数
temb_channels: int,
# 是否添加上采样
add_upsample: bool
) -> UpBlockType:
# 检查上采样块类型是否为 "UpResnetBlock1D"
if up_block_type == "UpResnetBlock1D":
# 创建并返回 UpResnetBlock1D 实例
return UpResnetBlock1D(
# 设置输入通道数
in_channels=in_channels,
# 设置网络层数
num_layers=num_layers,
# 设置输出通道数
out_channels=out_channels,
# 设置时间嵌入通道数
temb_channels=temb_channels,
# 设置是否添加上采样
add_upsample=add_upsample,
)
# 检查上采样块类型是否为 "UpBlock1D"
elif up_block_type == "UpBlock1D":
# 创建并返回 UpBlock1D 实例
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
# 检查上采样块类型是否为 "AttnUpBlock1D"
elif up_block_type == "AttnUpBlock1D":
# 创建并返回 AttnUpBlock1D 实例
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
# 检查上采样块类型是否为 "UpBlock1DNoSkip"
elif up_block_type == "UpBlock1DNoSkip":
# 创建并返回 UpBlock1DNoSkip 实例
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
# 抛出错误,表示该上采样块类型不存在
raise ValueError(f"{up_block_type} does not exist.")
# 根据给定的中间块类型,创建并返回对应的中间块实例
def get_mid_block(
# 中间块类型
mid_block_type: str,
# 网络层数
num_layers: int,
# 输入通道数
in_channels: int,
# 中间通道数
mid_channels: int,
# 输出通道数
out_channels: int,
# 嵌入维度
embed_dim: int,
# 是否添加下采样
add_downsample: bool,
) -> MidBlockType:
# 检查中间块类型是否为 "MidResTemporalBlock1D"
if mid_block_type == "MidResTemporalBlock1D":
# 创建并返回 MidResTemporalBlock1D 实例
return MidResTemporalBlock1D(
# 设置网络层数
num_layers=num_layers,
# 设置输入通道数
in_channels=in_channels,
# 设置输出通道数
out_channels=out_channels,
# 设置嵌入维度
embed_dim=embed_dim,
# 设置是否添加下采样
add_downsample=add_downsample,
)
# 检查中间块类型是否为 "ValueFunctionMidBlock1D"
elif mid_block_type == "ValueFunctionMidBlock1D":
# 创建并返回 ValueFunctionMidBlock1D 实例
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
# 检查中间块类型是否为 "UNetMidBlock1D"
elif mid_block_type == "UNetMidBlock1D":
# 创建并返回 UNetMidBlock1D 实例
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
# 抛出错误,表示该中间块类型不存在
raise ValueError(f"{mid_block_type} does not exist.")
# 根据给定的输出块类型,创建并返回对应的输出块实例
def get_out_block(
# 输出块类型
*, out_block_type: str,
# 输出组数
num_groups_out: int,
# 嵌入维度
embed_dim: int,
# 输出通道数
out_channels: int,
# 激活函数类型
act_fn: str,
# 全连接层维度
fc_dim: int
) -> Optional[OutBlockType]:
# 检查输出块类型是否为 "OutConv1DBlock"
if out_block_type == "OutConv1DBlock":
# 创建并返回 OutConv1DBlock 实例
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
# 检查输出块类型是否为 "ValueFunction"
elif out_block_type == "ValueFunction":
# 创建并返回 OutValueFunctionBlock 实例
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
# 如果输出块类型不匹配,返回 None
return None